diff --git a/tests/qlora/test_hf_qlora_train_and_merge.py b/tests/qlora/test_hf_qlora_train_and_merge.py
index 797d94018..ae975b026 100644
--- a/tests/qlora/test_hf_qlora_train_and_merge.py
+++ b/tests/qlora/test_hf_qlora_train_and_merge.py
@@ -54,37 +54,37 @@ if __name__ == "__main__":
seed = 42
batch_size = 5
num_generations = 5
- tokenizer = setup_tokenizer(model_name, fixup_funcs=[fix_llama3_tokenizer])
+ tokenizer = setup_tokenizer(model_name, fixup_funcs = [fix_llama3_tokenizer])
temperature = 0.8
max_new_tokens = 20
- peft_config = get_peft_config(lora_rank=lora_rank, target_modules="all-linear")
- model = setup_model(model_name, quantize=True, dtype=dtype, peft_config=peft_config)
+ peft_config = get_peft_config(lora_rank = lora_rank, target_modules = "all-linear")
+ model = setup_model(model_name, quantize = True, dtype = dtype, peft_config = peft_config)
prompt = tokenizer.apply_chat_template(
- [USER_MESSAGE], tokenize=False, add_generation_prompt=True
+ [USER_MESSAGE], tokenize = False, add_generation_prompt = True
)
with header_footer_context("Test Prompt and Answer"):
print(f"Test Prompt:\n{prompt}\nExpected Answer:\n{ANSWER}")
dataset: Dataset = create_dataset(
- tokenizer, num_examples=num_examples, messages=DEFAULT_MESSAGES
+ tokenizer, num_examples = num_examples, messages = DEFAULT_MESSAGES
)
with header_footer_context("Dataset"):
print(f"Dataset: {next(iter(dataset))}")
training_args = SFTConfig(
- output_dir=output_dir,
- max_steps=max_steps,
- per_device_train_batch_size=batch_size,
- log_level="info",
- report_to="none",
- num_train_epochs=1,
- logging_steps=1,
- seed=seed,
- bf16=dtype == torch.bfloat16,
- fp16=dtype == torch.float16,
- save_strategy="no",
+ output_dir = output_dir,
+ max_steps = max_steps,
+ per_device_train_batch_size = batch_size,
+ log_level = "info",
+ report_to = "none",
+ num_train_epochs = 1,
+ logging_steps = 1,
+ seed = seed,
+ bf16 = dtype == torch.bfloat16,
+ fp16 = dtype == torch.float16,
+ save_strategy = "no",
)
with header_footer_context("Train Args"):
@@ -92,7 +92,7 @@ if __name__ == "__main__":
print(peft_config)
trainer = setup_trainer(
- model, tokenizer, dataset, training_args, peft_config=peft_config
+ model, tokenizer, dataset, training_args, peft_config = peft_config
)
with header_footer_context("Model"):
@@ -108,11 +108,11 @@ if __name__ == "__main__":
responses = sample_responses(
model,
tokenizer,
- prompt=prompt,
+ prompt = prompt,
**generation_args,
)
with header_footer_context("Responses before training"):
- check_responses(responses, answer=ANSWER, prompt=prompt)
+ check_responses(responses, answer = ANSWER, prompt = prompt)
with header_footer_context("Peft Weights before training"):
for name, stats in itertools.islice(describe_peft_weights(model), 2):
@@ -129,11 +129,11 @@ if __name__ == "__main__":
responses = sample_responses(
model,
tokenizer,
- prompt=prompt,
+ prompt = prompt,
**generation_args,
)
with header_footer_context("Responses after training"):
- check_responses(responses, answer=ANSWER, prompt=prompt)
+ check_responses(responses, answer = ANSWER, prompt = prompt)
model_copy = deepcopy(model)
@@ -142,18 +142,18 @@ if __name__ == "__main__":
responses = sample_responses(
merged_model,
tokenizer,
- prompt=prompt,
+ prompt = prompt,
**generation_args,
)
with header_footer_context("Responses after custom merging to 16bit"):
- check_responses(responses, answer=ANSWER, prompt=prompt)
+ check_responses(responses, answer = ANSWER, prompt = prompt)
merged_model_peft = model_copy.merge_and_unload()
responses = sample_responses(
merged_model_peft,
tokenizer,
- prompt=prompt,
+ prompt = prompt,
**generation_args,
)
with header_footer_context("Responses after peft merge_and_unload"):
- check_responses(responses, answer=ANSWER, prompt=prompt)
+ check_responses(responses, answer = ANSWER, prompt = prompt)
diff --git a/tests/qlora/test_unsloth_qlora_train_and_merge.py b/tests/qlora/test_unsloth_qlora_train_and_merge.py
index 59fa813fa..9040ad793 100644
--- a/tests/qlora/test_unsloth_qlora_train_and_merge.py
+++ b/tests/qlora/test_unsloth_qlora_train_and_merge.py
@@ -50,13 +50,13 @@ def get_unsloth_model_and_tokenizer(
dtype: torch.dtype = torch.bfloat16,
):
return FastLanguageModel.from_pretrained(
- model_name=model_name,
- max_seq_length=max_seq_length,
- load_in_4bit=load_in_4bit,
- fast_inference=fast_inference,
- max_lora_rank=max_lora_rank,
- gpu_memory_utilization=gpu_memory_utilization,
- dtype=dtype,
+ model_name = model_name,
+ max_seq_length = max_seq_length,
+ load_in_4bit = load_in_4bit,
+ fast_inference = fast_inference,
+ max_lora_rank = max_lora_rank,
+ gpu_memory_utilization = gpu_memory_utilization,
+ dtype = dtype,
)
@@ -69,11 +69,11 @@ def get_unsloth_peft_model(
):
return FastLanguageModel.get_peft_model(
model,
- r=lora_rank,
- target_modules=target_modules,
- lora_alpha=lora_rank,
- use_gradient_checkpointing=use_gradient_checkpointing,
- random_state=random_state,
+ r = lora_rank,
+ target_modules = target_modules,
+ lora_alpha = lora_rank,
+ use_gradient_checkpointing = use_gradient_checkpointing,
+ random_state = random_state,
)
@@ -101,48 +101,48 @@ if __name__ == "__main__":
model, tokenizer = get_unsloth_model_and_tokenizer(
model_name,
- max_seq_length=512,
- load_in_4bit=True,
- fast_inference=False,
- max_lora_rank=lora_rank,
- dtype=dtype,
+ max_seq_length = 512,
+ load_in_4bit = True,
+ fast_inference = False,
+ max_lora_rank = lora_rank,
+ dtype = dtype,
)
temperature = 0.8
max_new_tokens = 20
model = get_unsloth_peft_model(
model,
- lora_rank=lora_rank,
- target_modules=target_modules,
- use_gradient_checkpointing=gradient_checkpointing,
- random_state=seed,
+ lora_rank = lora_rank,
+ target_modules = target_modules,
+ use_gradient_checkpointing = gradient_checkpointing,
+ random_state = seed,
)
prompt = tokenizer.apply_chat_template(
- [USER_MESSAGE], tokenize=False, add_generation_prompt=True
+ [USER_MESSAGE], tokenize = False, add_generation_prompt = True
)
with header_footer_context("Test Prompt and Answer"):
print(f"Test Prompt:\n{prompt}\nExpected Answer:\n{ANSWER}")
dataset: Dataset = create_dataset(
- tokenizer, num_examples=num_examples, messages=DEFAULT_MESSAGES
+ tokenizer, num_examples = num_examples, messages = DEFAULT_MESSAGES
)
with header_footer_context("Dataset"):
print(f"Dataset: {next(iter(dataset))}")
training_args = SFTConfig(
- output_dir=output_dir,
- max_steps=max_steps,
- per_device_train_batch_size=batch_size,
- log_level="info",
- report_to="none",
- num_train_epochs=1,
- logging_steps=1,
- seed=seed,
- bf16=dtype == torch.bfloat16,
- fp16=dtype == torch.float16,
- save_strategy="no",
+ output_dir = output_dir,
+ max_steps = max_steps,
+ per_device_train_batch_size = batch_size,
+ log_level = "info",
+ report_to = "none",
+ num_train_epochs = 1,
+ logging_steps = 1,
+ seed = seed,
+ bf16 = dtype == torch.bfloat16,
+ fp16 = dtype == torch.float16,
+ save_strategy = "no",
)
with header_footer_context("Train Args"):
@@ -163,11 +163,11 @@ if __name__ == "__main__":
responses = sample_responses(
model,
tokenizer,
- prompt=prompt,
+ prompt = prompt,
**generation_args,
)
with header_footer_context("Responses before training"):
- check_responses(responses, answer=ANSWER, prompt=prompt)
+ check_responses(responses, answer = ANSWER, prompt = prompt)
with header_footer_context("Peft Weights before training"):
for name, stats in itertools.islice(describe_peft_weights(model), 2):
print(f"{name}:\n{stats}")
@@ -183,29 +183,29 @@ if __name__ == "__main__":
responses = sample_responses(
model,
tokenizer,
- prompt=prompt,
+ prompt = prompt,
**generation_args,
)
with header_footer_context("Responses after training"):
- check_responses(responses, answer=ANSWER, prompt=prompt)
+ check_responses(responses, answer = ANSWER, prompt = prompt)
model.save_pretrained_merged(
unsloth_merged_path,
tokenizer,
- save_method="merged_16bit",
+ save_method = "merged_16bit",
)
merged_model_unsloth, tokenizer = get_unsloth_model_and_tokenizer(
unsloth_merged_path,
- max_seq_length=512,
- load_in_4bit=False,
- fast_inference=False,
- dtype=dtype,
+ max_seq_length = 512,
+ load_in_4bit = False,
+ fast_inference = False,
+ dtype = dtype,
)
responses = sample_responses(
merged_model_unsloth,
tokenizer,
- prompt=prompt,
+ prompt = prompt,
**generation_args,
)
with header_footer_context("Responses after unsloth merge to 16bit"):
- check_responses(responses, answer=ANSWER, prompt=prompt)
+ check_responses(responses, answer = ANSWER, prompt = prompt)
diff --git a/tests/saving/language_models/test_merged_model_perplexity_qwen_2.5.py b/tests/saving/language_models/test_merged_model_perplexity_qwen_2.5.py
index 622e72d25..d63bb9fe0 100644
--- a/tests/saving/language_models/test_merged_model_perplexity_qwen_2.5.py
+++ b/tests/saving/language_models/test_merged_model_perplexity_qwen_2.5.py
@@ -78,17 +78,17 @@ def formatting_prompts_func(examples):
}
-def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=False):
+def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):
"""Load model and compute perplexity in subprocess"""
from unsloth import FastLanguageModel
from tests.utils.perplexity_eval import ppl_model
# Load model
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
- model_name="./unsloth_out/merged_qwen_text_model",
- max_seq_length=2048,
- load_in_4bit=load_in_4bit,
- load_in_8bit=load_in_8bit,
+ model_name = "./unsloth_out/merged_qwen_text_model",
+ max_seq_length = 2048,
+ load_in_4bit = load_in_4bit,
+ load_in_8bit = load_in_8bit,
)
# Set up tokenizer
# merged_tokenizer = get_chat_template(
@@ -98,7 +98,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
# Load dataset fresh in subprocess
dataset_ppl = load_dataset(
- "allenai/openassistant-guanaco-reformatted", split="eval"
+ "allenai/openassistant-guanaco-reformatted", split = "eval"
)
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
@@ -146,7 +146,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
"text": texts,
}
- dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
+ dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
# Compute perplexity using the passed dataset
ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)
@@ -172,7 +172,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
# Main execution code should be wrapped in this guard
if __name__ == "__main__":
- mp.set_start_method("spawn", force=True)
+ mp.set_start_method("spawn", force = True)
if torch.cuda.is_bf16_supported():
compute_dtype = torch.bfloat16
@@ -182,31 +182,31 @@ if __name__ == "__main__":
attn_implementation = "sdpa"
model, tokenizer = FastLanguageModel.from_pretrained(
- model_name="unsloth/Qwen2.5-7B-Instruct",
- max_seq_length=2048,
- dtype=compute_dtype,
- load_in_4bit=True,
- load_in_8bit=False,
- full_finetuning=False,
- attn_implementation=attn_implementation,
+ model_name = "unsloth/Qwen2.5-7B-Instruct",
+ max_seq_length = 2048,
+ dtype = compute_dtype,
+ load_in_4bit = True,
+ load_in_8bit = False,
+ full_finetuning = False,
+ attn_implementation = attn_implementation,
)
dataset_train = load_dataset(
- "allenai/openassistant-guanaco-reformatted", split="train"
+ "allenai/openassistant-guanaco-reformatted", split = "train"
)
dataset_ppl = load_dataset(
- "allenai/openassistant-guanaco-reformatted", split="eval"
+ "allenai/openassistant-guanaco-reformatted", split = "eval"
)
- dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
- dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
+ dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
+ dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
model = FastLanguageModel.get_peft_model(
model,
- r=16,
- target_modules=[
+ r = 16,
+ target_modules = [
"k_proj",
"q_proj",
"v_proj",
@@ -215,40 +215,40 @@ if __name__ == "__main__":
"down_proj",
"up_proj",
],
- lora_alpha=16,
- lora_dropout=0,
- bias="none",
- use_gradient_checkpointing="unsloth",
- random_state=3407,
- use_rslora=False,
- loftq_config=None,
+ lora_alpha = 16,
+ lora_dropout = 0,
+ bias = "none",
+ use_gradient_checkpointing = "unsloth",
+ random_state = 3407,
+ use_rslora = False,
+ loftq_config = None,
)
from unsloth import is_bfloat16_supported
trainer = SFTTrainer(
- model=model,
- tokenizer=tokenizer,
- train_dataset=dataset_train,
- dataset_text_field="text",
- max_seq_length=2048,
- data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
- dataset_num_proc=2,
- packing=False,
- args=TrainingArguments(
- per_device_train_batch_size=2,
- gradient_accumulation_steps=4,
- warmup_ratio=0.1,
- max_steps=200,
- learning_rate=2e-4,
- fp16=not is_bfloat16_supported(),
- bf16=is_bfloat16_supported(),
- logging_steps=50,
- optim="adamw_8bit",
- lr_scheduler_type="linear",
- seed=3407,
- output_dir="outputs",
- report_to="none",
+ model = model,
+ tokenizer = tokenizer,
+ train_dataset = dataset_train,
+ dataset_text_field = "text",
+ max_seq_length = 2048,
+ data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
+ dataset_num_proc = 2,
+ packing = False,
+ args = TrainingArguments(
+ per_device_train_batch_size = 2,
+ gradient_accumulation_steps = 4,
+ warmup_ratio = 0.1,
+ max_steps = 200,
+ learning_rate = 2e-4,
+ fp16 = not is_bfloat16_supported(),
+ bf16 = is_bfloat16_supported(),
+ logging_steps = 50,
+ optim = "adamw_8bit",
+ lr_scheduler_type = "linear",
+ seed = 3407,
+ output_dir = "outputs",
+ report_to = "none",
),
)
@@ -260,7 +260,7 @@ if __name__ == "__main__":
# saving and merging the model to local disk
print("merge and save to local disk")
model.save_pretrained_merged(
- save_directory="./unsloth_out/merged_qwen_text_model", tokenizer=tokenizer
+ save_directory = "./unsloth_out/merged_qwen_text_model", tokenizer = tokenizer
)
# print("cleaning")
@@ -272,10 +272,10 @@ if __name__ == "__main__":
# load model from local disk and test
print("Loading merged model in 4 bit for perplexity test")
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
- model_name="./unsloth_out/merged_qwen_text_model",
- max_seq_length=2048,
- load_in_4bit=True,
- load_in_8bit=False,
+ model_name = "./unsloth_out/merged_qwen_text_model",
+ max_seq_length = 2048,
+ load_in_4bit = True,
+ load_in_8bit = False,
)
add_to_comparison(
@@ -284,7 +284,7 @@ if __name__ == "__main__":
print("Computing 8-bit model perplexity in subprocess...")
result_queue = mp.Queue()
- p = mp.Process(target=load_and_compute_8bit_ppl, args=(result_queue, False, True))
+ p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))
p.start()
p.join()
@@ -293,10 +293,10 @@ if __name__ == "__main__":
print("Loading merged model in 16 bit for perplexity test")
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
- model_name="./unsloth_out/merged_qwen_text_model",
- max_seq_length=2048,
- load_in_4bit=False,
- load_in_8bit=False,
+ model_name = "./unsloth_out/merged_qwen_text_model",
+ max_seq_length = 2048,
+ load_in_4bit = False,
+ load_in_8bit = False,
)
add_to_comparison(
diff --git a/tests/saving/language_models/test_push_to_hub_merged.py b/tests/saving/language_models/test_push_to_hub_merged.py
index b92842b90..58d589305 100644
--- a/tests/saving/language_models/test_push_to_hub_merged.py
+++ b/tests/saving/language_models/test_push_to_hub_merged.py
@@ -37,7 +37,7 @@ def formatting_prompts_func(examples):
convos = examples["messages"]
texts = [
tokenizer.apply_chat_template(
- convo, tokenize=False, add_generation_prompt=False
+ convo, tokenize = False, add_generation_prompt = False
)
for convo in convos
]
@@ -52,34 +52,34 @@ else:
attn_implementation = "sdpa"
model, tokenizer = FastLanguageModel.from_pretrained(
- model_name="unsloth/Llama-3.2-1B-Instruct",
- max_seq_length=2048,
- dtype=compute_dtype,
- load_in_4bit=True,
- load_in_8bit=False,
- full_finetuning=False,
- attn_implementation=attn_implementation,
+ model_name = "unsloth/Llama-3.2-1B-Instruct",
+ max_seq_length = 2048,
+ dtype = compute_dtype,
+ load_in_4bit = True,
+ load_in_8bit = False,
+ full_finetuning = False,
+ attn_implementation = attn_implementation,
)
tokenizer = get_chat_template(
tokenizer,
- chat_template="llama-3.1",
+ chat_template = "llama-3.1",
)
from unsloth.chat_templates import standardize_sharegpt
-dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split="train")
-dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split="eval")
+dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split = "train")
+dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split = "eval")
-dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
-dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
+dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
+dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
model = FastLanguageModel.get_peft_model(
model,
- r=16,
- target_modules=[
+ r = 16,
+ target_modules = [
"k_proj",
"q_proj",
"v_proj",
@@ -88,40 +88,40 @@ model = FastLanguageModel.get_peft_model(
"down_proj",
"up_proj",
],
- lora_alpha=16,
- lora_dropout=0,
- bias="none",
- use_gradient_checkpointing="unsloth",
- random_state=3407,
- use_rslora=False,
- loftq_config=None,
+ lora_alpha = 16,
+ lora_dropout = 0,
+ bias = "none",
+ use_gradient_checkpointing = "unsloth",
+ random_state = 3407,
+ use_rslora = False,
+ loftq_config = None,
)
from unsloth import is_bfloat16_supported
trainer = SFTTrainer(
- model=model,
- tokenizer=tokenizer,
- train_dataset=dataset_train,
- dataset_text_field="text",
- max_seq_length=2048,
- data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
- dataset_num_proc=2,
- packing=False,
- args=TrainingArguments(
- per_device_train_batch_size=2,
- gradient_accumulation_steps=4,
- warmup_ratio=0.1,
- max_steps=30,
- learning_rate=2e-4,
- fp16=not is_bfloat16_supported(),
- bf16=is_bfloat16_supported(),
- logging_steps=50,
- optim="adamw_8bit",
- lr_scheduler_type="linear",
- seed=3407,
- output_dir="outputs",
- report_to="none",
+ model = model,
+ tokenizer = tokenizer,
+ train_dataset = dataset_train,
+ dataset_text_field = "text",
+ max_seq_length = 2048,
+ data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
+ dataset_num_proc = 2,
+ packing = False,
+ args = TrainingArguments(
+ per_device_train_batch_size = 2,
+ gradient_accumulation_steps = 4,
+ warmup_ratio = 0.1,
+ max_steps = 30,
+ learning_rate = 2e-4,
+ fp16 = not is_bfloat16_supported(),
+ bf16 = is_bfloat16_supported(),
+ logging_steps = 50,
+ optim = "adamw_8bit",
+ lr_scheduler_type = "linear",
+ seed = 3407,
+ output_dir = "outputs",
+ report_to = "none",
),
)
@@ -129,8 +129,8 @@ from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
trainer,
- instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
- response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
+ instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
+ response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
)
# run training
@@ -160,7 +160,7 @@ try:
print("\n" + "=" * 80)
print("=== UPLOADING MODEL TO HUB ===".center(80))
print("=" * 80 + "\n")
- model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token)
+ model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)
success["upload"] = True
print("✅ Model uploaded successfully!")
except Exception as e:
diff --git a/tests/saving/language_models/test_push_to_hub_merged_sharded_index_file.py b/tests/saving/language_models/test_push_to_hub_merged_sharded_index_file.py
index 1f07ce643..038565d17 100644
--- a/tests/saving/language_models/test_push_to_hub_merged_sharded_index_file.py
+++ b/tests/saving/language_models/test_push_to_hub_merged_sharded_index_file.py
@@ -37,7 +37,7 @@ def formatting_prompts_func(examples):
convos = examples["messages"]
texts = [
tokenizer.apply_chat_template(
- convo, tokenize=False, add_generation_prompt=False
+ convo, tokenize = False, add_generation_prompt = False
)
for convo in convos
]
@@ -52,34 +52,34 @@ else:
attn_implementation = "sdpa"
model, tokenizer = FastLanguageModel.from_pretrained(
- model_name="unsloth/Llama-3.1-8B-Instruct",
- max_seq_length=2048,
- dtype=compute_dtype,
- load_in_4bit=True,
- load_in_8bit=False,
- full_finetuning=False,
- attn_implementation=attn_implementation,
+ model_name = "unsloth/Llama-3.1-8B-Instruct",
+ max_seq_length = 2048,
+ dtype = compute_dtype,
+ load_in_4bit = True,
+ load_in_8bit = False,
+ full_finetuning = False,
+ attn_implementation = attn_implementation,
)
tokenizer = get_chat_template(
tokenizer,
- chat_template="llama-3.1",
+ chat_template = "llama-3.1",
)
from unsloth.chat_templates import standardize_sharegpt
-dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split="train")
-dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split="eval")
+dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split = "train")
+dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split = "eval")
-dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
-dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
+dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
+dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
model = FastLanguageModel.get_peft_model(
model,
- r=16,
- target_modules=[
+ r = 16,
+ target_modules = [
"k_proj",
"q_proj",
"v_proj",
@@ -88,40 +88,40 @@ model = FastLanguageModel.get_peft_model(
"down_proj",
"up_proj",
],
- lora_alpha=16,
- lora_dropout=0,
- bias="none",
- use_gradient_checkpointing="unsloth",
- random_state=3407,
- use_rslora=False,
- loftq_config=None,
+ lora_alpha = 16,
+ lora_dropout = 0,
+ bias = "none",
+ use_gradient_checkpointing = "unsloth",
+ random_state = 3407,
+ use_rslora = False,
+ loftq_config = None,
)
from unsloth import is_bfloat16_supported
trainer = SFTTrainer(
- model=model,
- tokenizer=tokenizer,
- train_dataset=dataset_train,
- dataset_text_field="text",
- max_seq_length=2048,
- data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
- dataset_num_proc=2,
- packing=False,
- args=TrainingArguments(
- per_device_train_batch_size=2,
- gradient_accumulation_steps=4,
- warmup_ratio=0.1,
- max_steps=30,
- learning_rate=2e-4,
- fp16=not is_bfloat16_supported(),
- bf16=is_bfloat16_supported(),
- logging_steps=50,
- optim="adamw_8bit",
- lr_scheduler_type="linear",
- seed=3407,
- output_dir="outputs",
- report_to="none",
+ model = model,
+ tokenizer = tokenizer,
+ train_dataset = dataset_train,
+ dataset_text_field = "text",
+ max_seq_length = 2048,
+ data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
+ dataset_num_proc = 2,
+ packing = False,
+ args = TrainingArguments(
+ per_device_train_batch_size = 2,
+ gradient_accumulation_steps = 4,
+ warmup_ratio = 0.1,
+ max_steps = 30,
+ learning_rate = 2e-4,
+ fp16 = not is_bfloat16_supported(),
+ bf16 = is_bfloat16_supported(),
+ logging_steps = 50,
+ optim = "adamw_8bit",
+ lr_scheduler_type = "linear",
+ seed = 3407,
+ output_dir = "outputs",
+ report_to = "none",
),
)
@@ -129,8 +129,8 @@ from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
trainer,
- instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
- response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
+ instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
+ response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
)
# run training
@@ -161,7 +161,7 @@ try:
print("\n" + "=" * 80)
print("=== UPLOADING MODEL TO HUB ===".center(80))
print("=" * 80 + "\n")
- model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token)
+ model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)
success["upload"] = True
print("✅ Model uploaded successfully!")
except Exception as e:
@@ -173,8 +173,8 @@ try:
print("\n" + "=" * 80)
print("=== VERIFYING REPO CONTENTS ===".center(80))
print("=" * 80 + "\n")
- fs = HfFileSystem(token=hf_token)
- file_list = fs.ls(repo_name, detail=True)
+ fs = HfFileSystem(token = hf_token)
+ file_list = fs.ls(repo_name, detail = True)
safetensors_found = any(
file["name"].endswith("model.safetensors.index.json") for file in file_list
)
diff --git a/tests/saving/language_models/test_save_merged_grpo_model.py b/tests/saving/language_models/test_save_merged_grpo_model.py
index e93b8add9..67b649305 100644
--- a/tests/saving/language_models/test_save_merged_grpo_model.py
+++ b/tests/saving/language_models/test_save_merged_grpo_model.py
@@ -24,7 +24,7 @@ max_seq_length = 2048 # Can increase for longer reasoning traces
lora_rank = 64 # Larger rank = smarter, but slower
-def evaluate_merged_model(result_queue, load_in_4bit=False, load_in_8bit=False):
+def evaluate_merged_model(result_queue, load_in_4bit = False, load_in_8bit = False):
from unsloth import FastLanguageModel
from tests.utils.aime_eval import evaluate_model_aime
@@ -32,12 +32,12 @@ def evaluate_merged_model(result_queue, load_in_4bit=False, load_in_8bit=False):
lora_rank = 64 # Larger rank = smarter, but slower
model, tokenizer = FastLanguageModel.from_pretrained(
- model_name="./final_merged_model",
- max_seq_length=max_seq_length,
- load_in_4bit=True, # False for LoRA 16bit
- fast_inference=True, # Enable vLLM fast inference
- max_lora_rank=lora_rank,
- gpu_memory_utilization=0.8, # Reduce if out of memory
+ model_name = "./final_merged_model",
+ max_seq_length = max_seq_length,
+ load_in_4bit = True, # False for LoRA 16bit
+ fast_inference = True, # Enable vLLM fast inference
+ max_lora_rank = lora_rank,
+ gpu_memory_utilization = 0.8, # Reduce if out of memory
)
print(f"\n{'='*60}")
@@ -53,14 +53,14 @@ def evaluate_merged_model(result_queue, load_in_4bit=False, load_in_8bit=False):
print(f"{'='*60}")
evaluate_model_aime(
- model=model,
- tokenizer=tokenizer,
- model_type=model_type,
- temperature=0.3,
- n_sampling=8,
- max_tokens=32768,
- top_p=0.95,
- seed=0,
+ model = model,
+ tokenizer = tokenizer,
+ model_type = model_type,
+ temperature = 0.3,
+ n_sampling = 8,
+ max_tokens = 32768,
+ top_p = 0.95,
+ seed = 0,
)
result_queue.put(results)
@@ -74,12 +74,12 @@ def evaluate_merged_model(result_queue, load_in_4bit=False, load_in_8bit=False):
# Main execution code should be wrapped in this guard
def training_run(result_queue):
model, tokenizer = FastLanguageModel.from_pretrained(
- model_name="meta-llama/Llama-3.2-3B-Instruct",
- max_seq_length=max_seq_length,
- load_in_4bit=False, # False for LoRA 16bit
- fast_inference=True, # Enable vLLM fast inference
- max_lora_rank=lora_rank,
- gpu_memory_utilization=0.8, # Reduce if out of memory
+ model_name = "meta-llama/Llama-3.2-3B-Instruct",
+ max_seq_length = max_seq_length,
+ load_in_4bit = False, # False for LoRA 16bit
+ fast_inference = True, # Enable vLLM fast inference
+ max_lora_rank = lora_rank,
+ gpu_memory_utilization = 0.8, # Reduce if out of memory
)
"""### Helper Functions
@@ -166,10 +166,10 @@ def training_run(result_queue):
lengths = dataset.map(
lambda x: {
"tokens": tokenizer.apply_chat_template(
- x["prompt"], add_generation_prompt=True, tokenize=True
+ x["prompt"], add_generation_prompt = True, tokenize = True
)
},
- batched=True,
+ batched = True,
).map(lambda x: {"length": len(x["tokens"])})["length"]
max_length = max(lengths)
@@ -181,7 +181,7 @@ def training_run(result_queue):
)
return max_length, avg_length
- def extract_unsloth_answer(text, start_tag="", end_tag=""):
+ def extract_unsloth_answer(text, start_tag = "", end_tag = ""):
"""Extract answer from Unsloth SOLUTION tags"""
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
matches = re.findall(pattern, text, re.DOTALL)
@@ -213,10 +213,10 @@ def training_run(result_queue):
"""Count tokens in text"""
if not text:
return 0
- encoding = tokenizer_instance(text, return_tensors="pt")
+ encoding = tokenizer_instance(text, return_tensors = "pt")
return len(encoding["input_ids"][0])
- def check_format_compliance(text, format_type="unsloth"):
+ def check_format_compliance(text, format_type = "unsloth"):
"""Check if response follows expected format"""
if format_type == "unsloth":
reasoning_start = ""
@@ -419,11 +419,11 @@ def training_run(result_queue):
# Save comparison
comparison_data = {
"summary": all_results,
- "best_model": max(all_results, key=lambda x: x["exact_match_pct"]),
+ "best_model": max(all_results, key = lambda x: x["exact_match_pct"]),
}
with open("model_comparison_comprehensive.json", "w") as f:
- json.dump(comparison_data, f, indent=4)
+ json.dump(comparison_data, f, indent = 4)
print(
f"\nBest performing model: {comparison_data['best_model']['model_type']} "
@@ -449,10 +449,10 @@ def training_run(result_queue):
from datasets import load_dataset
# Load GSM8K
- gsm8k_dataset = load_dataset("openai/gsm8k", "main", split="train")
+ gsm8k_dataset = load_dataset("openai/gsm8k", "main", split = "train")
# Load LIMO (adjust this based on your access method)
- limo_train = load_dataset("GAIR/LIMO", split="train")
+ limo_train = load_dataset("GAIR/LIMO", split = "train")
# Prepare datasets
gsm8k_train = prepare_gsm8k_dataset(gsm8k_dataset)
@@ -466,28 +466,28 @@ def training_run(result_queue):
# Single temperature evaluation on combined dataset
results = evaluate_model_aime(
- model=model,
- tokenizer=tokenizer,
- model_type="base",
- temperature=0.3,
- n_sampling=8,
- max_tokens=32768,
- top_p=0.95,
- seed=0,
+ model = model,
+ tokenizer = tokenizer,
+ model_type = "base",
+ temperature = 0.3,
+ n_sampling = 8,
+ max_tokens = 32768,
+ top_p = 0.95,
+ seed = 0,
)
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
tokenizer,
- chat_template="llama-3.1",
+ chat_template = "llama-3.1",
)
def formatting_prompts_func(examples):
convos = examples["prompt"]
texts = [
tokenizer.apply_chat_template(
- convo, tokenize=False, add_generation_prompt=False
+ convo, tokenize = False, add_generation_prompt = False
)
for convo in convos
]
@@ -497,7 +497,7 @@ def training_run(result_queue):
limo_train = limo_train.map(
formatting_prompts_func,
- batched=True,
+ batched = True,
)
from trl import SFTTrainer
@@ -510,8 +510,8 @@ def training_run(result_queue):
model = FastLanguageModel.get_peft_model(
model,
- r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
- target_modules=[
+ r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
+ target_modules = [
"q_proj",
"k_proj",
"v_proj",
@@ -520,37 +520,37 @@ def training_run(result_queue):
"up_proj",
"down_proj",
], # Remove QKVO if out of memory
- lora_alpha=lora_rank,
- use_gradient_checkpointing="unsloth", # Enable long context finetuning
- random_state=3407,
+ lora_alpha = lora_rank,
+ use_gradient_checkpointing = "unsloth", # Enable long context finetuning
+ random_state = 3407,
)
if limo_train is not None:
trainer = SFTTrainer(
- model=model,
- tokenizer=tokenizer,
- train_dataset=limo_train,
- dataset_text_field="text",
- max_seq_length=max_seq_length,
- data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
- dataset_num_proc=2,
- packing=False, # Can make training 5x faster for short sequences.
- args=TrainingArguments(
- per_device_train_batch_size=2,
- gradient_accumulation_steps=4,
- warmup_steps=5,
- num_train_epochs=1, # Set this for 1 full training run.
+ model = model,
+ tokenizer = tokenizer,
+ train_dataset = limo_train,
+ dataset_text_field = "text",
+ max_seq_length = max_seq_length,
+ data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
+ dataset_num_proc = 2,
+ packing = False, # Can make training 5x faster for short sequences.
+ args = TrainingArguments(
+ per_device_train_batch_size = 2,
+ gradient_accumulation_steps = 4,
+ warmup_steps = 5,
+ num_train_epochs = 1, # Set this for 1 full training run.
# max_steps = 60,
- learning_rate=2e-4,
- fp16=not is_bfloat16_supported(),
- bf16=is_bfloat16_supported(),
- logging_steps=1,
- optim="adamw_8bit",
- weight_decay=0.01,
- lr_scheduler_type="linear",
- seed=3407,
- output_dir="outputs",
- report_to="none", # Use this for WandB etc
+ learning_rate = 2e-4,
+ fp16 = not is_bfloat16_supported(),
+ bf16 = is_bfloat16_supported(),
+ logging_steps = 1,
+ optim = "adamw_8bit",
+ weight_decay = 0.01,
+ lr_scheduler_type = "linear",
+ seed = 3407,
+ output_dir = "outputs",
+ report_to = "none", # Use this for WandB etc
),
)
@@ -558,8 +558,8 @@ def training_run(result_queue):
trainer = train_on_responses_only(
trainer,
- instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
- response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
+ instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
+ response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
)
# Train
@@ -588,7 +588,7 @@ def training_run(result_queue):
PRINT_EVERY_STEPS = 5
match_numbers = re.compile(
- solution_start + r".*?([\d\.\,]{1,})", flags=re.MULTILINE | re.DOTALL
+ solution_start + r".*?([\d\.\,]{1,})", flags = re.MULTILINE | re.DOTALL
)
def check_numbers(prompts, completions, answer, **kwargs):
@@ -642,37 +642,37 @@ def training_run(result_queue):
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
- learning_rate=5e-6,
- weight_decay=0.1,
- warmup_ratio=0.1,
- lr_scheduler_type="cosine",
- optim="adamw_torch_fused",
- logging_steps=1,
- per_device_train_batch_size=1,
- gradient_accumulation_steps=4, # Increase to 4 for smoother training
- num_generations=8, # Decrease if out of memory
- max_prompt_length=max_prompt_length,
- max_completion_length=max_seq_length - max_prompt_length,
+ learning_rate = 5e-6,
+ weight_decay = 0.1,
+ warmup_ratio = 0.1,
+ lr_scheduler_type = "cosine",
+ optim = "adamw_torch_fused",
+ logging_steps = 1,
+ per_device_train_batch_size = 1,
+ gradient_accumulation_steps = 4, # Increase to 4 for smoother training
+ num_generations = 8, # Decrease if out of memory
+ max_prompt_length = max_prompt_length,
+ max_completion_length = max_seq_length - max_prompt_length,
# num_train_epochs = 1, # Set to 1 for a full training run
# max_steps = 250,
- max_steps=1000,
- save_steps=250,
- max_grad_norm=0.1,
- report_to="none", # Can use Weights & Biases
- output_dir="outputs",
+ max_steps = 1000,
+ save_steps = 250,
+ max_grad_norm = 0.1,
+ report_to = "none", # Can use Weights & Biases
+ output_dir = "outputs",
)
trainer = GRPOTrainer(
- model=model,
- processing_class=tokenizer,
- reward_funcs=[
+ model = model,
+ processing_class = tokenizer,
+ reward_funcs = [
match_format_exactly,
match_format_approximately,
check_answer_correctness,
check_numbers,
],
- args=training_args,
- train_dataset=gsm8k_train,
+ args = training_args,
+ train_dataset = gsm8k_train,
)
# Train
@@ -696,14 +696,14 @@ def training_run(result_queue):
print(f"{'='*60}")
grpo_results = evaluate_model_aime(
- model=model,
- tokenizer=tokenizer,
- model_type="grpo",
- temperature=0.3,
- n_sampling=8,
- max_tokens=32768,
- top_p=0.95,
- seed=0,
+ model = model,
+ tokenizer = tokenizer,
+ model_type = "grpo",
+ temperature = 0.3,
+ n_sampling = 8,
+ max_tokens = 32768,
+ top_p = 0.95,
+ seed = 0,
)
all_results.append(grpo_results)
@@ -716,7 +716,7 @@ def training_run(result_queue):
# Save as merged model
try:
model.save_pretrained_merged(
- "final_merged_model", tokenizer, save_method="merged_16bit"
+ "final_merged_model", tokenizer, save_method = "merged_16bit"
)
print("✅ Merged model saved to: final_merged_model/")
except Exception as e:
@@ -774,12 +774,12 @@ def training_run(result_queue):
if __name__ == "__main__":
- mp.set_start_method("spawn", force=True)
+ mp.set_start_method("spawn", force = True)
result_queue = mp.Queue()
all_results = []
# run main finetuning and grpo loop
- p = mp.Process(target=training_run, args=(result_queue,))
+ p = mp.Process(target = training_run, args = (result_queue,))
p.start()
p.join()
@@ -787,7 +787,7 @@ if __name__ == "__main__":
all_results = results
# evaluate merged model loaded 16bits
- p = mp.Process(target=evaluate_merged_model, args=(result_queue, False, False))
+ p = mp.Process(target = evaluate_merged_model, args = (result_queue, False, False))
p.start()
p.join()
@@ -796,7 +796,7 @@ if __name__ == "__main__":
safe_remove_directory("./unsloth_compiled_cache")
# Merged model load 8 bits model AIME eval
- p = mp.Process(target=evaluate_merged_model, args=(result_queue, False, True))
+ p = mp.Process(target = evaluate_merged_model, args = (result_queue, False, True))
p.start()
p.join()
@@ -806,7 +806,7 @@ if __name__ == "__main__":
safe_remove_directory("./unsloth_compiled_cache")
# Merged model load 4 bits model AIME eval
- p = mp.Process(target=evaluate_merged_model, args=(result_queue, True, False))
+ p = mp.Process(target = evaluate_merged_model, args = (result_queue, True, False))
p.start()
p.join()
diff --git a/tests/saving/test_unsloth_save.py b/tests/saving/test_unsloth_save.py
index 5ec2a943e..35fdad6ba 100644
--- a/tests/saving/test_unsloth_save.py
+++ b/tests/saving/test_unsloth_save.py
@@ -43,31 +43,31 @@ tokenizer_files = [
]
-@pytest.fixture(scope="session", params=model_to_test)
+@pytest.fixture(scope = "session", params = model_to_test)
def loaded_model_tokenizer(request):
model_name = request.param
print("Loading model and tokenizer...")
model, tokenizer = FastModel.from_pretrained(
model_name, # use small model
- max_seq_length=128,
- dtype=None,
- load_in_4bit=True,
+ max_seq_length = 128,
+ dtype = None,
+ load_in_4bit = True,
)
# Apply LoRA
model = FastModel.get_peft_model(
model,
- r=16,
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
- lora_alpha=16,
- use_gradient_checkpointing="unsloth",
+ r = 16,
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
+ lora_alpha = 16,
+ use_gradient_checkpointing = "unsloth",
)
return model, tokenizer
-@pytest.fixture(scope="session", params=torchao_models)
+@pytest.fixture(scope = "session", params = torchao_models)
def fp16_model_tokenizer(request):
"""Load model in FP16 for TorchAO quantization"""
model_name = request.param
@@ -75,29 +75,29 @@ def fp16_model_tokenizer(request):
model, tokenizer = FastModel.from_pretrained(
model_name,
- max_seq_length=128,
- dtype=None,
- load_in_4bit=False, # No BnB quantization
+ max_seq_length = 128,
+ dtype = None,
+ load_in_4bit = False, # No BnB quantization
)
# Apply LoRA
model = FastModel.get_peft_model(
model,
- r=16,
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
- lora_alpha=16,
- use_gradient_checkpointing="unsloth",
+ r = 16,
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
+ lora_alpha = 16,
+ use_gradient_checkpointing = "unsloth",
)
return model, tokenizer
-@pytest.fixture(scope="session")
+@pytest.fixture(scope = "session")
def model(loaded_model_tokenizer):
return loaded_model_tokenizer[0]
-@pytest.fixture(scope="session")
+@pytest.fixture(scope = "session")
def tokenizer(loaded_model_tokenizer):
return loaded_model_tokenizer[1]
@@ -133,7 +133,7 @@ def test_save_merged_16bit(model, tokenizer, temp_save_dir: str):
)
model.save_pretrained_merged(
- save_path, tokenizer=tokenizer, save_method="merged_16bit"
+ save_path, tokenizer = tokenizer, save_method = "merged_16bit"
)
# Check model files
@@ -172,9 +172,9 @@ def test_save_merged_16bit(model, tokenizer, temp_save_dir: str):
# Test loading the model from the saved path
loaded_model, loaded_tokenizer = FastLanguageModel.from_pretrained(
save_path,
- max_seq_length=128,
- dtype=None,
- load_in_4bit=True,
+ max_seq_length = 128,
+ dtype = None,
+ load_in_4bit = True,
)
@@ -186,7 +186,7 @@ def test_save_merged_4bit(model, tokenizer, temp_save_dir: str):
)
model.save_pretrained_merged(
- save_path, tokenizer=tokenizer, save_method="merged_4bit_forced"
+ save_path, tokenizer = tokenizer, save_method = "merged_4bit_forced"
)
# Check model files
@@ -230,15 +230,15 @@ def test_save_merged_4bit(model, tokenizer, temp_save_dir: str):
# Test loading the model from the saved path
loaded_model, loaded_tokenizer = FastModel.from_pretrained(
save_path,
- max_seq_length=128,
- dtype=None,
- load_in_4bit=True,
+ max_seq_length = 128,
+ dtype = None,
+ load_in_4bit = True,
)
@pytest.mark.skipif(
importlib.util.find_spec("torchao") is None,
- reason="require torchao to be installed",
+ reason = "require torchao to be installed",
)
def test_save_torchao(fp16_model_tokenizer, temp_save_dir: str):
model, tokenizer = fp16_model_tokenizer
@@ -251,9 +251,9 @@ def test_save_torchao(fp16_model_tokenizer, temp_save_dir: str):
torchao_config = Int8DynamicActivationInt8WeightConfig()
model.save_pretrained_torchao(
save_path,
- tokenizer=tokenizer,
- torchao_config=torchao_config,
- push_to_hub=False,
+ tokenizer = tokenizer,
+ torchao_config = torchao_config,
+ push_to_hub = False,
)
weight_files_16bit = [
@@ -316,15 +316,15 @@ def test_save_torchao(fp16_model_tokenizer, temp_save_dir: str):
with torch.serialization.safe_globals([getattr]):
loaded_model, loaded_tokenizer = FastModel.from_pretrained(
torchao_save_path,
- max_seq_length=128,
- dtype=None,
- load_in_4bit=False,
+ max_seq_length = 128,
+ dtype = None,
+ load_in_4bit = False,
)
@pytest.mark.skipif(
importlib.util.find_spec("torchao") is None,
- reason="require torchao to be installed",
+ reason = "require torchao to be installed",
)
def test_save_and_inference_torchao(fp16_model_tokenizer, temp_save_dir: str):
model, tokenizer = fp16_model_tokenizer
@@ -343,9 +343,9 @@ def test_save_and_inference_torchao(fp16_model_tokenizer, temp_save_dir: str):
# Save with TorchAO
model.save_pretrained_torchao(
save_path,
- tokenizer=tokenizer,
- torchao_config=torchao_config,
- push_to_hub=False,
+ tokenizer = tokenizer,
+ torchao_config = torchao_config,
+ push_to_hub = False,
)
torchao_save_path = save_path + "-torchao"
@@ -361,9 +361,9 @@ def test_save_and_inference_torchao(fp16_model_tokenizer, temp_save_dir: str):
with torch.serialization.safe_globals([getattr]):
loaded_model, loaded_tokenizer = FastModel.from_pretrained(
torchao_save_path,
- max_seq_length=128,
- dtype=None,
- load_in_4bit=False,
+ max_seq_length = 128,
+ dtype = None,
+ load_in_4bit = False,
)
FastModel.for_inference(loaded_model) # Enable native 2x faster inference
@@ -376,24 +376,24 @@ def test_save_and_inference_torchao(fp16_model_tokenizer, temp_save_dir: str):
]
inputs = loaded_tokenizer.apply_chat_template(
messages,
- tokenize=True,
- add_generation_prompt=True, # Must add for generation
- return_tensors="pt",
+ tokenize = True,
+ add_generation_prompt = True, # Must add for generation
+ return_tensors = "pt",
).to("cuda")
outputs = loaded_model.generate( # ← Use loaded_model, not model
- input_ids=inputs,
- max_new_tokens=64,
- use_cache=False, # Avoid cache issues
- temperature=1.5,
- min_p=0.1,
- do_sample=True,
- pad_token_id=loaded_tokenizer.pad_token_id or loaded_tokenizer.eos_token_id,
+ input_ids = inputs,
+ max_new_tokens = 64,
+ use_cache = False, # Avoid cache issues
+ temperature = 1.5,
+ min_p = 0.1,
+ do_sample = True,
+ pad_token_id = loaded_tokenizer.pad_token_id or loaded_tokenizer.eos_token_id,
)
# Decode with the LOADED tokenizer
- generated_text = loaded_tokenizer.decode(outputs[0], skip_special_tokens=True)
- input_text = loaded_tokenizer.decode(inputs[0], skip_special_tokens=True)
+ generated_text = loaded_tokenizer.decode(outputs[0], skip_special_tokens = True)
+ input_text = loaded_tokenizer.decode(inputs[0], skip_special_tokens = True)
response_part = generated_text[len(input_text) :].strip()
print(f"Input: {input_text}")
diff --git a/tests/saving/text_to_speech_models/test_csm.py b/tests/saving/text_to_speech_models/test_csm.py
index 7f82ca63a..c1a892a8d 100644
--- a/tests/saving/text_to_speech_models/test_csm.py
+++ b/tests/saving/text_to_speech_models/test_csm.py
@@ -26,11 +26,11 @@ print(f"{'='*80}")
model, tokenizer = FastModel.from_pretrained(
- model_name="unsloth/csm-1b",
- max_seq_length=2048, # Choose any for long context!
- dtype=None, # Leave as None for auto-detection
- auto_model=CsmForConditionalGeneration,
- load_in_4bit=False, # Select True for 4bit - reduces memory usage
+ model_name = "unsloth/csm-1b",
+ max_seq_length = 2048, # Choose any for long context!
+ dtype = None, # Leave as None for auto-detection
+ auto_model = CsmForConditionalGeneration,
+ load_in_4bit = False, # Select True for 4bit - reduces memory usage
)
@@ -39,8 +39,8 @@ base_model_class = model.__class__.__name__
model = FastModel.get_peft_model(
model,
- r=32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
- target_modules=[
+ r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
+ target_modules = [
"q_proj",
"k_proj",
"v_proj",
@@ -49,14 +49,14 @@ model = FastModel.get_peft_model(
"up_proj",
"down_proj",
],
- lora_alpha=32,
- lora_dropout=0, # Supports any, but = 0 is optimized
- bias="none", # Supports any, but = "none" is optimized
+ lora_alpha = 32,
+ lora_dropout = 0, # Supports any, but = 0 is optimized
+ bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
- use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
- random_state=3407,
- use_rslora=False, # We support rank stabilized LoRA
- loftq_config=None, # And LoftQ
+ use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
+ random_state = 3407,
+ use_rslora = False, # We support rank stabilized LoRA
+ loftq_config = None, # And LoftQ
)
print("✅ Model and LoRA adapters loaded successfully!")
@@ -110,11 +110,11 @@ print(f"{'='*80}")
model, processor = FastModel.from_pretrained(
- model_name="./csm",
- max_seq_length=2048, # Choose any for long context!
- dtype=None, # Leave as None for auto-detection
- auto_model=CsmForConditionalGeneration,
- load_in_4bit=False, # Select True for 4bit - reduces memory usage
+ model_name = "./csm",
+ max_seq_length = 2048, # Choose any for long context!
+ dtype = None, # Leave as None for auto-detection
+ auto_model = CsmForConditionalGeneration,
+ load_in_4bit = False, # Select True for 4bit - reduces memory usage
)
from transformers import AutoProcessor
@@ -138,19 +138,19 @@ try:
"We just finished fine tuning a text to speech model... and it's pretty good!"
)
speaker_id = 0
- inputs = processor(f"[{speaker_id}]{text}", add_special_tokens=True).to("cuda")
+ inputs = processor(f"[{speaker_id}]{text}", add_special_tokens = True).to("cuda")
audio_values = model.generate(
**inputs,
- max_new_tokens=125, # 125 tokens is 10 seconds of audio, for longer speech increase this
+ max_new_tokens = 125, # 125 tokens is 10 seconds of audio, for longer speech increase this
# play with these parameters to get the best results
- depth_decoder_temperature=0.6,
- depth_decoder_top_k=0,
- depth_decoder_top_p=0.9,
- temperature=0.8,
- top_k=50,
- top_p=1.0,
+ depth_decoder_temperature = 0.6,
+ depth_decoder_top_k = 0,
+ depth_decoder_top_p = 0.9,
+ temperature = 0.8,
+ top_k = 50,
+ top_p = 1.0,
#########################################################
- output_audio=True,
+ output_audio = True,
)
audio = audio_values[0].to(torch.float32).cpu().numpy()
sf.write("example_without_context.wav", audio, 24000)
diff --git a/tests/saving/text_to_speech_models/test_lasa.py b/tests/saving/text_to_speech_models/test_lasa.py
index e9e05af08..804ff512f 100644
--- a/tests/saving/text_to_speech_models/test_lasa.py
+++ b/tests/saving/text_to_speech_models/test_lasa.py
@@ -42,10 +42,10 @@ print(f"{'='*80}")
max_seq_length = 2048
model, tokenizer = FastLanguageModel.from_pretrained(
- model_name="unsloth/Llasa-1B",
- max_seq_length=max_seq_length,
- dtype=None, # Select None for auto detection
- load_in_4bit=False, # Choose True for 4bit which reduces memory
+ model_name = "unsloth/Llasa-1B",
+ max_seq_length = max_seq_length,
+ dtype = None, # Select None for auto detection
+ load_in_4bit = False, # Choose True for 4bit which reduces memory
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
@@ -54,16 +54,16 @@ base_model_class = model.__class__.__name__
model = FastLanguageModel.get_peft_model(
model,
- r=128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
- target_modules=["q_proj", "v_proj"],
- lora_alpha=128,
- lora_dropout=0, # Supports any, but = 0 is optimized
- bias="none", # Supports any, but = "none" is optimized
+ r = 128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
+ target_modules = ["q_proj", "v_proj"],
+ lora_alpha = 128,
+ lora_dropout = 0, # Supports any, but = 0 is optimized
+ bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
- use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
- random_state=3407,
- use_rslora=False, # We support rank stabilized LoRA
- loftq_config=None, # And LoftQ
+ use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
+ random_state = 3407,
+ use_rslora = False, # We support rank stabilized LoRA
+ loftq_config = None, # And LoftQ
)
print("✅ Model and LoRA adapters loaded successfully!")
@@ -117,10 +117,10 @@ print(f"{'='*80}")
model, tokenizer = FastLanguageModel.from_pretrained(
- model_name="./lasa",
- max_seq_length=max_seq_length,
- dtype=None, # Select None for auto detection
- load_in_4bit=False, # Choose True for 4bit which reduces memory
+ model_name = "./lasa",
+ max_seq_length = max_seq_length,
+ dtype = None, # Select None for auto detection
+ load_in_4bit = False, # Choose True for 4bit which reduces memory
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
@@ -166,7 +166,7 @@ def extract_speech_ids(speech_tokens_str):
# TTS start!
with torch.inference_mode():
- with torch.amp.autocast("cuda", dtype=model.dtype):
+ with torch.amp.autocast("cuda", dtype = model.dtype):
formatted_text = (
f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
)
@@ -178,7 +178,7 @@ with torch.inference_mode():
]
input_ids = tokenizer.apply_chat_template(
- chat, tokenize=True, return_tensors="pt", continue_final_message=True
+ chat, tokenize = True, return_tensors = "pt", continue_final_message = True
)
input_ids = input_ids.to("cuda")
@@ -187,16 +187,16 @@ with torch.inference_mode():
# Generate the speech autoregressively
outputs = model.generate(
input_ids,
- max_length=2048, # We trained our model with a max length of 2048
- eos_token_id=speech_end_id,
- do_sample=True,
- top_p=1.2, # Adjusts the diversity of generated content
- temperature=1.2, # Controls randomness in output
+ max_length = 2048, # We trained our model with a max length of 2048
+ eos_token_id = speech_end_id,
+ do_sample = True,
+ top_p = 1.2, # Adjusts the diversity of generated content
+ temperature = 1.2, # Controls randomness in output
)
# Extract the speech tokens
generated_ids = outputs[0][input_ids.shape[1] : -1]
- speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+ speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens = True)
# Convert token <|s_23456|> to int 23456
speech_tokens = extract_speech_ids(speech_tokens)
diff --git a/tests/saving/text_to_speech_models/test_whisper.py b/tests/saving/text_to_speech_models/test_whisper.py
index 002d136cb..55f6d98ca 100644
--- a/tests/saving/text_to_speech_models/test_whisper.py
+++ b/tests/saving/text_to_speech_models/test_whisper.py
@@ -28,12 +28,12 @@ print(f"{'='*80}")
model, tokenizer = FastModel.from_pretrained(
- model_name="unsloth/whisper-large-v3",
- dtype=None, # Leave as None for auto detection
- load_in_4bit=False, # Set to True to do 4bit quantization which reduces memory
- auto_model=WhisperForConditionalGeneration,
- whisper_language="English",
- whisper_task="transcribe",
+ model_name = "unsloth/whisper-large-v3",
+ dtype = None, # Leave as None for auto detection
+ load_in_4bit = False, # Set to True to do 4bit quantization which reduces memory
+ auto_model = WhisperForConditionalGeneration,
+ whisper_language = "English",
+ whisper_task = "transcribe",
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
@@ -46,17 +46,17 @@ model.generation_config.forced_decoder_ids = None
model = FastModel.get_peft_model(
model,
- r=64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
- target_modules=["q_proj", "v_proj"],
- lora_alpha=64,
- lora_dropout=0, # Supports any, but = 0 is optimized
- bias="none", # Supports any, but = "none" is optimized
+ r = 64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
+ target_modules = ["q_proj", "v_proj"],
+ lora_alpha = 64,
+ lora_dropout = 0, # Supports any, but = 0 is optimized
+ bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
- use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
- random_state=3407,
- use_rslora=False, # We support rank stabilized LoRA
- loftq_config=None, # And LoftQ
- task_type=None, # ** MUST set this for Whisper **
+ use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
+ random_state = 3407,
+ use_rslora = False, # We support rank stabilized LoRA
+ loftq_config = None, # And LoftQ
+ task_type = None, # ** MUST set this for Whisper **
)
print("✅ Model and LoRA adapters loaded successfully!")
@@ -110,12 +110,12 @@ print(f"{'='*80}")
model, tokenizer = FastModel.from_pretrained(
- model_name="./whisper",
- dtype=None, # Leave as None for auto detection
- load_in_4bit=False, # Set to True to do 4bit quantization which reduces memory
- auto_model=WhisperForConditionalGeneration,
- whisper_language="English",
- whisper_task="transcribe",
+ model_name = "./whisper",
+ dtype = None, # Leave as None for auto detection
+ load_in_4bit = False, # Set to True to do 4bit quantization which reduces memory
+ auto_model = WhisperForConditionalGeneration,
+ whisper_language = "English",
+ whisper_task = "transcribe",
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
@@ -135,7 +135,7 @@ try:
headers = {
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
- response = requests.get(audio_url, headers=headers)
+ response = requests.get(audio_url, headers = headers)
response.raise_for_status()
with open(audio_file, "wb") as f:
f.write(response.content)
@@ -156,12 +156,12 @@ model.eval()
# Create pipeline without specifying the device
whisper = pipeline(
"automatic-speech-recognition",
- model=model,
- tokenizer=tokenizer.tokenizer,
- feature_extractor=tokenizer.feature_extractor,
- processor=tokenizer,
- return_language=True,
- torch_dtype=torch.float16, # Remove the device parameter
+ model = model,
+ tokenizer = tokenizer.tokenizer,
+ feature_extractor = tokenizer.feature_extractor,
+ processor = tokenizer,
+ return_language = True,
+ torch_dtype = torch.float16, # Remove the device parameter
)
# Example usage
audio_file = "Speech_12dB_s16.flac"
diff --git a/tests/saving/vision_models/test_index_file_sharded_model.py b/tests/saving/vision_models/test_index_file_sharded_model.py
index 39cadb433..f73716984 100644
--- a/tests/saving/vision_models/test_index_file_sharded_model.py
+++ b/tests/saving/vision_models/test_index_file_sharded_model.py
@@ -21,7 +21,7 @@ from tests.utils.cleanup_utils import safe_remove_directory
## Dataset Preparation"""
print("\n📊 Loading and preparing dataset...")
-dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split="train")
+dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split = "train")
# To select the first 2000 examples
train_dataset = dataset.select(range(2000))
@@ -81,11 +81,11 @@ print("🤖 Loading base vision model...")
try:
model, tokenizer = FastVisionModel.from_pretrained(
# model_name = "unsloth/Qwen2-VL-7B-Instruct",
- model_name="unsloth/Qwen2-VL-7B-Instruct",
- max_seq_length=2048, # Choose any for long context!
- load_in_4bit=True, # 4 bit quantization to reduce memory
- load_in_8bit=False, # [NEW!] A bit more accurate, uses 2x memory
- full_finetuning=False, # [NEW!] We have full finetuning now!
+ model_name = "unsloth/Qwen2-VL-7B-Instruct",
+ max_seq_length = 2048, # Choose any for long context!
+ load_in_4bit = True, # 4 bit quantization to reduce memory
+ load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
+ full_finetuning = False, # [NEW!] We have full finetuning now!
)
except Exception as e:
print(f"❌ Failed to load base model: {e}")
@@ -96,18 +96,18 @@ print("\n🔧 Setting up LoRA configuration...")
try:
model = FastVisionModel.get_peft_model(
model,
- finetune_vision_layers=True, # Turn off for just text!
- finetune_language_layers=True, # Should leave on!
- finetune_attention_modules=True, # Attention good for GRPO
- finetune_mlp_modules=True, # SHould leave on always!
- r=16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
- lora_alpha=32,
- lora_dropout=0, # Supports any, but = 0 is optimized
- bias="none", # Supports any, but = "none" is optimized
- use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
- random_state=3407,
- use_rslora=False, # We support rank stabilized LoRA
- loftq_config=None, # And LoftQ
+ finetune_vision_layers = True, # Turn off for just text!
+ finetune_language_layers = True, # Should leave on!
+ finetune_attention_modules = True, # Attention good for GRPO
+ finetune_mlp_modules = True, # SHould leave on always!
+ r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
+ lora_alpha = 32,
+ lora_dropout = 0, # Supports any, but = 0 is optimized
+ bias = "none", # Supports any, but = "none" is optimized
+ use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
+ random_state = 3407,
+ use_rslora = False, # We support rank stabilized LoRA
+ loftq_config = None, # And LoftQ
)
print("✅ LoRA configuration applied successfully!")
print(f" 🎯 LoRA rank (r): 16")
@@ -128,40 +128,40 @@ FastVisionModel.for_training(model) # Enable for training!
try:
trainer = SFTTrainer(
- model=model,
- tokenizer=tokenizer,
- data_collator=UnslothVisionDataCollator(model, tokenizer),
- train_dataset=train_dataset,
- args=SFTConfig(
+ model = model,
+ tokenizer = tokenizer,
+ data_collator = UnslothVisionDataCollator(model, tokenizer),
+ train_dataset = train_dataset,
+ args = SFTConfig(
# per_device_train_batch_size = 4,
# gradient_accumulation_steps = 8,
- per_device_train_batch_size=2,
- gradient_accumulation_steps=4,
- gradient_checkpointing=True,
- gradient_checkpointing_kwargs={
+ per_device_train_batch_size = 2,
+ gradient_accumulation_steps = 4,
+ gradient_checkpointing = True,
+ gradient_checkpointing_kwargs = {
"use_reentrant": False
}, # use reentrant checkpointing
- max_grad_norm=0.3, # max gradient norm based on QLoRA paper
- warmup_ratio=0.03,
+ max_grad_norm = 0.3, # max gradient norm based on QLoRA paper
+ warmup_ratio = 0.03,
# num_train_epochs = 2, # Set this instead of max_steps for full training runs
- max_steps=10,
- learning_rate=2e-4,
- fp16=not is_bf16_supported(),
- bf16=is_bf16_supported(),
- logging_steps=5,
- save_strategy="epoch",
- optim="adamw_torch_fused",
- weight_decay=0.01,
- lr_scheduler_type="linear",
- seed=3407,
- output_dir="checkpoints",
- report_to="none", # For Weights and Biases
+ max_steps = 10,
+ learning_rate = 2e-4,
+ fp16 = not is_bf16_supported(),
+ bf16 = is_bf16_supported(),
+ logging_steps = 5,
+ save_strategy = "epoch",
+ optim = "adamw_torch_fused",
+ weight_decay = 0.01,
+ lr_scheduler_type = "linear",
+ seed = 3407,
+ output_dir = "checkpoints",
+ report_to = "none", # For Weights and Biases
# You MUST put the below items for vision finetuning:
- remove_unused_columns=False,
- dataset_text_field="",
- dataset_kwargs={"skip_prepare_dataset": True},
- dataset_num_proc=4,
- max_seq_length=2048,
+ remove_unused_columns = False,
+ dataset_text_field = "",
+ dataset_kwargs = {"skip_prepare_dataset": True},
+ dataset_num_proc = 4,
+ max_seq_length = 2048,
),
)
print("✅ Trainer setup completed!")
@@ -221,7 +221,7 @@ try:
print("=== UPLOADING MODEL TO HUB ===".center(80))
print("=" * 80 + "\n")
print(f"🚀 Uploading to repository: {repo_name}")
- model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token)
+ model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)
success["upload"] = True
print("✅ Model uploaded successfully!")
except Exception as e:
@@ -233,8 +233,8 @@ try:
print("\n" + "=" * 80)
print("=== VERIFYING REPO CONTENTS ===".center(80))
print("=" * 80 + "\n")
- fs = HfFileSystem(token=hf_token)
- file_list = fs.ls(repo_name, detail=True)
+ fs = HfFileSystem(token = hf_token)
+ file_list = fs.ls(repo_name, detail = True)
safetensors_found = any(
file["name"].endswith("model.safetensors.index.json") for file in file_list
)
diff --git a/tests/saving/vision_models/test_push_to_hub_merged.py b/tests/saving/vision_models/test_push_to_hub_merged.py
index 8cea6ad26..74fa05898 100644
--- a/tests/saving/vision_models/test_push_to_hub_merged.py
+++ b/tests/saving/vision_models/test_push_to_hub_merged.py
@@ -22,7 +22,7 @@ from tests.utils.cleanup_utils import safe_remove_directory
## Dataset Preparation"""
print("\n📊 Loading and preparing dataset...")
-dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split="train")
+dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split = "train")
# To select the first 2000 examples
train_dataset = dataset.select(range(2000))
@@ -82,11 +82,11 @@ print("🤖 Loading base vision model...")
try:
model, tokenizer = FastVisionModel.from_pretrained(
# model_name = "unsloth/Qwen2-VL-7B-Instruct",
- model_name="unsloth/Qwen2-VL-2B-Instruct",
- max_seq_length=2048, # Choose any for long context!
- load_in_4bit=True, # 4 bit quantization to reduce memory
- load_in_8bit=False, # [NEW!] A bit more accurate, uses 2x memory
- full_finetuning=False, # [NEW!] We have full finetuning now!
+ model_name = "unsloth/Qwen2-VL-2B-Instruct",
+ max_seq_length = 2048, # Choose any for long context!
+ load_in_4bit = True, # 4 bit quantization to reduce memory
+ load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
+ full_finetuning = False, # [NEW!] We have full finetuning now!
)
except Exception as e:
print(f"❌ Failed to load base model: {e}")
@@ -97,18 +97,18 @@ print("\n🔧 Setting up LoRA configuration...")
try:
model = FastVisionModel.get_peft_model(
model,
- finetune_vision_layers=True, # Turn off for just text!
- finetune_language_layers=True, # Should leave on!
- finetune_attention_modules=True, # Attention good for GRPO
- finetune_mlp_modules=True, # SHould leave on always!
- r=16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
- lora_alpha=32,
- lora_dropout=0, # Supports any, but = 0 is optimized
- bias="none", # Supports any, but = "none" is optimized
- use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
- random_state=3407,
- use_rslora=False, # We support rank stabilized LoRA
- loftq_config=None, # And LoftQ
+ finetune_vision_layers = True, # Turn off for just text!
+ finetune_language_layers = True, # Should leave on!
+ finetune_attention_modules = True, # Attention good for GRPO
+ finetune_mlp_modules = True, # SHould leave on always!
+ r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
+ lora_alpha = 32,
+ lora_dropout = 0, # Supports any, but = 0 is optimized
+ bias = "none", # Supports any, but = "none" is optimized
+ use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
+ random_state = 3407,
+ use_rslora = False, # We support rank stabilized LoRA
+ loftq_config = None, # And LoftQ
)
print("✅ LoRA configuration applied successfully!")
print(f" 🎯 LoRA rank (r): 16")
@@ -129,40 +129,40 @@ FastVisionModel.for_training(model) # Enable for training!
try:
trainer = SFTTrainer(
- model=model,
- tokenizer=tokenizer,
- data_collator=UnslothVisionDataCollator(model, tokenizer),
- train_dataset=train_dataset,
- args=SFTConfig(
+ model = model,
+ tokenizer = tokenizer,
+ data_collator = UnslothVisionDataCollator(model, tokenizer),
+ train_dataset = train_dataset,
+ args = SFTConfig(
# per_device_train_batch_size = 4,
# gradient_accumulation_steps = 8,
- per_device_train_batch_size=2,
- gradient_accumulation_steps=4,
- gradient_checkpointing=True,
- gradient_checkpointing_kwargs={
+ per_device_train_batch_size = 2,
+ gradient_accumulation_steps = 4,
+ gradient_checkpointing = True,
+ gradient_checkpointing_kwargs = {
"use_reentrant": False
}, # use reentrant checkpointing
- max_grad_norm=0.3, # max gradient norm based on QLoRA paper
- warmup_ratio=0.03,
+ max_grad_norm = 0.3, # max gradient norm based on QLoRA paper
+ warmup_ratio = 0.03,
# num_train_epochs = 2, # Set this instead of max_steps for full training runs
- max_steps=10,
- learning_rate=2e-4,
- fp16=not is_bf16_supported(),
- bf16=is_bf16_supported(),
- logging_steps=5,
- save_strategy="epoch",
- optim="adamw_torch_fused",
- weight_decay=0.01,
- lr_scheduler_type="linear",
- seed=3407,
- output_dir="checkpoints",
- report_to="none", # For Weights and Biases
+ max_steps = 10,
+ learning_rate = 2e-4,
+ fp16 = not is_bf16_supported(),
+ bf16 = is_bf16_supported(),
+ logging_steps = 5,
+ save_strategy = "epoch",
+ optim = "adamw_torch_fused",
+ weight_decay = 0.01,
+ lr_scheduler_type = "linear",
+ seed = 3407,
+ output_dir = "checkpoints",
+ report_to = "none", # For Weights and Biases
# You MUST put the below items for vision finetuning:
- remove_unused_columns=False,
- dataset_text_field="",
- dataset_kwargs={"skip_prepare_dataset": True},
- dataset_num_proc=4,
- max_seq_length=2048,
+ remove_unused_columns = False,
+ dataset_text_field = "",
+ dataset_kwargs = {"skip_prepare_dataset": True},
+ dataset_num_proc = 4,
+ max_seq_length = 2048,
),
)
print("✅ Trainer setup completed!")
@@ -221,7 +221,7 @@ try:
print("=== UPLOADING MODEL TO HUB ===".center(80))
print("=" * 80 + "\n")
print(f"🚀 Uploading to repository: {repo_name}")
- model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token)
+ model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)
success["upload"] = True
print("✅ Model uploaded successfully!")
except Exception as e:
diff --git a/tests/utils/aime_eval.py b/tests/utils/aime_eval.py
index 281950155..131da3e50 100644
--- a/tests/utils/aime_eval.py
+++ b/tests/utils/aime_eval.py
@@ -24,7 +24,7 @@ def download_and_combine_aime_datasets(data_dir: str = "./data/aime") -> str:
"test2025-II": "https://raw.githubusercontent.com/GAIR-NLP/AIME-Preview/main/eval/data/aime/test2025-II.jsonl",
}
- os.makedirs(data_dir, exist_ok=True)
+ os.makedirs(data_dir, exist_ok = True)
combined_filepath = os.path.join(data_dir, "aime.jsonl")
# Check if combined file already exists
@@ -67,9 +67,9 @@ def download_and_combine_aime_datasets(data_dir: str = "./data/aime") -> str:
# Write combined dataset
if all_problems:
- with open(combined_filepath, "w", encoding="utf-8") as f:
+ with open(combined_filepath, "w", encoding = "utf-8") as f:
for problem in all_problems:
- f.write(json.dumps(problem, ensure_ascii=False) + "\n")
+ f.write(json.dumps(problem, ensure_ascii = False) + "\n")
print(f"✅ Combined {len(all_problems)} problems from {len(datasets)} datasets")
print(f" Saved to: {combined_filepath}")
@@ -92,7 +92,7 @@ def load_aime_dataset(data_dir: str = "./data/aime") -> List[Dict[str, Any]]:
filepath = download_and_combine_aime_datasets(data_dir)
examples = []
- with open(filepath, "r", encoding="utf-8") as f:
+ with open(filepath, "r", encoding = "utf-8") as f:
for line_num, line in enumerate(f):
line = line.strip()
if line:
@@ -188,20 +188,20 @@ def get_num_tokens(text, tokenizer_instance):
"""Count tokens in text"""
if not text:
return 0
- encoding = tokenizer_instance(text, return_tensors="pt")
+ encoding = tokenizer_instance(text, return_tensors = "pt")
return len(encoding["input_ids"][0])
def evaluate_model_aime(
model,
tokenizer,
- model_type="base",
- lora_request=None,
- temperature=0.3,
- n_sampling=8,
- max_tokens=32768,
- top_p=0.95,
- seed=0,
+ model_type = "base",
+ lora_request = None,
+ temperature = 0.3,
+ n_sampling = 8,
+ max_tokens = 32768,
+ top_p = 0.95,
+ seed = 0,
):
"""Evaluate model on combined AIME dataset with official configuration"""
@@ -237,11 +237,11 @@ def evaluate_model_aime(
# Setup sampling parameters (AIME configuration)
sampling_params = SamplingParams(
- temperature=temperature,
- top_p=top_p,
- max_tokens=max_tokens,
- n=n_sampling, # Multiple samples per question
- seed=seed,
+ temperature = temperature,
+ top_p = top_p,
+ max_tokens = max_tokens,
+ n = n_sampling, # Multiple samples per question
+ seed = seed,
)
print(f"\n🔧 Configuration:")
@@ -272,13 +272,13 @@ def evaluate_model_aime(
# Main evaluation loop
with tqdm(
- total=len(eval_dataset), desc="Processing AIME problems", unit="problem"
+ total = len(eval_dataset), desc = "Processing AIME problems", unit = "problem"
) as pbar:
for task_id, item in enumerate(eval_dataset):
try:
# Prepare prompt
prompt_text = tokenizer.apply_chat_template(
- item["prompt"], add_generation_prompt=True, tokenize=False
+ item["prompt"], add_generation_prompt = True, tokenize = False
)
input_tokens.append(get_num_tokens(prompt_text, tokenizer))
@@ -286,9 +286,9 @@ def evaluate_model_aime(
# Generate multiple responses
outputs = model.fast_generate(
[prompt_text],
- sampling_params=sampling_params,
- lora_request=lora_request,
- use_tqdm=False,
+ sampling_params = sampling_params,
+ lora_request = lora_request,
+ use_tqdm = False,
)[0].outputs
# Process all generated responses
@@ -413,8 +413,8 @@ def evaluate_model_aime(
# Save results
filename = f"aime_eval_combined_{model_type}_t{temperature}_n{n_sampling}.json"
- with open(filename, "w", encoding="utf-8") as f:
- json.dump({"results": results, "records": records}, f, indent=4)
+ with open(filename, "w", encoding = "utf-8") as f:
+ json.dump({"results": results, "records": records}, f, indent = 4)
# Print comprehensive summary
print(f"\n{'='*70}")
@@ -517,27 +517,27 @@ def compare_aime_results(all_results):
if all_results and "source_accuracies" in all_results[0]:
datasets = list(all_results[0]["source_accuracies"].keys())
- print(f"{'Model':<15}", end="")
+ print(f"{'Model':<15}", end = "")
for dataset in datasets:
- print(f"{dataset:<15}", end="")
+ print(f"{dataset:<15}", end = "")
print()
print("-" * (15 + 15 * len(datasets)))
for result in all_results:
- print(f"{result['model_type']:<15}", end="")
+ print(f"{result['model_type']:<15}", end = "")
for dataset in datasets:
accuracy = result["source_accuracies"].get(dataset, 0)
- print(f"{accuracy:<15.1f}", end="")
+ print(f"{accuracy:<15.1f}", end = "")
print()
# Save comparison
comparison_data = {
"summary": all_results,
- "best_model": max(all_results, key=lambda x: x["accuracy"]),
+ "best_model": max(all_results, key = lambda x: x["accuracy"]),
}
with open("aime_model_comparison.json", "w") as f:
- json.dump(comparison_data, f, indent=4)
+ json.dump(comparison_data, f, indent = 4)
print(
f"\nBest performing model: {comparison_data['best_model']['model_type']} "
diff --git a/tests/utils/hf_utils.py b/tests/utils/hf_utils.py
index cc5edce02..8ad6d5ad0 100644
--- a/tests/utils/hf_utils.py
+++ b/tests/utils/hf_utils.py
@@ -79,27 +79,27 @@ def generate_responses(
skip_special_tokens: bool = True,
dtype: torch.dtype = None,
):
- inputs = [tokenizer(prompt, return_tensors="pt") for _ in range(num_generations)]
+ inputs = [tokenizer(prompt, return_tensors = "pt") for _ in range(num_generations)]
keys = inputs[0].keys()
batched_inputs = {
- key: torch.cat([input[key] for input in inputs], dim=0).to(model.device)
+ key: torch.cat([input[key] for input in inputs], dim = 0).to(model.device)
for key in keys
}
if dtype is not None:
- inference_context = torch.autocast(device_type="cuda", dtype=dtype)
+ inference_context = torch.autocast(device_type = "cuda", dtype = dtype)
else:
inference_context = nullcontext()
with inference_context:
outputs = model.generate(
**batched_inputs,
- max_new_tokens=max_new_tokens,
- do_sample=do_sample,
- temperature=temperature,
+ max_new_tokens = max_new_tokens,
+ do_sample = do_sample,
+ temperature = temperature,
)
- responses = tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens)
+ responses = tokenizer.batch_decode(outputs, skip_special_tokens = skip_special_tokens)
return responses
@@ -117,11 +117,11 @@ def sample_responses(
model,
tokenizer,
prompt,
- temperature=temperature,
- num_generations=num_generations,
- max_new_tokens=max_new_tokens,
- skip_special_tokens=skip_special_tokens,
- dtype=dtype,
+ temperature = temperature,
+ num_generations = num_generations,
+ max_new_tokens = max_new_tokens,
+ skip_special_tokens = skip_special_tokens,
+ dtype = dtype,
)
return responses
@@ -136,32 +136,32 @@ def setup_tokenizer(model_name, fixup_funcs: list[Callable] = []):
def setup_model(
model_name,
quantize: bool = True,
- dtype=torch.bfloat16,
- peft_config=None,
+ dtype = torch.bfloat16,
+ peft_config = None,
autocast_adapter: bool = True,
):
if quantize:
bnb_config = BitsAndBytesConfig(
- load_in_4bit=True,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
- bnb_4bit_compute_dtype=dtype,
+ load_in_4bit = True,
+ bnb_4bit_use_double_quant = True,
+ bnb_4bit_quant_type = "nf4",
+ bnb_4bit_compute_dtype = dtype,
)
else:
bnb_config = None
model = AutoModelForCausalLM.from_pretrained(
model_name,
- device_map="cuda:0",
- attn_implementation="sdpa",
- quantization_config=bnb_config,
- torch_dtype=dtype,
+ device_map = "cuda:0",
+ attn_implementation = "sdpa",
+ quantization_config = bnb_config,
+ torch_dtype = dtype,
)
model = prepare_model_for_kbit_training(model) if quantize else model
if peft_config is not None:
model = get_peft_model(
- model, peft_config, autocast_adapter_dtype=autocast_adapter
+ model, peft_config, autocast_adapter_dtype = autocast_adapter
)
return model
@@ -169,19 +169,19 @@ def setup_model(
def get_peft_config(
lora_rank,
- lora_alpha=None,
- lora_dropout=0.0,
- bias="none",
- target_modules="all-linear",
+ lora_alpha = None,
+ lora_dropout = 0.0,
+ bias = "none",
+ target_modules = "all-linear",
):
lora_alpha = lora_alpha or 2 * lora_rank
peft_config = LoraConfig(
- lora_alpha=lora_alpha,
- lora_dropout=lora_dropout,
- r=lora_rank,
- bias=bias,
- target_modules=target_modules,
- task_type="CAUSAL_LM",
+ lora_alpha = lora_alpha,
+ lora_dropout = lora_dropout,
+ r = lora_rank,
+ bias = bias,
+ target_modules = target_modules,
+ task_type = "CAUSAL_LM",
)
return peft_config
@@ -191,18 +191,18 @@ def setup_trainer(
tokenizer,
dataset,
train_args,
- peft_config=None,
- formatting_func=None,
- collator=None,
+ peft_config = None,
+ formatting_func = None,
+ collator = None,
):
return SFTTrainer(
- model=model,
- peft_config=peft_config,
- train_dataset=dataset,
- processing_class=tokenizer,
- formatting_func=formatting_func,
- data_collator=collator,
- args=train_args,
+ model = model,
+ peft_config = peft_config,
+ train_dataset = dataset,
+ processing_class = tokenizer,
+ formatting_func = formatting_func,
+ data_collator = collator,
+ args = train_args,
)
@@ -212,17 +212,17 @@ def setup_lora(
dataset,
peft_config,
train_args,
- formatting_func=None,
- collator=None,
+ formatting_func = None,
+ collator = None,
):
return LoraConfig(
- model=model,
- peft_config=peft_config,
- train_dataset=dataset,
- processing_class=tokenizer,
- formatting_func=formatting_func,
- data_collator=collator,
- args=train_args,
+ model = model,
+ peft_config = peft_config,
+ train_dataset = dataset,
+ processing_class = tokenizer,
+ formatting_func = formatting_func,
+ data_collator = collator,
+ args = train_args,
)
@@ -236,7 +236,7 @@ def convert_weights_back_to_dtype(model, dtype):
param.data = param.data.to(dtype)
-def fix_llama3_tokenizer(tokenizer, padding_side="right"):
+def fix_llama3_tokenizer(tokenizer, padding_side = "right"):
tokenizer.padding_side = padding_side
added_vocab = tokenizer.get_added_vocab()
pad_token = [w for w in added_vocab if "pad" in w]
@@ -276,12 +276,12 @@ def _convert_lora_to_linear(module: LoraLayer, adapter_name: str = "default"):
w_dq = w_dq.to(original_dtype)
new_module = torch.nn.Linear(
- w_dq.shape[1], w_dq.shape[0], bias=module.base_layer.bias is not None
+ w_dq.shape[1], w_dq.shape[0], bias = module.base_layer.bias is not None
)
- new_module.weight.data = torch.nn.Parameter(w_dq, requires_grad=False)
+ new_module.weight.data = torch.nn.Parameter(w_dq, requires_grad = False)
if module.lora_bias[adapter_name]:
bias_data = module.base_layer.bias.data + module.lora_B[adapter_name].bias
- new_module.bias.data = torch.nn.Parameter(bias_data, requires_grad=False)
+ new_module.bias.data = torch.nn.Parameter(bias_data, requires_grad = False)
return new_module
diff --git a/unsloth/__init__.py b/unsloth/__init__.py
index c9433faaa..340bcee5e 100644
--- a/unsloth/__init__.py
+++ b/unsloth/__init__.py
@@ -39,7 +39,7 @@ if already_imported:
f"to ensure all optimizations are applied. Your code may run slower or encounter "
f"memory issues without these optimizations.\n\n"
f"Please restructure your imports with 'import unsloth' at the top of your file.",
- stacklevel=2,
+ stacklevel = 2,
)
del already_imported, critical_modules
@@ -141,7 +141,7 @@ if DEVICE_TYPE == "cuda":
old_is_bf16_supported = torch.cuda.is_bf16_supported
if "including_emulation" in str(inspect.signature(old_is_bf16_supported)):
- def is_bf16_supported(including_emulation=False):
+ def is_bf16_supported(including_emulation = False):
return old_is_bf16_supported(including_emulation)
torch.cuda.is_bf16_supported = is_bf16_supported
diff --git a/unsloth/kernels/fast_lora.py b/unsloth/kernels/fast_lora.py
index a8384ca3b..60d0c318c 100644
--- a/unsloth/kernels/fast_lora.py
+++ b/unsloth/kernels/fast_lora.py
@@ -86,7 +86,7 @@ class LoRA_MLP(torch.autograd.Function):
downS,
_forward_function,
_backward_function,
- inplace=True,
+ inplace = True,
):
dtype = X.dtype
@@ -169,39 +169,39 @@ class LoRA_MLP(torch.autograd.Function):
# d_downB = (downA.t() @ h.t()) @ dY
# d_downA *= downS
# d_downB *= downS
- d_downA.addmm_(h.t(), dY @ downB.t(), alpha=downS, beta=0)
- d_downB.addmm_(downA.t() @ h.t(), dY, alpha=downS, beta=0)
+ d_downA.addmm_(h.t(), dY @ downB.t(), alpha = downS, beta = 0)
+ d_downB.addmm_(downA.t() @ h.t(), dY, alpha = downS, beta = 0)
# Up projection LoRA weights
# d_upA = X.t() @ (df @ upB.t())
# d_upB = (upA.t() @ X.t()) @ df
# d_upA *= upS
# d_upB *= upS
- d_upA.addmm_(X.t(), df @ upB.t(), alpha=upS, beta=0)
- d_upB.addmm_(upA.t() @ X.t(), df, alpha=upS, beta=0)
+ d_upA.addmm_(X.t(), df @ upB.t(), alpha = upS, beta = 0)
+ d_upB.addmm_(upA.t() @ X.t(), df, alpha = upS, beta = 0)
# Gate projection LoRA weights
# d_gateA = X.t() @ (de @ gateB.t())
# d_gateB = (gateA.t() @ X.t()) @ de
# d_gateA *= gateS
# d_gateB *= gateS
- d_gateA.addmm_(X.t(), de @ gateB.t(), alpha=gateS, beta=0)
- d_gateB.addmm_(gateA.t() @ X.t(), de, alpha=gateS, beta=0)
+ d_gateA.addmm_(X.t(), de @ gateB.t(), alpha = gateS, beta = 0)
+ d_gateB.addmm_(gateA.t() @ X.t(), de, alpha = gateS, beta = 0)
# dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
# dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
upW = fast_dequantize(upW.t(), upW_quant)
- dX = torch.matmul(df, upW.t(), out=X if ctx.inplace else None)
+ dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None)
del upW
# dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
- dX.addmm_(df @ upB.t(), upA.t(), alpha=upS)
+ dX.addmm_(df @ upB.t(), upA.t(), alpha = upS)
gateW = fast_dequantize(gateW.t(), gateW_quant)
# dX += de @ gateW.t()
dX.addmm_(de, gateW.t())
del gateW
# dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
- dX.addmm_(de @ gateB.t(), gateA.t(), alpha=gateS)
+ dX.addmm_(de @ gateB.t(), gateA.t(), alpha = gateS)
# gateW, gateW_quant, gateA, gateB, gateS,
# upW, upW_quant, upA, upB, upS,
@@ -232,7 +232,7 @@ class LoRA_MLP(torch.autograd.Function):
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
-def apply_lora_mlp_swiglu(self, X, inplace=True):
+def apply_lora_mlp_swiglu(self, X, inplace = True):
X = _maybe_fake_quantize_activations(X, self.gate_proj)
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
@@ -264,7 +264,7 @@ def apply_lora_mlp_swiglu(self, X, inplace=True):
from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel
-def apply_lora_mlp_geglu_exact(self, X, inplace=True):
+def apply_lora_mlp_geglu_exact(self, X, inplace = True):
X = _maybe_fake_quantize_activations(X, self.gate_proj)
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
@@ -375,7 +375,7 @@ class LoRA_QKV(torch.autograd.Function):
VA,
VB,
VS,
- inplace=True,
+ inplace = True,
):
dtype = X.dtype
@@ -452,32 +452,32 @@ class LoRA_QKV(torch.autograd.Function):
# d_QB = (QA.t() @ X.t()) @ dQ
# d_QA *= QS
# d_QB *= QS
- d_QA.addmm_(X.t(), dQ @ QB.t(), alpha=QS, beta=0)
- d_QB.addmm_(QA.t() @ X.t(), dQ, alpha=QS, beta=0)
+ d_QA.addmm_(X.t(), dQ @ QB.t(), alpha = QS, beta = 0)
+ d_QB.addmm_(QA.t() @ X.t(), dQ, alpha = QS, beta = 0)
# K Projection
# d_KA = X.t() @ (dK @ KB.t())
# d_KB = (KA.t() @ X.t()) @ dK
# d_KA *= KS
# d_KB *= KS
- d_KA.addmm_(X.t(), dK @ KB.t(), alpha=KS, beta=0)
- d_KB.addmm_(KA.t() @ X.t(), dK, alpha=KS, beta=0)
+ d_KA.addmm_(X.t(), dK @ KB.t(), alpha = KS, beta = 0)
+ d_KB.addmm_(KA.t() @ X.t(), dK, alpha = KS, beta = 0)
# V Projection
# d_VA = X.t() @ (dV @ VB.t())
# d_VB = (VA.t() @ X.t()) @ dV
# d_VA *= VS
# d_VB *= VS
- d_VA.addmm_(X.t(), dV @ VB.t(), alpha=VS, beta=0)
- d_VB.addmm_(VA.t() @ X.t(), dV, alpha=VS, beta=0)
+ d_VA.addmm_(X.t(), dV @ VB.t(), alpha = VS, beta = 0)
+ d_VB.addmm_(VA.t() @ X.t(), dV, alpha = VS, beta = 0)
# Combine derivatives to find dX
# dQ
QW = fast_dequantize(QW.t(), QW_quant)
- dX = torch.matmul(dQ, QW.t(), out=X if ctx.inplace else None)
+ dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None)
del QW
# dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))
- dX.addmm_(dQ @ QB.t(), QA.t(), alpha=QS)
+ dX.addmm_(dQ @ QB.t(), QA.t(), alpha = QS)
# dK
KW = fast_dequantize(KW.t(), KW_quant)
@@ -485,7 +485,7 @@ class LoRA_QKV(torch.autograd.Function):
dX.addmm_(dK, KW.t())
del KW
# dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())
- dX.addmm_(dK @ KB.t(), KA.t(), alpha=KS)
+ dX.addmm_(dK @ KB.t(), KA.t(), alpha = KS)
# dV
VW = fast_dequantize(VW.t(), VW_quant)
@@ -493,7 +493,7 @@ class LoRA_QKV(torch.autograd.Function):
dX.addmm_(dV, VW.t())
del VW
# dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())
- dX.addmm_(dV @ VB.t(), VA.t(), alpha=VS)
+ dX.addmm_(dV @ VB.t(), VA.t(), alpha = VS)
# QW, QW_quant, QA, QB, QS,
# KW, KW_quant, KA, KB, KS,
@@ -519,7 +519,7 @@ class LoRA_QKV(torch.autograd.Function):
)
-def apply_lora_qkv(self, X, inplace=True):
+def apply_lora_qkv(self, X, inplace = True):
X = _maybe_fake_quantize_activations(X, self.q_proj)
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
@@ -611,15 +611,15 @@ class LoRA_W(torch.autograd.Function):
# d_B = (A.t() @ X.t()) @ dY
# d_A *= S
# d_B *= S
- d_A.addmm_(X.t(), dY @ B.t(), alpha=S, beta=0)
- d_B.addmm_(A.t() @ X.t(), dY, alpha=S, beta=0)
+ d_A.addmm_(X.t(), dY @ B.t(), alpha = S, beta = 0)
+ d_B.addmm_(A.t() @ X.t(), dY, alpha = S, beta = 0)
# Get derivative for dX
W = fast_dequantize(W.t(), W_quant)
dX = dY @ W.t()
del W
# dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
- dX.addmm_(dY @ B.t(), A.t(), alpha=S)
+ dX.addmm_(dY @ B.t(), A.t(), alpha = S)
# W, W_quant, A, B, S
return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None
@@ -649,7 +649,7 @@ def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(
- x, *args, adapter_names=adapter_names, **kwargs
+ x, *args, adapter_names = adapter_names, **kwargs
)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
@@ -705,11 +705,11 @@ def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
result = result + self.lora_magnitude_vector[active_adapter](
x,
- lora_A=lora_A,
- lora_B=lora_B,
- scaling=scaling,
- base_layer=self.get_base_layer(),
- base_result=base_result,
+ lora_A = lora_A,
+ lora_B = lora_B,
+ scaling = scaling,
+ base_layer = self.get_base_layer(),
+ base_result = base_result,
)
if requires_conversion:
result = result.to(expected_dtype)
diff --git a/unsloth/kernels/moe/benchmark/benchmark_fused_moe.py b/unsloth/kernels/moe/benchmark/benchmark_fused_moe.py
index 2fe2afa1e..074cc5a56 100644
--- a/unsloth/kernels/moe/benchmark/benchmark_fused_moe.py
+++ b/unsloth/kernels/moe/benchmark/benchmark_fused_moe.py
@@ -56,7 +56,7 @@ def run_benchmark_forward(
hidden_size = config.hidden_size
X = torch.randn(
- bs, seqlen, hidden_size, dtype=dtype, device=device, requires_grad=True
+ bs, seqlen, hidden_size, dtype = dtype, device = device, requires_grad = True
)
# Forward
@@ -89,7 +89,7 @@ def run_benchmark_backward(
config: AutoConfig,
seqlen: int,
dtype: torch.dtype,
- bs=1,
+ bs = 1,
):
torch.manual_seed(
SEED
@@ -98,7 +98,7 @@ def run_benchmark_backward(
hidden_size = config.hidden_size
X = torch.randn(
- bs, seqlen, hidden_size, dtype=dtype, device=device, requires_grad=True
+ bs, seqlen, hidden_size, dtype = dtype, device = device, requires_grad = True
)
X_test = X.detach().clone().requires_grad_(True)
@@ -114,14 +114,14 @@ def run_benchmark_backward(
# Bench
grad_output = torch.randn_like(output)
- bench_backward_ref = lambda: output.backward(grad_output, retain_graph=True) # noqa: E731
- bench_backward_fused = lambda: test_output.backward(grad_output, retain_graph=True) # noqa: E731
+ bench_backward_ref = lambda: output.backward(grad_output, retain_graph = True) # noqa: E731
+ bench_backward_fused = lambda: test_output.backward(grad_output, retain_graph = True) # noqa: E731
ref_backward_time = do_bench(
- bench_backward_ref, grad_to_none=[X, *ref_model.parameters()]
+ bench_backward_ref, grad_to_none = [X, *ref_model.parameters()]
)
fused_backward_time = do_bench(
- bench_backward_fused, grad_to_none=[X_test, *tt_model.parameters()]
+ bench_backward_fused, grad_to_none = [X_test, *tt_model.parameters()]
)
print(
f"Backward: ref {ref_backward_time:.4f}, fused {fused_backward_time:.4f}, speedup {ref_backward_time / fused_backward_time:.1f}x"
@@ -138,10 +138,10 @@ def setup_model(
kernel_config_fwd,
kernel_config_bwd_dW,
kernel_config_bwd_dX,
- dX_only=False,
- dW_only=False,
- overlap_router_shared=False,
- device="cuda",
+ dX_only = False,
+ dW_only = False,
+ overlap_router_shared = False,
+ device = "cuda",
):
if isinstance(config, Qwen3MoeConfig):
ref_model = Qwen3MoeSparseMoeBlock(config).to(device, dtype)
@@ -149,29 +149,29 @@ def setup_model(
# Triton kernel grouped gemm version of MoE Block -- this is what we're testing
tt_model = Qwen3MoeFusedGroupedGEMMBlock.from_hf(
ref_model,
- permute_x=permute_x,
- permute_y=permute_y,
- autotune=autotune,
- kernel_config_fwd=kernel_config_fwd,
- kernel_config_bwd_dW=kernel_config_bwd_dW,
- kernel_config_bwd_dX=kernel_config_bwd_dX,
- dX_only=dX_only,
- dW_only=dW_only,
+ permute_x = permute_x,
+ permute_y = permute_y,
+ autotune = autotune,
+ kernel_config_fwd = kernel_config_fwd,
+ kernel_config_bwd_dW = kernel_config_bwd_dW,
+ kernel_config_bwd_dX = kernel_config_bwd_dX,
+ dX_only = dX_only,
+ dW_only = dW_only,
).to(device, dtype)
elif isinstance(config, Llama4TextConfig):
ref_model = Llama4TextMoe(config).to(device, dtype)
tt_model = Llama4TritonTextMoe(
config,
- overlap_router_shared=overlap_router_shared,
- permute_x=permute_x,
- permute_y=permute_y,
- autotune=autotune,
- kernel_config_fwd=kernel_config_fwd,
- kernel_config_bwd_dW=kernel_config_bwd_dW,
- kernel_config_bwd_dX=kernel_config_bwd_dX,
- dX_only=dX_only,
- dW_only=dW_only,
+ overlap_router_shared = overlap_router_shared,
+ permute_x = permute_x,
+ permute_y = permute_y,
+ autotune = autotune,
+ kernel_config_fwd = kernel_config_fwd,
+ kernel_config_bwd_dW = kernel_config_bwd_dW,
+ kernel_config_bwd_dX = kernel_config_bwd_dX,
+ dX_only = dX_only,
+ dW_only = dW_only,
).to(device, dtype)
else:
@@ -205,31 +205,31 @@ def run_benchmark(
ref_model, tt_model = setup_model(
model_config,
- dtype=dtype,
- permute_x=permute_x,
- permute_y=permute_y,
- autotune=autotune,
- kernel_config_fwd=kernel_config_fwd,
- kernel_config_bwd_dW=kernel_config_bwd_dW,
- kernel_config_bwd_dX=kernel_config_bwd_dX,
- dX_only=dX_only,
- dW_only=dW_only,
- overlap_router_shared=overlap_router_shared,
+ dtype = dtype,
+ permute_x = permute_x,
+ permute_y = permute_y,
+ autotune = autotune,
+ kernel_config_fwd = kernel_config_fwd,
+ kernel_config_bwd_dW = kernel_config_bwd_dW,
+ kernel_config_bwd_dX = kernel_config_bwd_dX,
+ dX_only = dX_only,
+ dW_only = dW_only,
+ overlap_router_shared = overlap_router_shared,
)
if mode == "forward":
ref_time, fused_time = run_benchmark_forward(
ref_model,
tt_model,
- config=model_config,
- seqlen=seqlen,
- dtype=dtype,
- autotune=autotune,
- kernel_config_fwd=kernel_config_fwd,
+ config = model_config,
+ seqlen = seqlen,
+ dtype = dtype,
+ autotune = autotune,
+ kernel_config_fwd = kernel_config_fwd,
)
else:
ref_time, fused_time = run_benchmark_backward(
- ref_model, tt_model, config=model_config, seqlen=seqlen, dtype=dtype
+ ref_model, tt_model, config = model_config, seqlen = seqlen, dtype = dtype
)
if autotune:
@@ -251,60 +251,60 @@ def run_benchmark(
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument("--results_dir", type=str, default="benchmark_results")
- parser.add_argument("--model", type=str, choices=["llama4", "qwen3"], required=True)
- parser.add_argument("--seqlen", type=int, default=1024)
+ parser.add_argument("--results_dir", type = str, default = "benchmark_results")
+ parser.add_argument("--model", type = str, choices = ["llama4", "qwen3"], required = True)
+ parser.add_argument("--seqlen", type = int, default = 1024)
parser.add_argument(
- "--dtype", type=str, choices=["bfloat16", "float16"], default="bfloat16"
+ "--dtype", type = str, choices = ["bfloat16", "float16"], default = "bfloat16"
)
- parser.add_argument("--permute_x", action="store_true")
- parser.add_argument("--permute_y", action="store_true")
- parser.add_argument("--autotune", action="store_true")
- parser.add_argument("--overlap_router_shared", action="store_true")
+ parser.add_argument("--permute_x", action = "store_true")
+ parser.add_argument("--permute_y", action = "store_true")
+ parser.add_argument("--autotune", action = "store_true")
+ parser.add_argument("--overlap_router_shared", action = "store_true")
parser.add_argument(
"--BLOCK_SIZE_M",
- nargs=2,
- type=int,
- default=[DEFAULT_M_BLOCK_SIZES[0], DEFAULT_M_BLOCK_SIZES[-1]],
+ nargs = 2,
+ type = int,
+ default = [DEFAULT_M_BLOCK_SIZES[0], DEFAULT_M_BLOCK_SIZES[-1]],
)
parser.add_argument(
"--BLOCK_SIZE_N",
- nargs=2,
- type=int,
- default=[DEFAULT_N_BLOCK_SIZES[0], DEFAULT_N_BLOCK_SIZES[-1]],
+ nargs = 2,
+ type = int,
+ default = [DEFAULT_N_BLOCK_SIZES[0], DEFAULT_N_BLOCK_SIZES[-1]],
)
parser.add_argument(
"--BLOCK_SIZE_K",
- nargs=2,
- type=int,
- default=[DEFAULT_K_BLOCK_SIZES[0], DEFAULT_K_BLOCK_SIZES[-1]],
+ nargs = 2,
+ type = int,
+ default = [DEFAULT_K_BLOCK_SIZES[0], DEFAULT_K_BLOCK_SIZES[-1]],
)
parser.add_argument(
"--num_warps",
- nargs=2,
- type=int,
- default=[DEFAULT_NUM_WARPS[0], DEFAULT_NUM_WARPS[-1]],
+ nargs = 2,
+ type = int,
+ default = [DEFAULT_NUM_WARPS[0], DEFAULT_NUM_WARPS[-1]],
)
parser.add_argument(
"--num_stages",
- nargs=2,
- type=int,
- default=[DEFAULT_NUM_STAGES[0], DEFAULT_NUM_STAGES[-1]],
+ nargs = 2,
+ type = int,
+ default = [DEFAULT_NUM_STAGES[0], DEFAULT_NUM_STAGES[-1]],
)
parser.add_argument(
- "--use_tma_load_w", action="store_true"
+ "--use_tma_load_w", action = "store_true"
) # No need to specify, will automatically parametrize these for each kernel config
parser.add_argument(
- "--use_tma_load_x", action="store_true"
+ "--use_tma_load_x", action = "store_true"
) # No need to specify, will automatically parametrize these for each kernel config
parser.add_argument(
- "--use_tma_load_dy", action="store_true"
+ "--use_tma_load_dy", action = "store_true"
) # No need to specify, will automatically parametrize these for each kernel config
parser.add_argument(
"--mode",
- type=str,
- choices=["forward", "backward", "dW", "dX"],
- default="forward",
+ type = str,
+ choices = ["forward", "backward", "dW", "dX"],
+ default = "forward",
)
args = parser.parse_args()
args.dtype = getattr(torch, args.dtype)
@@ -324,13 +324,13 @@ if __name__ == "__main__":
ref_time, fused_time = run_benchmark(
args.mode,
model_config,
- seqlen=args.seqlen,
- dtype=args.dtype,
- permute_x=args.permute_x,
- permute_y=args.permute_y,
- autotune=args.autotune,
- overlap_router_shared=args.overlap_router_shared,
- results_dir=args.results_dir,
+ seqlen = args.seqlen,
+ dtype = args.dtype,
+ permute_x = args.permute_x,
+ permute_y = args.permute_y,
+ autotune = args.autotune,
+ overlap_router_shared = args.overlap_router_shared,
+ results_dir = args.results_dir,
)
end_time = time.time()
print(f"Total time: {end_time - start_time:.4f} seconds")
@@ -343,13 +343,13 @@ if __name__ == "__main__":
kernel_configs = create_kernel_configs(args, args.permute_x, args.permute_y)
print(f"Running {len(kernel_configs)} kernel configs")
default_kernel_config_fwd = KernelConfigForward(
- permute_x=args.permute_x, permute_y=args.permute_y
+ permute_x = args.permute_x, permute_y = args.permute_y
)
default_kernel_config_bwd_dW = KernelConfigBackward_dW(
- permute_x=args.permute_x, permute_y=args.permute_y
+ permute_x = args.permute_x, permute_y = args.permute_y
)
default_kernel_config_bwd_dX = KernelConfigBackward_dX(
- permute_x=args.permute_x, permute_y=args.permute_y
+ permute_x = args.permute_x, permute_y = args.permute_y
)
results = []
for kernel_config in kernel_configs:
@@ -374,21 +374,21 @@ if __name__ == "__main__":
ref_time, fused_time = run_benchmark(
args.mode,
model_config,
- seqlen=args.seqlen,
- dtype=args.dtype,
- permute_x=kernel_config.permute_x,
- permute_y=kernel_config.permute_y,
- autotune=False,
- kernel_config_fwd=kernel_config_fwd,
- kernel_config_bwd_dW=kernel_config_bwd_dW,
- kernel_config_bwd_dX=kernel_config_bwd_dX,
+ seqlen = args.seqlen,
+ dtype = args.dtype,
+ permute_x = kernel_config.permute_x,
+ permute_y = kernel_config.permute_y,
+ autotune = False,
+ kernel_config_fwd = kernel_config_fwd,
+ kernel_config_bwd_dW = kernel_config_bwd_dW,
+ kernel_config_bwd_dX = kernel_config_bwd_dX,
)
results.append(
KernelResult(
- torch_time=ref_time,
- triton_time=fused_time,
- speedup=ref_time / fused_time,
- kernel_config=kernel_config,
+ torch_time = ref_time,
+ triton_time = fused_time,
+ speedup = ref_time / fused_time,
+ kernel_config = kernel_config,
)
)
df = post_process_results(
diff --git a/unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py b/unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py
index d57105cee..a185b5fd3 100644
--- a/unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py
+++ b/unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py
@@ -37,15 +37,15 @@ def convert_args_to_list(args):
def get_forward_configs(
- BLOCK_M=DEFAULT_M_BLOCK_SIZES,
- BLOCK_N=DEFAULT_N_BLOCK_SIZES,
- BLOCK_K=DEFAULT_K_BLOCK_SIZES,
- TMA_LOAD_X=True,
- TMA_LOAD_W=True,
- TMA_STORE=False, # NOTE: TMA_STORE is disabled for now
- num_warps=DEFAULT_NUM_WARPS,
- num_stages=DEFAULT_NUM_STAGES,
- num_ctas=DEFAULT_NUM_CTAS,
+ BLOCK_M = DEFAULT_M_BLOCK_SIZES,
+ BLOCK_N = DEFAULT_N_BLOCK_SIZES,
+ BLOCK_K = DEFAULT_K_BLOCK_SIZES,
+ TMA_LOAD_X = True,
+ TMA_LOAD_W = True,
+ TMA_STORE = False, # NOTE: TMA_STORE is disabled for now
+ num_warps = DEFAULT_NUM_WARPS,
+ num_stages = DEFAULT_NUM_STAGES,
+ num_ctas = DEFAULT_NUM_CTAS,
):
(
BLOCK_M,
@@ -95,16 +95,16 @@ def get_forward_configs(
kernel_configs.append(
triton.Config(
dict(
- BLOCK_SIZE_M=block_m,
- BLOCK_SIZE_N=block_n,
- BLOCK_SIZE_K=block_k,
- USE_TMA_LOAD_X=tma_load_x,
- USE_TMA_LOAD_W=tma_load_w,
- USE_TMA_STORE=tma_store,
+ BLOCK_SIZE_M = block_m,
+ BLOCK_SIZE_N = block_n,
+ BLOCK_SIZE_K = block_k,
+ USE_TMA_LOAD_X = tma_load_x,
+ USE_TMA_LOAD_W = tma_load_w,
+ USE_TMA_STORE = tma_store,
),
- num_warps=w,
- num_stages=s,
- num_ctas=num_ctas,
+ num_warps = w,
+ num_stages = s,
+ num_ctas = num_ctas,
)
)
@@ -112,15 +112,15 @@ def get_forward_configs(
def get_dX_kernel_configs(
- BLOCK_M=DEFAULT_M_BLOCK_SIZES,
- BLOCK_N=DEFAULT_N_BLOCK_SIZES,
- BLOCK_K=DEFAULT_K_BLOCK_SIZES,
- TMA_LOAD_dY=True,
- TMA_LOAD_W=True,
- TMA_STORE=False, # NOTE: TMA_STORE is disabled for now
- num_warps=DEFAULT_NUM_WARPS,
- num_stages=DEFAULT_NUM_STAGES,
- num_ctas=DEFAULT_NUM_CTAS,
+ BLOCK_M = DEFAULT_M_BLOCK_SIZES,
+ BLOCK_N = DEFAULT_N_BLOCK_SIZES,
+ BLOCK_K = DEFAULT_K_BLOCK_SIZES,
+ TMA_LOAD_dY = True,
+ TMA_LOAD_W = True,
+ TMA_STORE = False, # NOTE: TMA_STORE is disabled for now
+ num_warps = DEFAULT_NUM_WARPS,
+ num_stages = DEFAULT_NUM_STAGES,
+ num_ctas = DEFAULT_NUM_CTAS,
):
(
BLOCK_M,
@@ -170,16 +170,16 @@ def get_dX_kernel_configs(
kernel_configs.append(
triton.Config(
dict(
- BLOCK_SIZE_M=block_m,
- BLOCK_SIZE_N=block_n,
- BLOCK_SIZE_K=block_k,
- USE_TMA_LOAD_dY=tma_load_dy,
- USE_TMA_LOAD_W=tma_load_w,
- USE_TMA_STORE=tma_store,
+ BLOCK_SIZE_M = block_m,
+ BLOCK_SIZE_N = block_n,
+ BLOCK_SIZE_K = block_k,
+ USE_TMA_LOAD_dY = tma_load_dy,
+ USE_TMA_LOAD_W = tma_load_w,
+ USE_TMA_STORE = tma_store,
),
- num_warps=w,
- num_stages=s,
- num_ctas=num_ctas,
+ num_warps = w,
+ num_stages = s,
+ num_ctas = num_ctas,
)
)
@@ -187,15 +187,15 @@ def get_dX_kernel_configs(
def get_dW_kernel_configs(
- BLOCK_M=DEFAULT_M_BLOCK_SIZES,
- BLOCK_N=DEFAULT_N_BLOCK_SIZES,
- BLOCK_K=DEFAULT_K_BLOCK_SIZES,
- num_warps=DEFAULT_NUM_WARPS,
- num_stages=DEFAULT_NUM_STAGES,
- num_ctas=DEFAULT_NUM_CTAS,
- TMA_LOAD_dY=True,
- TMA_LOAD_X=True,
- TMA_STORE=False,
+ BLOCK_M = DEFAULT_M_BLOCK_SIZES,
+ BLOCK_N = DEFAULT_N_BLOCK_SIZES,
+ BLOCK_K = DEFAULT_K_BLOCK_SIZES,
+ num_warps = DEFAULT_NUM_WARPS,
+ num_stages = DEFAULT_NUM_STAGES,
+ num_ctas = DEFAULT_NUM_CTAS,
+ TMA_LOAD_dY = True,
+ TMA_LOAD_X = True,
+ TMA_STORE = False,
):
(
BLOCK_M,
@@ -245,16 +245,16 @@ def get_dW_kernel_configs(
kernel_configs.append(
triton.Config(
dict(
- BLOCK_SIZE_M=block_m,
- BLOCK_SIZE_N=block_n,
- BLOCK_SIZE_K=block_k,
- USE_TMA_LOAD_dY=tma_load_dy,
- USE_TMA_LOAD_X=tma_load_x,
- USE_TMA_STORE=tma_store,
+ BLOCK_SIZE_M = block_m,
+ BLOCK_SIZE_N = block_n,
+ BLOCK_SIZE_K = block_k,
+ USE_TMA_LOAD_dY = tma_load_dy,
+ USE_TMA_LOAD_X = tma_load_x,
+ USE_TMA_STORE = tma_store,
),
- num_warps=w,
- num_stages=s,
- num_ctas=num_ctas,
+ num_warps = w,
+ num_stages = s,
+ num_ctas = num_ctas,
)
)
diff --git a/unsloth/kernels/moe/grouped_gemm/kernels/backward.py b/unsloth/kernels/moe/grouped_gemm/kernels/backward.py
index a05fb4d5d..d8bdcb57e 100644
--- a/unsloth/kernels/moe/grouped_gemm/kernels/backward.py
+++ b/unsloth/kernels/moe/grouped_gemm/kernels/backward.py
@@ -84,18 +84,18 @@ def _grouped_gemm_dX_kernel(
if USE_TMA_LOAD_dY:
dY_desc = tl._experimental_make_tensor_descriptor(
dY_ptr,
- shape=[TOTAL_TOKENS, N],
- strides=[N, 1],
- block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
+ shape = [TOTAL_TOKENS, N],
+ strides = [N, 1],
+ block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],
)
if USE_TMA_LOAD_W:
expert_stride = N * K
w_desc = tl._experimental_make_tensor_descriptor(
w_ptr,
- shape=[NUM_EXPERTS, N, K],
- strides=[expert_stride, K, 1],
- block_shape=[1, BLOCK_SIZE_N, BLOCK_SIZE_K],
+ shape = [NUM_EXPERTS, N, K],
+ strides = [expert_stride, K, 1],
+ block_shape = [1, BLOCK_SIZE_N, BLOCK_SIZE_K],
)
m_end = 0
@@ -104,7 +104,7 @@ def _grouped_gemm_dX_kernel(
n_block_range = tl.arange(0, BLOCK_SIZE_N)
k_block_range = tl.arange(0, BLOCK_SIZE_K)
- for expert_idx in range(NUM_EXPERTS, flatten=FLATTEN):
+ for expert_idx in range(NUM_EXPERTS, flatten = FLATTEN):
m_start = m_end
m_size = tl.load(m_sizes_ptr + expert_idx).to(tl.int32)
m_end = m_start + m_size
@@ -125,9 +125,9 @@ def _grouped_gemm_dX_kernel(
)
dX_desc = tl._experimental_make_tensor_descriptor(
dX_ptr,
- shape=[m_end, K],
- strides=[K, 1],
- block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
+ shape = [m_end, K],
+ strides = [K, 1],
+ block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],
)
# Lower bound and upper bound are defined relative to the total tiles processed so far
@@ -152,7 +152,7 @@ def _grouped_gemm_dX_kernel(
)
expert_token_idx = tl.load(
gather_indices_ptr + indices_to_gather,
- mask=indices_to_gather < TOTAL_TOKENS,
+ mask = indices_to_gather < TOTAL_TOKENS,
)
expert_token_offsets = expert_token_idx[:, None]
@@ -210,13 +210,13 @@ def _grouped_gemm_dX_kernel(
# col_mask = offs_bk[None, :] < K
store_mask = row_mask # & col_mask
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype = tl.float32)
# GEMM main loop
for n_offset in range(0, N, BLOCK_SIZE_N):
# dY block [M, N]
if not USE_TMA_LOAD_dY:
- dY = tl.load(dY_ptrs, mask=row_mask)
+ dY = tl.load(dY_ptrs, mask = row_mask)
else:
dY = dY_desc.load(
[m_start + tile_m_idx * BLOCK_SIZE_M, n_offset]
@@ -253,7 +253,7 @@ def _grouped_gemm_dX_kernel(
tl.store(
dX_ptr + store_idx + offs_bk[None, :],
dX,
- mask=store_mask,
+ mask = store_mask,
)
# Move to the next tile within this expert group
@@ -264,9 +264,9 @@ def _grouped_gemm_dX_kernel(
_autotuned_grouped_gemm_dX_kernel = triton.autotune(
- configs=get_dX_kernel_configs(),
- prune_configs_by={"early_config_prune": prune_dX_configs},
- key=["NUM_EXPERTS", "NUM_TOKENS", "N", "K", "PERMUTE_X", "PERMUTE_Y"],
+ configs = get_dX_kernel_configs(),
+ prune_configs_by = {"early_config_prune": prune_dX_configs},
+ key = ["NUM_EXPERTS", "NUM_TOKENS", "N", "K", "PERMUTE_X", "PERMUTE_Y"],
)(_grouped_gemm_dX_kernel)
"""
@@ -324,17 +324,17 @@ def _grouped_gemm_dW_kernel(
if USE_TMA_LOAD_dY and not TMA_LOAD_BOTH:
dY_desc = tl._experimental_make_tensor_descriptor(
dY_ptr,
- shape=[TOTAL_TOKENS, N],
- strides=[N, 1],
- block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
+ shape = [TOTAL_TOKENS, N],
+ strides = [N, 1],
+ block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],
)
if USE_TMA_LOAD_X and not TMA_LOAD_BOTH:
x_desc = tl._experimental_make_tensor_descriptor(
x_ptr,
- shape=[TOTAL_TOKENS, K],
- strides=[K, 1],
- block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
+ shape = [TOTAL_TOKENS, K],
+ strides = [K, 1],
+ block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],
)
# Output tiles per expert, since each expert weight matrix is [N, K]
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
@@ -351,9 +351,9 @@ def _grouped_gemm_dW_kernel(
tl.static_assert(K % BLOCK_SIZE_K == 0, "K must be divisible by BLOCK_SIZE_K")
dW_desc = tl._experimental_make_tensor_descriptor(
dW_ptr,
- shape=[NUM_EXPERTS, N, K],
- strides=[N * K, K, 1],
- block_shape=[1, BLOCK_SIZE_N, BLOCK_SIZE_K],
+ shape = [NUM_EXPERTS, N, K],
+ strides = [N * K, K, 1],
+ block_shape = [1, BLOCK_SIZE_N, BLOCK_SIZE_K],
)
for tile_idx in range(
@@ -377,7 +377,7 @@ def _grouped_gemm_dW_kernel(
m_end = 0
for expert_idx in range(NUM_EXPERTS):
# We need to instantiate a fresh accumulator for each expert
- accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=acc_dtype)
+ accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype = acc_dtype)
m_start = m_end
# Need to figure out why this cast is needed, otherwise compiler complains about mismatching types
@@ -392,16 +392,16 @@ def _grouped_gemm_dW_kernel(
if TMA_LOAD_BOTH:
dY_desc = tl._experimental_make_tensor_descriptor(
dY_ptr,
- shape=[m_end, N],
- strides=[N, 1],
- block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
+ shape = [m_end, N],
+ strides = [N, 1],
+ block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],
)
x_desc = tl._experimental_make_tensor_descriptor(
x_ptr,
- shape=[m_end, K],
- strides=[K, 1],
- block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
+ shape = [m_end, K],
+ strides = [K, 1],
+ block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],
)
for tile_m_idx in range(0, m_size, BLOCK_SIZE_M):
@@ -425,7 +425,7 @@ def _grouped_gemm_dW_kernel(
# indices_to_gather = m_start + gather_offsets
expert_token_idx = tl.load(
gather_indices_ptr + indices_to_gather,
- mask=indices_to_gather < TOTAL_TOKENS,
+ mask = indices_to_gather < TOTAL_TOKENS,
)
expert_token_offsets = expert_token_idx[:, None]
@@ -461,7 +461,7 @@ def _grouped_gemm_dW_kernel(
x_ptr
+ x_row_load_idx
+ (k_offset + block_range_k)[None, :],
- mask=mk_mask,
+ mask = mk_mask,
)
if USE_TMA_LOAD_dY:
@@ -471,7 +471,7 @@ def _grouped_gemm_dW_kernel(
dY_ptr
+ dY_row_load_idx
+ (n_offset + block_range_n)[None, :],
- mask=mn_mask,
+ mask = mn_mask,
)
accumulator += tl.dot(
@@ -491,12 +491,12 @@ def _grouped_gemm_dW_kernel(
+ store_row_offs[:, None] * K
+ (k_offset + block_range_k)[None, :],
y,
- mask=nk_mask,
+ mask = nk_mask,
)
_autotuned_grouped_gemm_dW_kernel = triton.autotune(
- configs=get_dW_kernel_configs(),
- prune_configs_by={"early_config_prune": prune_kernel_configs_backward_dW},
- key=["NUM_EXPERTS", "NUM_TOKENS", "N", "K", "PERMUTE_X", "PERMUTE_Y"],
+ configs = get_dW_kernel_configs(),
+ prune_configs_by = {"early_config_prune": prune_kernel_configs_backward_dW},
+ key = ["NUM_EXPERTS", "NUM_TOKENS", "N", "K", "PERMUTE_X", "PERMUTE_Y"],
)(_grouped_gemm_dW_kernel)
diff --git a/unsloth/kernels/moe/grouped_gemm/reference/layers/llama4_moe.py b/unsloth/kernels/moe/grouped_gemm/reference/layers/llama4_moe.py
index 1474c95bd..4010c77ce 100644
--- a/unsloth/kernels/moe/grouped_gemm/reference/layers/llama4_moe.py
+++ b/unsloth/kernels/moe/grouped_gemm/reference/layers/llama4_moe.py
@@ -51,9 +51,9 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe):
def __init__(
self,
config: Llama4TextConfig,
- overlap_router_shared=False,
- verbose=False,
- debug=False,
+ overlap_router_shared = False,
+ verbose = False,
+ debug = False,
):
super().__init__(config)
self.overlap_router_shared = overlap_router_shared
@@ -136,7 +136,7 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe):
hidden_states = hidden_states.view(-1, self.hidden_dim)
router_logits = self.router(hidden_states)
routing_weights, selected_experts = torch.topk(
- router_logits, self.top_k, dim=-1
+ router_logits, self.top_k, dim = -1
)
routing_weights = F.sigmoid(routing_weights.float()).to(hidden_states.dtype)
@@ -195,7 +195,7 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe):
)
if self.top_k > 1:
- hidden_states = hidden_states.sum(dim=1)
+ hidden_states = hidden_states.sum(dim = 1)
hidden_states_after_weight_merge = hidden_states.view(-1, hidden_dim)
# 1. Compute tokens per expert and indices for gathering tokes from token order to expert order
@@ -212,7 +212,7 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe):
# Start expert computation
first_gemm = torch_grouped_gemm(
- X=hidden_states, W=self.experts.gate_up_proj, m_sizes=token_counts_by_expert
+ X = hidden_states, W = self.experts.gate_up_proj, m_sizes = token_counts_by_expert
)
assert first_gemm.shape == (total_tokens, 2 * self.experts.expert_dim)
@@ -221,7 +221,7 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe):
# See comment above
second_gemm = torch_grouped_gemm(
- X=intermediate, W=self.experts.down_proj, m_sizes=token_counts_by_expert
+ X = intermediate, W = self.experts.down_proj, m_sizes = token_counts_by_expert
)
assert second_gemm.shape == (total_tokens, hidden_dim)
@@ -234,17 +234,17 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe):
result = (
Llama4MoeResult(
- token_counts_by_expert=token_counts_by_expert,
- gather_indices=gather_indices,
- topk_weights=routing_weights,
- hidden_states_after_weight_merge=hidden_states_after_weight_merge,
- first_gemm=first_gemm,
- intermediate=intermediate,
- second_gemm=second_gemm,
- hidden_states_unpermute=hidden_states_unpermute,
- shared_expert_out=shared_expert_out,
- final_out=final_out,
- router_logits=router_logits,
+ token_counts_by_expert = token_counts_by_expert,
+ gather_indices = gather_indices,
+ topk_weights = routing_weights,
+ hidden_states_after_weight_merge = hidden_states_after_weight_merge,
+ first_gemm = first_gemm,
+ intermediate = intermediate,
+ second_gemm = second_gemm,
+ hidden_states_unpermute = hidden_states_unpermute,
+ shared_expert_out = shared_expert_out,
+ final_out = final_out,
+ router_logits = router_logits,
)
if self.debug
else (final_out, routing_weights)
@@ -257,7 +257,7 @@ class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe):
def __init__(
self,
config: Llama4TextConfig,
- overlap_router_shared=False,
+ overlap_router_shared = False,
permute_x: bool = False,
permute_y: bool = True,
autotune: bool = True,
@@ -266,9 +266,9 @@ class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe):
kernel_config_bwd_dX: KernelConfigBackward_dX = None,
dW_only: bool = False,
dX_only: bool = False,
- verbose=False,
+ verbose = False,
):
- super().__init__(config, overlap_router_shared=overlap_router_shared)
+ super().__init__(config, overlap_router_shared = overlap_router_shared)
assert not permute_x, "Llama4 triton grouped gemm does not support permute x due to pre-multiplication of router weights"
self.permute_x = permute_x
self.permute_y = permute_y
@@ -321,7 +321,7 @@ class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe):
hidden_states = hidden_states.view(-1, self.hidden_dim)
router_logits = self.router(hidden_states)
routing_weights, selected_experts = torch.topk(
- router_logits, self.top_k, dim=-1
+ router_logits, self.top_k, dim = -1
)
routing_weights = F.sigmoid(routing_weights.float()).to(hidden_states.dtype)
@@ -380,7 +380,7 @@ class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe):
)
if self.top_k > 1:
- hidden_states = hidden_states.sum(dim=1)
+ hidden_states = hidden_states.sum(dim = 1)
hidden_states = hidden_states.view(-1, hidden_dim)
# 1. Compute tokens per expert and indices for gathering tokes from token order to expert order
@@ -395,37 +395,37 @@ class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe):
# Start expert computation
hidden_states = grouped_gemm(
- X=hidden_states,
- W=self.experts.gate_up_proj,
- m_sizes=token_counts_by_expert,
- gather_indices=gather_indices,
- topk=self.top_k,
- permute_x=self.permute_x,
- permute_y=False, # output of first grouped gemm should never be permuted
- autotune=self.autotune,
- kernel_config_fwd=self.kernel_config_fwd,
- kernel_config_bwd_dW=self.kernel_config_bwd_dW,
- kernel_config_bwd_dX=self.kernel_config_bwd_dX,
- is_first_gemm=True,
- dW_only=self.dW_only,
- dX_only=self.dX_only,
+ X = hidden_states,
+ W = self.experts.gate_up_proj,
+ m_sizes = token_counts_by_expert,
+ gather_indices = gather_indices,
+ topk = self.top_k,
+ permute_x = self.permute_x,
+ permute_y = False, # output of first grouped gemm should never be permuted
+ autotune = self.autotune,
+ kernel_config_fwd = self.kernel_config_fwd,
+ kernel_config_bwd_dW = self.kernel_config_bwd_dW,
+ kernel_config_bwd_dX = self.kernel_config_bwd_dX,
+ is_first_gemm = True,
+ dW_only = self.dW_only,
+ dX_only = self.dX_only,
)
hidden_states = self.act_and_mul(hidden_states)
hidden_states = grouped_gemm(
- X=hidden_states,
- W=self.experts.down_proj,
- m_sizes=token_counts_by_expert,
- gather_indices=gather_indices,
- topk=self.top_k,
- permute_x=False,
- permute_y=self.permute_y,
- autotune=self.autotune,
- kernel_config_fwd=self.kernel_config_fwd,
- kernel_config_bwd_dW=self.kernel_config_bwd_dW,
- kernel_config_bwd_dX=self.kernel_config_bwd_dX,
- is_first_gemm=False,
- dW_only=self.dW_only,
- dX_only=self.dX_only,
+ X = hidden_states,
+ W = self.experts.down_proj,
+ m_sizes = token_counts_by_expert,
+ gather_indices = gather_indices,
+ topk = self.top_k,
+ permute_x = False,
+ permute_y = self.permute_y,
+ autotune = self.autotune,
+ kernel_config_fwd = self.kernel_config_fwd,
+ kernel_config_bwd_dW = self.kernel_config_bwd_dW,
+ kernel_config_bwd_dX = self.kernel_config_bwd_dX,
+ is_first_gemm = False,
+ dW_only = self.dW_only,
+ dX_only = self.dX_only,
)
# Post-processing
diff --git a/unsloth/kernels/moe/grouped_gemm/reference/moe_block.py b/unsloth/kernels/moe/grouped_gemm/reference/moe_block.py
index cbccf19cb..0d497f380 100644
--- a/unsloth/kernels/moe/grouped_gemm/reference/moe_block.py
+++ b/unsloth/kernels/moe/grouped_gemm/reference/moe_block.py
@@ -80,14 +80,14 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
gate,
gate_up_proj,
down_proj,
- permute_x=permute_x,
- permute_y=permute_y,
- autotune=autotune,
- kernel_config_fwd=kernel_config_fwd,
- kernel_config_bwd_dW=kernel_config_bwd_dW,
- kernel_config_bwd_dX=kernel_config_bwd_dX,
- dW_only=dW_only,
- dX_only=dX_only,
+ permute_x = permute_x,
+ permute_y = permute_y,
+ autotune = autotune,
+ kernel_config_fwd = kernel_config_fwd,
+ kernel_config_bwd_dW = kernel_config_bwd_dW,
+ kernel_config_bwd_dX = kernel_config_bwd_dX,
+ dW_only = dW_only,
+ dX_only = dX_only,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -112,37 +112,37 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
hidden_states = permute(hidden_states, gather_indices, self.top_k)
# Start expert computation
hidden_states = grouped_gemm(
- X=hidden_states,
- W=self.gate_up_proj,
- m_sizes=token_counts_by_expert,
- gather_indices=gather_indices,
- topk=self.top_k,
- permute_x=self.permute_x,
- permute_y=False, # output of first grouped gemm should never be permuted
- autotune=self.autotune,
- kernel_config_fwd=self.kernel_config_fwd,
- kernel_config_bwd_dW=self.kernel_config_bwd_dW,
- kernel_config_bwd_dX=self.kernel_config_bwd_dX,
- is_first_gemm=True,
- dW_only=self.dW_only,
- dX_only=self.dX_only,
+ X = hidden_states,
+ W = self.gate_up_proj,
+ m_sizes = token_counts_by_expert,
+ gather_indices = gather_indices,
+ topk = self.top_k,
+ permute_x = self.permute_x,
+ permute_y = False, # output of first grouped gemm should never be permuted
+ autotune = self.autotune,
+ kernel_config_fwd = self.kernel_config_fwd,
+ kernel_config_bwd_dW = self.kernel_config_bwd_dW,
+ kernel_config_bwd_dX = self.kernel_config_bwd_dX,
+ is_first_gemm = True,
+ dW_only = self.dW_only,
+ dX_only = self.dX_only,
)
hidden_states = self.act_and_mul(hidden_states)
hidden_states = grouped_gemm(
- X=hidden_states,
- W=self.down_proj,
- m_sizes=token_counts_by_expert,
- gather_indices=gather_indices,
- topk=self.top_k,
- permute_x=False,
- permute_y=self.permute_y,
- autotune=self.autotune,
- kernel_config_fwd=self.kernel_config_fwd,
- kernel_config_bwd_dW=self.kernel_config_bwd_dW,
- kernel_config_bwd_dX=self.kernel_config_bwd_dX,
- is_first_gemm=False,
- dW_only=self.dW_only,
- dX_only=self.dX_only,
+ X = hidden_states,
+ W = self.down_proj,
+ m_sizes = token_counts_by_expert,
+ gather_indices = gather_indices,
+ topk = self.top_k,
+ permute_x = False,
+ permute_y = self.permute_y,
+ autotune = self.autotune,
+ kernel_config_fwd = self.kernel_config_fwd,
+ kernel_config_bwd_dW = self.kernel_config_bwd_dW,
+ kernel_config_bwd_dX = self.kernel_config_bwd_dX,
+ is_first_gemm = False,
+ dW_only = self.dW_only,
+ dX_only = self.dX_only,
)
# Post-processing
@@ -155,7 +155,7 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
hidden_states.view(num_tokens, self.top_k, hidden_dim)
* routing_weights[..., None]
)
- hidden_states = hidden_states.sum(dim=1)
+ hidden_states = hidden_states.sum(dim = 1)
hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)
return hidden_states, router_logits
diff --git a/unsloth/kernels/moe/grouped_gemm/reference/moe_ops.py b/unsloth/kernels/moe/grouped_gemm/reference/moe_ops.py
index 821b06c97..46d9c3c51 100644
--- a/unsloth/kernels/moe/grouped_gemm/reference/moe_ops.py
+++ b/unsloth/kernels/moe/grouped_gemm/reference/moe_ops.py
@@ -56,7 +56,7 @@ def calculate_topk(
gating_output.dtype
)
else:
- scores = F.softmax(gating_output.to(torch.float32), dim=1).to(
+ scores = F.softmax(gating_output.to(torch.float32), dim = 1).to(
gating_output.dtype
)
@@ -67,13 +67,13 @@ def calculate_topk(
else:
scores = gating_output
- topk_weights, topk_ids = torch.topk(scores, k=top_k, dim=1)
+ topk_weights, topk_ids = torch.topk(scores, k = top_k, dim = 1)
if post_act:
topk_weights = _activation(topk_weights)
if renormalize:
- topk_weights /= torch.sum(topk_weights, dim=-1, keepdim=True).to(
+ topk_weights /= torch.sum(topk_weights, dim = -1, keepdim = True).to(
gating_output.dtype
)
@@ -94,12 +94,12 @@ def get_routing_indices(
# group tokens together by expert indices from 0 to num_experts and pass that to experts forward
token_counts_by_expert = torch.histc(
selected_experts.view(-1),
- bins=num_experts,
- min=0,
- max=num_experts,
+ bins = num_experts,
+ min = 0,
+ max = num_experts,
)
# token_indices_experts_sorted shape (bs*slen*top_k,)
- gather_indices = torch.argsort(selected_experts.view(-1), stable=True)
+ gather_indices = torch.argsort(selected_experts.view(-1), stable = True)
if return_scatter_indices:
scatter_indices = gather_indices.argsort()
return token_counts_by_expert, gather_indices, scatter_indices
@@ -107,7 +107,7 @@ def get_routing_indices(
return token_counts_by_expert, gather_indices
-def torch_grouped_gemm(X, W, m_sizes, transpose=True):
+def torch_grouped_gemm(X, W, m_sizes, transpose = True):
"""
X: [M, K] if forward, else [M, N]
W: [E, N, K]
@@ -127,7 +127,7 @@ def torch_grouped_gemm(X, W, m_sizes, transpose=True):
N = W.shape[1]
- result = torch.zeros((M, N), dtype=X.dtype, device=X.device)
+ result = torch.zeros((M, N), dtype = X.dtype, device = X.device)
m_start = 0
for g in range(E):
diff --git a/unsloth/kernels/moe/tests/test_llama4_moe.py b/unsloth/kernels/moe/tests/test_llama4_moe.py
index 27d5d99a2..13ad552bf 100644
--- a/unsloth/kernels/moe/tests/test_llama4_moe.py
+++ b/unsloth/kernels/moe/tests/test_llama4_moe.py
@@ -37,7 +37,7 @@ NUM_AUTOTUNE_CONFIGS = 50
@contextmanager
-def annotated_context(prelude, epilogue="Passed!", char="-", num_chars=80):
+def annotated_context(prelude, epilogue = "Passed!", char = "-", num_chars = 80):
print(char * num_chars)
print(prelude)
yield
@@ -81,7 +81,7 @@ def prep_triton_kernel_traits(autotune):
def sparse_to_dense(t: torch.Tensor):
- t = t.sum(dim=0).view(-1)
+ t = t.sum(dim = 0).view(-1)
return t
@@ -91,9 +91,9 @@ def _check_diff(
t2: torch.Tensor,
atol,
rtol,
- precision=".6f",
- verbose=False,
- msg="",
+ precision = ".6f",
+ verbose = False,
+ msg = "",
):
t2 = t2.view_as(t1)
diff = t1.sub(t2).abs().max().item()
@@ -101,7 +101,7 @@ def _check_diff(
if msg == "":
msg = "diff"
print(f"{msg}: {diff:{precision}}")
- assert torch.allclose(t1, t2, atol=atol, rtol=rtol)
+ assert torch.allclose(t1, t2, atol = atol, rtol = rtol)
def run_backwards(y: torch.Tensor, grad_output: torch.Tensor, module: torch.nn.Module):
@@ -115,19 +115,19 @@ def _check_grads(
m2: torch.nn.Module,
atol,
rtol,
- precision=".6f",
- verbose=False,
- msg="",
+ precision = ".6f",
+ verbose = False,
+ msg = "",
):
for name, param in m1.named_parameters():
_check_diff(
param.grad,
m2.get_parameter(name).grad,
- atol=atol,
- rtol=rtol,
- precision=precision,
- verbose=verbose,
- msg=f"{msg}:{name}.grad",
+ atol = atol,
+ rtol = rtol,
+ precision = precision,
+ verbose = verbose,
+ msg = f"{msg}:{name}.grad",
)
@@ -139,19 +139,19 @@ def model_config():
@pytest.mark.parametrize(
"overlap_router_shared",
[False, True],
- ids=lambda x: "overlap_router_shared" if x else "no_overlap",
+ ids = lambda x: "overlap_router_shared" if x else "no_overlap",
)
@pytest.mark.parametrize(
- "permute_y", [False, True], ids=lambda x: "permute_y" if x else "no_permute_y"
+ "permute_y", [False, True], ids = lambda x: "permute_y" if x else "no_permute_y"
)
@pytest.mark.parametrize(
- "permute_x", [False], ids=lambda x: "permute_x" if x else "no_permute_x"
+ "permute_x", [False], ids = lambda x: "permute_x" if x else "no_permute_x"
) # Llama4 does not support permute_x
@pytest.mark.parametrize(
- "autotune", [True], ids=lambda x: "autotune" if x else "manual"
+ "autotune", [True], ids = lambda x: "autotune" if x else "manual"
)
-@pytest.mark.parametrize("seqlen", SEQ_LENS, ids=lambda x: f"seqlen={x}")
-@pytest.mark.parametrize("dtype", DTYPES, ids=str)
+@pytest.mark.parametrize("seqlen", SEQ_LENS, ids = lambda x: f"seqlen={x}")
+@pytest.mark.parametrize("dtype", DTYPES, ids = str)
def test_llama4_ref(
dtype: torch.dtype,
seqlen,
@@ -161,9 +161,9 @@ def test_llama4_ref(
overlap_router_shared: bool,
model_config: Llama4TextConfig, # test fixture
bs: int = 1,
- device="cuda",
- precision=".6f",
- verbose=False,
+ device = "cuda",
+ precision = ".6f",
+ verbose = False,
):
torch.manual_seed(
SEED
@@ -172,24 +172,24 @@ def test_llama4_ref(
hidden_dim = model_config.hidden_size
atol, rtol = TOLERANCES[dtype]
check_diff = partial(
- _check_diff, atol=atol, rtol=rtol, precision=precision, verbose=verbose
+ _check_diff, atol = atol, rtol = rtol, precision = precision, verbose = verbose
)
check_grads = partial(
- _check_grads, atol=atol, rtol=rtol, precision=precision, verbose=verbose
+ _check_grads, atol = atol, rtol = rtol, precision = precision, verbose = verbose
)
# Reference op -- HF
- llama4_ref = Llama4TextMoe(model_config).to(dtype=dtype, device=device)
+ llama4_ref = Llama4TextMoe(model_config).to(dtype = dtype, device = device)
# Torch grouped gemm impl
llama4_gg_ref = Llama4GroupedGemmTextMoe(
- model_config, overlap_router_shared=overlap_router_shared
- ).to(dtype=dtype, device=device)
+ model_config, overlap_router_shared = overlap_router_shared
+ ).to(dtype = dtype, device = device)
llama4_gg_ref.copy_weights(llama4_ref)
llama4_gg_ref.check_weights(llama4_ref)
x_ref = torch.randn(
- bs, seqlen, hidden_dim, dtype=dtype, device=device, requires_grad=True
+ bs, seqlen, hidden_dim, dtype = dtype, device = device, requires_grad = True
)
x_torch_gg = x_ref.detach().clone().requires_grad_()
x_triton = x_ref.detach().clone().requires_grad_()
@@ -198,9 +198,9 @@ def test_llama4_ref(
y_torch_gg, routing_torch_gg = llama4_gg_ref(x_torch_gg)
assert y_ref.shape == y_torch_gg.shape, f"{y_ref.shape} != {y_torch_gg.shape}"
with annotated_context("Testing torch grouped gemm Llama4TextMoe"):
- check_diff(y_ref, y_torch_gg, msg="y_torch_gg")
+ check_diff(y_ref, y_torch_gg, msg = "y_torch_gg")
check_diff(
- sparse_to_dense(routing_ref), routing_torch_gg, msg="routing_torch_gg"
+ sparse_to_dense(routing_ref), routing_torch_gg, msg = "routing_torch_gg"
)
kernel_config_fwd, kernel_config_bwd_dW, kernel_config_bwd_dX = (
@@ -209,38 +209,38 @@ def test_llama4_ref(
llama4_triton = Llama4TritonTextMoe(
model_config,
- overlap_router_shared=overlap_router_shared,
- permute_x=permute_x,
- permute_y=permute_y,
- autotune=autotune,
- kernel_config_fwd=kernel_config_fwd,
- kernel_config_bwd_dW=kernel_config_bwd_dW,
- kernel_config_bwd_dX=kernel_config_bwd_dX,
- ).to(device=device, dtype=dtype)
+ overlap_router_shared = overlap_router_shared,
+ permute_x = permute_x,
+ permute_y = permute_y,
+ autotune = autotune,
+ kernel_config_fwd = kernel_config_fwd,
+ kernel_config_bwd_dW = kernel_config_bwd_dW,
+ kernel_config_bwd_dX = kernel_config_bwd_dX,
+ ).to(device = device, dtype = dtype)
llama4_triton.copy_weights(llama4_ref)
llama4_triton.check_weights(llama4_ref)
y_triton, routing_triton = llama4_triton(x_triton)
with annotated_context("Testing triton grouped gemm Llama4TextMoe forward"):
- check_diff(y_ref, y_triton, msg="y_triton")
- check_diff(sparse_to_dense(routing_ref), routing_triton, msg="routing_triton")
+ check_diff(y_ref, y_triton, msg = "y_triton")
+ check_diff(sparse_to_dense(routing_ref), routing_triton, msg = "routing_triton")
ref_grad = torch.randn_like(y_ref)
run_backwards(y_ref, ref_grad, llama4_ref)
run_backwards(y_torch_gg, ref_grad, llama4_gg_ref)
with annotated_context("Testing torch group gemm Llama4TextMoe backward"):
- check_grads(llama4_ref, llama4_gg_ref, msg="torch_gg")
+ check_grads(llama4_ref, llama4_gg_ref, msg = "torch_gg")
run_backwards(y_triton, ref_grad, llama4_triton)
with annotated_context("Testing triton group gemm Llama4TextMoe backward"):
- check_grads(llama4_ref, llama4_triton, msg="triton")
+ check_grads(llama4_ref, llama4_triton, msg = "triton")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument("--seqlen", type=int, default=1024)
+ parser.add_argument("--seqlen", type = int, default = 1024)
parser.add_argument(
- "--dtype", type=str, choices=["bfloat16", "float16"], default="bfloat16"
+ "--dtype", type = str, choices = ["bfloat16", "float16"], default = "bfloat16"
)
args = parser.parse_args()
args.dtype = getattr(torch, args.dtype)
@@ -251,12 +251,12 @@ if __name__ == "__main__":
text_config: Llama4TextConfig = get_text_config(model_id)
for overlap in [False, True]:
test_llama4_ref(
- seqlen=args.seqlen,
- model_config=text_config,
- dtype=args.dtype,
- autotune=True,
- permute_x=False,
- permute_y=True,
- overlap_router_shared=overlap,
- verbose=True,
+ seqlen = args.seqlen,
+ model_config = text_config,
+ dtype = args.dtype,
+ autotune = True,
+ permute_x = False,
+ permute_y = True,
+ overlap_router_shared = overlap,
+ verbose = True,
)
diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py
index 359b9a41b..accbed11b 100644
--- a/unsloth/kernels/rms_layernorm.py
+++ b/unsloth/kernels/rms_layernorm.py
@@ -45,16 +45,16 @@ def _rms_layernorm_forward(
X += row_idx * X_row_stride
r += row_idx * r_row_stride
- X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32)
- W_row = tl.load(W + col_offsets, mask=mask, other=0) # .to(tl.float32)
+ X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
+ W_row = tl.load(W + col_offsets, mask = mask, other = 0) # .to(tl.float32)
- row_var = tl.sum(X_row * X_row, axis=0) / n_cols
+ row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
inv_var = tl.math.rsqrt(row_var + eps)
tl.store(r, inv_var)
normed = X_row * inv_var
normed = normed.to(W_row.dtype) # Exact copy from HF
output = normed * W_row
- tl.store(Y + col_offsets, output, mask=mask)
+ tl.store(Y + col_offsets, output, mask = mask)
def _rms_layernorm_backward(
@@ -92,9 +92,9 @@ def _rms_layernorm_backward(
else:
dX = dY
- dY_row = tl.load(dY + col_offsets, mask=mask, other=0).to(tl.float32)
- X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32)
- W_row = tl.load(W + col_offsets, mask=mask, other=0).to(tl.float32)
+ dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
+ X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
+ W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
# Get saved row variance
inv_var = tl.load(r).to(tl.float32)
@@ -105,9 +105,9 @@ def _rms_layernorm_backward(
else:
dY_W = dY_row * W_row
- rowsum_dY_normed = tl.sum(dY_W * normed, axis=0)
+ rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
output = inv_var / n_cols * (n_cols * dY_W - normed * rowsum_dY_normed)
- tl.store(dX + col_offsets, output, mask=mask)
+ tl.store(dX + col_offsets, output, mask = mask)
_rms_layernorm_backward = triton.jit(_rms_layernorm_backward)
@@ -143,16 +143,16 @@ def _gemma_rms_layernorm_forward(
X += row_idx * X_row_stride
r += row_idx * r_row_stride
- X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32)
- W_row = tl.load(W + col_offsets, mask=mask, other=0).to(tl.float32)
+ X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
+ W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
- row_var = tl.sum(X_row * X_row, axis=0) / n_cols
+ row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
inv_var = tl.math.rsqrt(row_var + eps)
tl.store(r, inv_var)
normed = X_row * inv_var
output = normed * (W_row + 1.0)
- tl.store(Y + col_offsets, output, mask=mask)
+ tl.store(Y + col_offsets, output, mask = mask)
class Fast_RMS_Layernorm(torch.autograd.Function):
@@ -169,8 +169,8 @@ class Fast_RMS_Layernorm(torch.autograd.Function):
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
device = X.device
- Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=device)
- r = torch.empty(n_rows, dtype=torch.float32, device=device)
+ Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)
+ r = torch.empty(n_rows, dtype = torch.float32, device = device)
fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
with torch_gpu_device(device):
@@ -185,8 +185,8 @@ class Fast_RMS_Layernorm(torch.autograd.Function):
r.stride(0),
n_cols,
eps,
- BLOCK_SIZE=BLOCK_SIZE,
- num_warps=num_warps,
+ BLOCK_SIZE = BLOCK_SIZE,
+ num_warps = num_warps,
)
ctx.eps = eps
ctx.BLOCK_SIZE = BLOCK_SIZE
@@ -222,9 +222,9 @@ class Fast_RMS_Layernorm(torch.autograd.Function):
# dW, dW.stride(0),
n_cols,
ctx.eps,
- GEMMA=ctx.GEMMA,
- BLOCK_SIZE=ctx.BLOCK_SIZE,
- num_warps=ctx.num_warps,
+ GEMMA = ctx.GEMMA,
+ BLOCK_SIZE = ctx.BLOCK_SIZE,
+ num_warps = ctx.num_warps,
)
dX = dX.view(*shape)
return dX, None, None, None
@@ -248,7 +248,7 @@ from transformers.models.llama.modeling_llama import LlamaRMSNorm
class Unsloth_LlamaRMSNorm(LlamaRMSNorm):
def forward(self, X):
- return fast_rms_layernorm(self, X, gemma=False)
+ return fast_rms_layernorm(self, X, gemma = False)
try:
@@ -256,7 +256,7 @@ try:
class Unsloth_MllamaTextRMSNorm(MllamaTextRMSNorm):
def forward(self, X):
- return fast_rms_layernorm(self, X, gemma=False)
+ return fast_rms_layernorm(self, X, gemma = False)
except:
@@ -292,25 +292,25 @@ def unpatch_rms_layernorm():
def test_rms_layernorm(
- dim=1024,
- eps=1e-5,
- dtype=torch.float16,
- bsz=21,
- random_state=3407,
- seqlen=3341,
+ dim = 1024,
+ eps = 1e-5,
+ dtype = torch.float16,
+ bsz = 21,
+ random_state = 3407,
+ seqlen = 3341,
):
from transformers.models.llama.modeling_llama import LlamaRMSNorm
- layernorm = LlamaRMSNorm((dim,), eps=eps).to("cuda")
+ layernorm = LlamaRMSNorm((dim,), eps = eps).to("cuda")
torch.cuda.manual_seed(random_state)
torch.manual_seed(random_state)
torch.nn.init.uniform_(layernorm.weight)
- X = torch.randn((bsz, seqlen, dim), dtype=dtype, device="cuda")
+ X = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda")
XX = X.clone()
X.requires_grad_(True)
XX.requires_grad_(True)
Y = layernorm(X)
- YY = torch.randn((bsz, seqlen, dim), dtype=dtype, device="cuda", requires_grad=True)
+ YY = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda", requires_grad = True)
Y.backward(YY)
correct_grad = X.grad.clone()
# from unsloth.kernels import fast_rms_layernorm
@@ -322,14 +322,14 @@ def test_rms_layernorm(
def testing_suite_layernorm():
for dim in [512, 1024, 2048]:
for dtype in [torch.float16, torch.bfloat16]:
- with torch.autocast(device_type="cuda", dtype=dtype):
+ with torch.autocast(device_type = "cuda", dtype = dtype):
for seqlen in [3341, 2048, 349]:
for random_state in [3407, 42]:
test_rms_layernorm(
- dim=dim,
- eps=1e-5,
- dtype=dtype,
- bsz=21,
- random_state=random_state,
- seqlen=seqlen,
+ dim = dim,
+ eps = 1e-5,
+ dtype = dtype,
+ bsz = 21,
+ random_state = random_state,
+ seqlen = seqlen,
)
diff --git a/unsloth/kernels/rope_embedding.py b/unsloth/kernels/rope_embedding.py
index 90e151373..8eecec10c 100644
--- a/unsloth/kernels/rope_embedding.py
+++ b/unsloth/kernels/rope_embedding.py
@@ -50,16 +50,16 @@ def _rope_embedding(
+ (row_position % seqlen) * sin_row_stride
+ half_head_dim * 0
+ col_offsets,
- mask=mask,
- other=0,
+ mask = mask,
+ other = 0,
)
cos1 = tl.load(
cos
+ (row_position % seqlen) * cos_row_stride
+ half_head_dim * 0
+ col_offsets,
- mask=mask,
- other=0,
+ mask = mask,
+ other = 0,
)
if BACKWARD_PASS:
@@ -78,11 +78,11 @@ def _rope_embedding(
)
# For Gemma - sometimes RoPE must be done in float32 and not bfloat16
- Q1 = tl.load(Q + offs_q1, mask=mask, other=0).to(sin1.dtype)
- Q2 = tl.load(Q + offs_q2, mask=mask, other=0).to(sin1.dtype)
+ Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)
+ Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)
- tl.store(Q + offs_q1, Q1 * cos1 - Q2 * sin1, mask=mask)
- tl.store(Q + offs_q2, Q2 * cos1 + Q1 * sin1, mask=mask)
+ tl.store(Q + offs_q1, Q1 * cos1 - Q2 * sin1, mask = mask)
+ tl.store(Q + offs_q2, Q2 * cos1 + Q1 * sin1, mask = mask)
_rope_embedding = triton.jit(_rope_embedding)
@@ -134,9 +134,9 @@ class Fast_RoPE_Embedding(torch.autograd.Function):
seq_len,
head_dim,
n_heads,
- BACKWARD_PASS=False,
- BLOCK_SIZE=BLOCK_SIZE,
- num_warps=num_warps,
+ BACKWARD_PASS = False,
+ BLOCK_SIZE = BLOCK_SIZE,
+ num_warps = num_warps,
)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
@@ -177,9 +177,9 @@ class Fast_RoPE_Embedding(torch.autograd.Function):
seq_len,
head_dim,
n_heads,
- BACKWARD_PASS=True,
- BLOCK_SIZE=ctx.BLOCK_SIZE,
- num_warps=ctx.num_warps,
+ BACKWARD_PASS = True,
+ BLOCK_SIZE = ctx.BLOCK_SIZE,
+ num_warps = ctx.num_warps,
)
dY = dY.view(batch, seq_len, n_heads, head_dim)
return (
@@ -211,7 +211,7 @@ class Slow_RoPE_Embedding(torch.autograd.Function):
# Q * cos + rotate_half(Q) * sin
half = Q.shape[-1] // 2
- RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim=-1)
+ RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)
Q *= cos
Q.addcmul_(RH_Q, sin)
# RH_Q *= sin
@@ -224,7 +224,7 @@ class Slow_RoPE_Embedding(torch.autograd.Function):
cos, sin = ctx.saved_tensors
# Q * cos + rotate_half.T(Q) * sin
half = dY.shape[-1] // 2
- RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim=-1)
+ RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
dY *= cos
dY.addcmul_(RH_dY, sin)
# RH_dY *= sin
diff --git a/unsloth/kernels/swiglu.py b/unsloth/kernels/swiglu.py
index 7547b9ca6..b321f5179 100644
--- a/unsloth/kernels/swiglu.py
+++ b/unsloth/kernels/swiglu.py
@@ -43,8 +43,8 @@ def _fg_kernel(
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
- e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)
- g_row = tl.load(g + offsets, mask=mask, other=0) # .to(tl.float32)
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
+ g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32)
# f = e * sigmoid(e)
f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row))
@@ -53,13 +53,13 @@ def _fg_kernel(
h_row = f_row * g_row
# Store h
- tl.store(h + offsets, h_row, mask=mask)
+ tl.store(h + offsets, h_row, mask = mask)
def swiglu_fg_kernel(e, g):
batch, seq_len, hd = e.shape
n_elements = e.numel()
- h = torch.empty((batch, seq_len, hd), dtype=e.dtype, device=e.device)
+ h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = e.device)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
with torch_gpu_device(e.device):
_fg_kernel[grid](
@@ -67,8 +67,8 @@ def swiglu_fg_kernel(e, g):
g,
h,
n_elements,
- BLOCK_SIZE=BLOCK_SIZE,
- LONG_INDEXING=0 if n_elements <= INT32_SAFETY_BUFFER else 1,
+ BLOCK_SIZE = BLOCK_SIZE,
+ LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
)
return h
@@ -101,9 +101,9 @@ def _DWf_DW_dfg_kernel(
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
- DW_row = tl.load(DW + offsets, mask=mask, other=0) # .to(tl.float32)
- e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)
- g_row = tl.load(g + offsets, mask=mask, other=0) # .to(tl.float32)
+ DW_row = tl.load(DW + offsets, mask = mask, other = 0) # .to(tl.float32)
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
+ g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32)
# e = e.float()
# se = 1.0 / (1.0 + torch.exp(-e))
@@ -122,9 +122,9 @@ def _DWf_DW_dfg_kernel(
de_row = de_row.to(DW_row.dtype)
# Store derivatives in buffers
- tl.store(DW + offsets, h_row, mask=mask) # h = f * g
- tl.store(e + offsets, df_row, mask=mask) # df = DW * f
- tl.store(g + offsets, de_row, mask=mask) # de
+ tl.store(DW + offsets, h_row, mask = mask) # h = f * g
+ tl.store(e + offsets, df_row, mask = mask) # df = DW * f
+ tl.store(g + offsets, de_row, mask = mask) # de
def swiglu_DWf_DW_dfg_kernel(DW, e, g):
@@ -137,7 +137,7 @@ def swiglu_DWf_DW_dfg_kernel(DW, e, g):
e,
g,
n_elements,
- BLOCK_SIZE=BLOCK_SIZE,
- LONG_INDEXING=0 if n_elements <= INT32_SAFETY_BUFFER else 1,
+ BLOCK_SIZE = BLOCK_SIZE,
+ LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
)
return DW, e, g
diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py
index c3f743cc0..5dcc7c232 100644
--- a/unsloth/kernels/utils.py
+++ b/unsloth/kernels/utils.py
@@ -47,12 +47,12 @@ if Version(torch.__version__) < Version("2.4.0"):
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
else:
- torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
- torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
+ torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
+ torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
if DEVICE_TYPE == "xpu":
- torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="xpu")
- torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="xpu")
+ torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "xpu")
+ torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "xpu")
# tl.math.tanh now is libdevice.tanh
@@ -349,7 +349,7 @@ def _maybe_fake_quantize_activations(
if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
@torch.inference_mode
- def fast_dequantize(W, quant_state=None, out=None, use_global_buffer=False):
+ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
# TODO: After adding XPU BNB support, check this function
if isinstance(W, Float8Tensor):
return W.dequantize()
@@ -390,13 +390,13 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index]
if WEIGHT_BUFFER is None:
WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(
- size, dtype=dtype, device=device, requires_grad=False
+ size, dtype = dtype, device = device, requires_grad = False
)
ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(
n_elements_absmax,
- dtype=torch.float32,
- device=device,
- requires_grad=False,
+ dtype = torch.float32,
+ device = device,
+ requires_grad = False,
)
if size > WEIGHT_BUFFER.numel():
@@ -409,16 +409,16 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
else:
if out is None:
out = torch_empty(
- shape, dtype=dtype, device=device, requires_grad=False
+ shape, dtype = dtype, device = device, requires_grad = False
)
else:
assert out.shape == shape
assert out.dtype == dtype
out_absmax = torch_empty(
n_elements_absmax,
- dtype=torch_float32,
- device=device,
- requires_grad=False,
+ dtype = torch_float32,
+ device = device,
+ requires_grad = False,
)
# NF4 dequantization of statistics
@@ -458,7 +458,7 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
@torch.inference_mode
- def fast_dequantize(W, quant_state=None, out=None, use_global_buffer=False):
+ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
if isinstance(W, Float8Tensor):
return W.dequantize()
if quant_state is None:
@@ -500,13 +500,13 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index]
if WEIGHT_BUFFER is None:
WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(
- size, dtype=dtype, device=device, requires_grad=False
+ size, dtype = dtype, device = device, requires_grad = False
)
ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(
n_elements_absmax,
- dtype=torch_float32,
- device=device,
- requires_grad=False,
+ dtype = torch_float32,
+ device = device,
+ requires_grad = False,
)
if size > WEIGHT_BUFFER.numel():
@@ -519,16 +519,16 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
else:
if out is None:
out = torch_empty(
- shape, dtype=dtype, device=device, requires_grad=False
+ shape, dtype = dtype, device = device, requires_grad = False
)
else:
assert out.shape == shape
assert out.dtype == dtype
out_absmax = torch_empty(
n_elements_absmax,
- dtype=torch_float32,
- device=device,
- requires_grad=False,
+ dtype = torch_float32,
+ device = device,
+ requires_grad = False,
)
pass
@@ -570,7 +570,7 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
else:
@torch.inference_mode
- def fast_dequantize(W, quant_state=None, out=None, use_global_buffer=False):
+ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
if isinstance(W, Float8Tensor):
return W.dequantize()
if quant_state is None:
@@ -601,12 +601,12 @@ else:
# Create weight matrix
if out is None:
- out = torch_empty(shape, dtype=dtype, device=device, requires_grad=False)
+ out = torch_empty(shape, dtype = dtype, device = device, requires_grad = False)
else:
assert out.shape == shape
assert out.dtype == dtype
out_absmax = torch_empty(
- n_elements_absmax, dtype=torch_float32, device=device, requires_grad=False
+ n_elements_absmax, dtype = torch_float32, device = device, requires_grad = False
)
# Do dequantization
@@ -645,9 +645,9 @@ else:
# INTEL GPU Specific Logic
if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
- def fast_gemv(X, W, quant_state, out=None):
+ def fast_gemv(X, W, quant_state, out = None):
if quant_state is None:
- return torch_matmul(X, W, out=out)
+ return torch_matmul(X, W, out = out)
# For fast X @ W where seq_len == 1
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
_, q_len, hd = X.shape
@@ -686,8 +686,8 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
1,
bout,
),
- dtype=dtype,
- device=device,
+ dtype = dtype,
+ device = device,
)
# else:
# assert(out.shape == (1, 1, bout,))
@@ -710,7 +710,7 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
ldb = ctypes_c_int32(ldb)
ldc = ctypes_c_int32(ldc)
- df = torch_empty(absmax.shape, dtype=torch_float32, device=device)
+ df = torch_empty(absmax.shape, dtype = torch_float32, device = device)
with torch_gpu_device(device):
cdequantize_blockwise_fp32(
get_ptr(code2),
@@ -751,9 +751,9 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
- def fast_gemv(X, W, quant_state, out=None):
+ def fast_gemv(X, W, quant_state, out = None):
if quant_state is None:
- return torch_matmul(X, W, out=out)
+ return torch_matmul(X, W, out = out)
# For fast X @ W where seq_len == 1
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
_, q_len, hd = X.shape
@@ -793,8 +793,8 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
1,
bout,
),
- dtype=dtype,
- device=device,
+ dtype = dtype,
+ device = device,
)
# else:
# assert(out.shape == (1, 1, bout,))
@@ -813,7 +813,7 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
ldb = ctypes_c_int32(ldb)
ldc = ctypes_c_int32(ldc)
- df = torch_empty(absmax.shape, dtype=torch_float32, device=device)
+ df = torch_empty(absmax.shape, dtype = torch_float32, device = device)
with torch_gpu_device(device):
cdequantize_blockwise_fp32(
get_ptr(code2),
@@ -856,9 +856,9 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
pass
else:
- def fast_gemv(X, W, quant_state, out=None):
+ def fast_gemv(X, W, quant_state, out = None):
if quant_state is None:
- return torch_matmul(X, W, out=out)
+ return torch_matmul(X, W, out = out)
# For fast X @ W where seq_len == 1
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
_, q_len, hd = X.shape
@@ -894,8 +894,8 @@ else:
1,
bout,
),
- dtype=dtype,
- device=device,
+ dtype = dtype,
+ device = device,
)
# else:
# assert(out.shape == (1, 1, bout,))
@@ -914,7 +914,7 @@ else:
ldb = ctypes_c_int32(ldb)
ldc = ctypes_c_int32(ldc)
- df = torch_empty(absmax.shape, dtype=torch_float32, device=device)
+ df = torch_empty(absmax.shape, dtype = torch_float32, device = device)
cdequantize_blockwise_fp32(
get_ptr(code2),
get_ptr(absmax),
@@ -953,21 +953,21 @@ else:
pass
-def fast_linear_forward(proj, X, temp_lora=None, out=None):
+def fast_linear_forward(proj, X, temp_lora = None, out = None):
W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
bsz, q_len, in_dim = X.shape
if q_len != 1:
return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
if W_quant is None:
- out = torch_matmul(X, W.t(), out=out)
+ out = torch_matmul(X, W.t(), out = out)
elif W.dtype == torch.float8_e4m3fn:
out = fp8_linear(X, W, W_quant, bias)
elif bsz == 1 and q_len == 1:
- out = fast_gemv(X, W, W_quant, out=out)
+ out = fast_gemv(X, W, W_quant, out = out)
else:
- W = fast_dequantize(W.t(), W_quant, use_global_buffer=True)
- out = torch_matmul(X, W, out=out)
+ W = fast_dequantize(W.t(), W_quant, use_global_buffer = True)
+ out = torch_matmul(X, W, out = out)
# Add in LoRA weights
if lora_A is not None:
@@ -980,14 +980,14 @@ def fast_linear_forward(proj, X, temp_lora=None, out=None):
if bsz == 1:
out = out.view(out_dim)
- temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out=temp_lora)
- out.addmv_(lora_B._fast_lora, temp_lora, alpha=lora_S)
+ temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
+ out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
else:
out = out.view(bsz, out_dim)
temp_lora = torch_mm(
- X.view(bsz, in_dim), lora_A._fast_lora.t(), out=temp_lora
+ X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora
)
- out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha=lora_S)
+ out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S)
out = out.view(bsz, 1, out_dim)
if bias is not None:
@@ -996,7 +996,7 @@ def fast_linear_forward(proj, X, temp_lora=None, out=None):
return out
-def matmul_lora(X, W, W_quant, A, B, s, out=None):
+def matmul_lora(X, W, W_quant, A, B, s, out = None):
dtype = X.dtype
if X.dim() == 3:
@@ -1015,12 +1015,12 @@ def matmul_lora(X, W, W_quant, A, B, s, out=None):
W = W.dequantize()
else:
W = W.contiguous()
- out = torch_matmul(X, W.t(), out=out)
+ out = torch_matmul(X, W.t(), out = out)
elif W.dtype == torch.float8_e4m3fn:
out = fp8_linear(X, W, W_quant)
else:
- W = fast_dequantize(W, W_quant, use_global_buffer=True)
- out = torch_matmul(X, W.t(), out=out)
+ W = fast_dequantize(W, W_quant, use_global_buffer = True)
+ out = torch_matmul(X, W.t(), out = out)
if W_quant is not None:
del W
@@ -1028,7 +1028,7 @@ def matmul_lora(X, W, W_quant, A, B, s, out=None):
# LoRA is enabled
A, B = A.t(), B.t()
XA = torch_matmul(X, A.to(dtype))
- out.addmm_(XA, B.to(dtype), alpha=s)
+ out.addmm_(XA, B.to(dtype), alpha = s)
# out += (X @ A.to(dtype)) @ (s * B.to(dtype))
return out.view(batch, seq_len, -1) if reshape else out
diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py
index e0766b352..d3c4a8d73 100644
--- a/unsloth/models/_utils.py
+++ b/unsloth/models/_utils.py
@@ -150,23 +150,23 @@ for temporary_patch in TEMPORARY_PATCHES:
# =============================================
# Disable some warnings which can get annoying
-warnings.filterwarnings(action="ignore", category=UserWarning, module="torch")
-warnings.filterwarnings(action="ignore", category=FutureWarning, module="torch")
-warnings.filterwarnings(action="ignore", category=UserWarning, module="huggingface_hub")
+warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
+warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "torch")
+warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
warnings.filterwarnings(
- action="ignore", category=FutureWarning, module="huggingface_hub"
+ action = "ignore", category = FutureWarning, module = "huggingface_hub"
)
-warnings.filterwarnings(action="ignore", category=UserWarning, module="trl")
-warnings.filterwarnings(action="ignore", category=FutureWarning, module="trl")
-warnings.filterwarnings(action="ignore", category=FutureWarning, module="xformers")
-warnings.filterwarnings(action="ignore", category=RuntimeWarning, module="subprocess")
-warnings.filterwarnings(action="ignore", category=UserWarning, module="transformers")
-warnings.filterwarnings(action="ignore", category=FutureWarning, module="accelerate")
+warnings.filterwarnings(action = "ignore", category = UserWarning, module = "trl")
+warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "trl")
+warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "xformers")
+warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "subprocess")
+warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers")
+warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate")
warnings.filterwarnings(
- action="ignore", category=RuntimeWarning, module="multiprocessing"
+ action = "ignore", category = RuntimeWarning, module = "multiprocessing"
)
-warnings.filterwarnings(action="ignore", category=RuntimeWarning, module="multiprocess")
-warnings.filterwarnings(action="ignore", category=UserWarning, module="triton")
+warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocess")
+warnings.filterwarnings(action = "ignore", category = UserWarning, module = "triton")
# Stop "Special tokens have been added in the vocabulary, ..."
import logging
@@ -345,10 +345,10 @@ except:
# You passed `quantization_config` or equivalent parameters
try:
warnings.filterwarnings(
- action="ignore",
- message=r".*quantization_config.*",
- category=UserWarning,
- append=True,
+ action = "ignore",
+ message = r".*quantization_config.*",
+ category = UserWarning,
+ append = True,
)
except:
pass
@@ -357,10 +357,10 @@ except:
# Will be fixed in torch 2.8.1 https://github.com/pytorch/pytorch/issues/158463
try:
warnings.filterwarnings(
- action="ignore",
- message=r".*Logical operators 'and' and 'or'.*",
- category=UserWarning,
- append=True,
+ action = "ignore",
+ message = r".*Logical operators 'and' and 'or'.*",
+ category = UserWarning,
+ append = True,
)
except:
pass
@@ -465,7 +465,7 @@ def extract_quant_model_param_count(model):
return count
-def get_model_param_count(model, trainable_only=False):
+def get_model_param_count(model, trainable_only = False):
"""
Calculate model's total param count. If trainable_only is True then count only those requiring grads
"""
@@ -596,14 +596,14 @@ if DEVICE_TYPE in ("cuda", "hip"):
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
else:
- torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
- torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
+ torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
+ torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
elif DEVICE_TYPE == "xpu":
if Version(torch_version) < Version("2.6.0"):
raise RuntimeError("torch.xpu currently only supports torch.version >= 2.6.0")
else:
- torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="xpu")
- torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="xpu")
+ torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "xpu")
+ torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "xpu")
# =============================================
# =============================================
@@ -934,9 +934,9 @@ import torch._inductor.utils
torch._inductor.utils.is_big_gpu = is_big_gpu
patch_torch_compile(
- debug=UNSLOTH_COMPILE_DEBUG,
- O3=UNSLOTH_COMPILE_MAXIMUM,
- ignore_errors=UNSLOTH_COMPILE_IGNORE_ERRORS,
+ debug = UNSLOTH_COMPILE_DEBUG,
+ O3 = UNSLOTH_COMPILE_MAXIMUM,
+ ignore_errors = UNSLOTH_COMPILE_IGNORE_ERRORS,
)
torch_compile_options = {
@@ -983,9 +983,9 @@ def patch_regional_compilation():
[
torch.compile(
x,
- dynamic=True,
- options=torch_compile_options,
- fullgraph=False,
+ dynamic = True,
+ options = torch_compile_options,
+ fullgraph = False,
)
for x in args[0]
]
@@ -1008,14 +1008,14 @@ def prepare_model_for_kbit_training(
use_reentrant: Optional[bool] = True,
) -> Any:
return prepare_model_for_training(
- model=model,
- use_gradient_checkpointing=use_gradient_checkpointing,
- use_reentrant=use_reentrant,
- full_finetuning=False,
- train_layernorms=False,
- train_embedding=False,
- train_lm_head=False,
- float32_mixed_precision=True,
+ model = model,
+ use_gradient_checkpointing = use_gradient_checkpointing,
+ use_reentrant = use_reentrant,
+ full_finetuning = False,
+ train_layernorms = False,
+ train_embedding = False,
+ train_lm_head = False,
+ float32_mixed_precision = True,
)
@@ -1033,7 +1033,7 @@ if Version(peft_version) < Version("0.12.0"):
text = "if weight is not None:\n"
start = source.find(text) + len(text)
end = source.find("self.to(weight.device)", start)
- spaces = re.findall(r"^([ ]{1,})break", source, flags=re.MULTILINE)[0]
+ spaces = re.findall(r"^([ ]{1,})break", source, flags = re.MULTILINE)[0]
source = source.replace(source[start:end], spaces)
spaces = len(re.match(r"[\s]{1,}", source).group(0))
lines = source.split("\n")
@@ -1070,7 +1070,7 @@ import socket
@functools.lru_cache(1)
-def has_internet(host="8.8.8.8", port=53, timeout=3):
+def has_internet(host = "8.8.8.8", port = 53, timeout = 3):
if os.environ.get("TRANSFORMERS_OFFLINE", "0") == "1":
return False
try:
@@ -1084,12 +1084,12 @@ def has_internet(host="8.8.8.8", port=53, timeout=3):
import psutil
-def _get_statistics(statistics=None, force_download=True):
+def _get_statistics(statistics = None, force_download = True):
# We log some basic stats about which environment is being used.
# We simply download a README.md file from HF - all data is made public.
# This is simply so we can check if some envs are broken or not.
# You can disable this by commenting the below out
- n_cpus = psutil.cpu_count(logical=False)
+ n_cpus = psutil.cpu_count(logical = False)
keynames = "\n" + "\n".join(os.environ.keys())
# Check modelscope for down detection
global USE_MODELSCOPE
@@ -1150,12 +1150,12 @@ def _get_statistics(statistics=None, force_download=True):
if has_internet():
def stats_check():
- with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as f:
+ with tempfile.TemporaryDirectory(ignore_cleanup_errors = True) as f:
snapshot_download(
f"unslothai/{statistics}",
- force_download=True,
- cache_dir=f,
- local_dir=f,
+ force_download = True,
+ cache_dir = f,
+ local_dir = f,
)
time_limited_stats_check = execute_with_time_limit(120)(stats_check)
@@ -1178,7 +1178,7 @@ def _get_statistics(statistics=None, force_download=True):
stats_check()
-def get_statistics(local_files_only=False):
+def get_statistics(local_files_only = False):
# We log some basic stats about which environment is being used.
# This is also to check if HuggingFace is down or not!
# We simply download a README.md file from HF - all data is made public.
@@ -1201,7 +1201,7 @@ def get_statistics(local_files_only=False):
disable_progress_bars()
disabled = True
_get_statistics(None)
- _get_statistics("repeat", force_download=False)
+ _get_statistics("repeat", force_download = False)
total_memory = (
torch.xpu.get_device_properties(0).total_memory
if DEVICE_TYPE == "xpu"
@@ -1242,7 +1242,7 @@ BitsAndBytesConfig__init__ = re.sub(
r"if[\s]{1,}kwargs\:[\s]{1,}.+?\n",
"",
BitsAndBytesConfig__init__,
- flags=re.MULTILINE,
+ flags = re.MULTILINE,
)
BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.split("\n")
length_spaces = len(re.match(r"[\s]{1,}", BitsAndBytesConfig__init__[0]).group(0))
@@ -1322,12 +1322,12 @@ def offload_to_disk(
torch.save(
W,
filename,
- pickle_module=pickle,
- pickle_protocol=pickle.HIGHEST_PROTOCOL,
+ pickle_module = pickle,
+ pickle_protocol = pickle.HIGHEST_PROTOCOL,
)
# We must use weights_only = False due to pickling
offloaded_W = torch.load(
- filename, map_location="cpu", mmap=True, weights_only=False
+ filename, map_location = "cpu", mmap = True, weights_only = False
)
offloaded_W._offloaded_file_location = filename
return offloaded_W
@@ -1352,7 +1352,7 @@ def offload_output_embeddings(
model.get_output_embeddings(), model, "output_embeddings", temporary_location
)
- new_output_embeddings = torch.nn.Linear(1, 1, bias=None)
+ new_output_embeddings = torch.nn.Linear(1, 1, bias = None)
del new_output_embeddings.weight
new_output_embeddings.weight = offloaded_W
new_output_embeddings.in_features = offloaded_W.shape[1]
@@ -1376,10 +1376,10 @@ def is_vLLM_available():
# Patches models to add RoPE Scaling
def patch_linear_scaling(
- model_name="gemma2",
- rope_module=None,
- scaled_rope_module=None,
- attention_module=None,
+ model_name = "gemma2",
+ rope_module = None,
+ scaled_rope_module = None,
+ attention_module = None,
):
assert rope_module is not None and scaled_rope_module is not None
assert attention_module is not None
@@ -1430,13 +1430,13 @@ def patch_linear_scaling(
pass
"""
fix_rope_function = fix_rope_function.format(
- rope_function=rope_module.__name__,
- scaled_rope_function=scaled_rope_module.__name__,
+ rope_function = rope_module.__name__,
+ scaled_rope_function = scaled_rope_module.__name__,
)
rotary_emb = re.findall(
r"self\.rotary\_emb \= .+?\)",
function,
- flags=re.DOTALL | re.MULTILINE,
+ flags = re.DOTALL | re.MULTILINE,
)
if len(rotary_emb) == 0:
return None, exec_code + "\n\n" + function
@@ -1449,12 +1449,12 @@ def patch_linear_scaling(
# Patches for Llama-3 LlamaExtendedRotaryEmbedding
def patch_llama_rope_scaling(
- model_name="llama",
- rope_module=None,
- scaled_rope_module=None,
- extended_rope_module=None,
- attention_module=None,
- longrope_module=None,
+ model_name = "llama",
+ rope_module = None,
+ scaled_rope_module = None,
+ extended_rope_module = None,
+ attention_module = None,
+ longrope_module = None,
):
assert (
rope_module is not None
@@ -1528,17 +1528,17 @@ def patch_llama_rope_scaling(
"""
fix_rope_function = fix_rope_function.format(
- rope_function=rope_module.__name__,
- scaled_rope_function=scaled_rope_module.__name__,
- extended_rope_function=extended_rope_module.__name__,
- longrope_rope_function=(
+ rope_function = rope_module.__name__,
+ scaled_rope_function = scaled_rope_module.__name__,
+ extended_rope_function = extended_rope_module.__name__,
+ longrope_rope_function = (
longrope_module if longrope_module is not None else rope_module
).__name__,
)
rotary_emb = re.findall(
r"self\.rotary\_emb \= .+?\)",
function,
- flags=re.DOTALL | re.MULTILINE,
+ flags = re.DOTALL | re.MULTILINE,
)
if len(rotary_emb) == 0:
return None, function
@@ -1548,15 +1548,15 @@ def patch_llama_rope_scaling(
return init_name, function
-def create_boolean_mask(n=4096, sliding_window=2048):
+def create_boolean_mask(n = 4096, sliding_window = 2048):
# Creates a boolean mask for attention
- mask = torch.ones(n, n, dtype=torch.bool)
+ mask = torch.ones(n, n, dtype = torch.bool)
if sliding_window == 0:
- return torch.triu(mask, diagonal=1, out=mask)
- torch.triu(mask, diagonal=0, out=mask)
- torch.triu(mask.T, diagonal=-sliding_window, out=mask.T)
+ return torch.triu(mask, diagonal = 1, out = mask)
+ torch.triu(mask, diagonal = 0, out = mask)
+ torch.triu(mask.T, diagonal = -sliding_window, out = mask.T)
mask = mask.T
- torch.logical_not(mask, out=mask)
+ torch.logical_not(mask, out = mask)
return mask
@@ -1567,37 +1567,37 @@ def test_mask_creation():
for s in range(1, 23):
correct_mask = (
AttentionMaskConverter(
- is_causal=True,
- sliding_window=s,
+ is_causal = True,
+ sliding_window = s,
)
.to_causal_4d(
1,
n,
n,
- dtype=torch.float16,
+ dtype = torch.float16,
)
.squeeze(0)
.squeeze(0)
)
correct_mask = correct_mask == correct_mask.min()
- our_mask = create_boolean_mask(n=n, sliding_window=s)
+ our_mask = create_boolean_mask(n = n, sliding_window = s)
assert torch.all(correct_mask == our_mask)
correct_mask = (
AttentionMaskConverter(
- is_causal=True,
- sliding_window=None,
+ is_causal = True,
+ sliding_window = None,
)
.to_causal_4d(
1,
n,
n,
- dtype=torch.float16,
+ dtype = torch.float16,
)
.squeeze(0)
.squeeze(0)
)
correct_mask = correct_mask == correct_mask.min()
- our_mask = create_boolean_mask(n=n, sliding_window=0)
+ our_mask = create_boolean_mask(n = n, sliding_window = 0)
assert torch.all(correct_mask == our_mask)
@@ -1812,34 +1812,34 @@ def unsloth_compile_transformers(
dtype,
model_name,
model_types,
- token=None,
- revision=None,
- trust_remote_code=False,
- sdpa_dynamic_mask=True,
- sdpa_bool_masks=True,
- sdpa_gqa_replace=True,
- sdpa_dynamic_compile=True,
- compile_attention=True,
- disable_causal_masks=True,
- compile_torch_modules=True,
- compile_custom_modules=True,
- compile_function_calls=True,
- fuse_lm_head=True,
- gradient_checkpointing=True,
- manual_replacements=True,
- fast_lora_forwards=True,
- fast_residual_stream=True,
- accurate_accumulation=True,
- epilogue_fusion=True,
- max_autotune=False,
- shape_padding=True,
- cudagraphs=False,
- debug=False,
- fullgraph=True,
- import_from_cache=False,
- disable=False,
- return_logits=False,
- unsloth_force_compile=False,
+ token = None,
+ revision = None,
+ trust_remote_code = False,
+ sdpa_dynamic_mask = True,
+ sdpa_bool_masks = True,
+ sdpa_gqa_replace = True,
+ sdpa_dynamic_compile = True,
+ compile_attention = True,
+ disable_causal_masks = True,
+ compile_torch_modules = True,
+ compile_custom_modules = True,
+ compile_function_calls = True,
+ fuse_lm_head = True,
+ gradient_checkpointing = True,
+ manual_replacements = True,
+ fast_lora_forwards = True,
+ fast_residual_stream = True,
+ accurate_accumulation = True,
+ epilogue_fusion = True,
+ max_autotune = False,
+ shape_padding = True,
+ cudagraphs = False,
+ debug = False,
+ fullgraph = True,
+ import_from_cache = False,
+ disable = False,
+ return_logits = False,
+ unsloth_force_compile = False,
):
if Version(torch_version) < Version("2.4.0"):
print(
@@ -1864,31 +1864,31 @@ def unsloth_compile_transformers(
for model_type in model_types:
_unsloth_compile_transformers(
model_type,
- sdpa_dynamic_mask=sdpa_dynamic_mask,
- sdpa_bool_masks=sdpa_bool_masks,
- sdpa_gqa_replace=sdpa_gqa_replace,
- sdpa_dynamic_compile=sdpa_dynamic_compile,
- compile_attention=compile_attention,
- disable_causal_masks=disable_causal_masks,
- compile_torch_modules=compile_torch_modules,
- compile_custom_modules=compile_custom_modules,
- compile_function_calls=compile_function_calls,
- fuse_lm_head=fuse_lm_head,
- gradient_checkpointing=gradient_checkpointing,
- manual_replacements=manual_replacements,
- fast_lora_forwards=fast_lora_forwards,
- fast_residual_stream=fast_residual_stream,
- accurate_accumulation=accurate_accumulation,
- epilogue_fusion=epilogue_fusion,
- max_autotune=max_autotune,
- shape_padding=shape_padding,
- cudagraphs=cudagraphs,
- debug=debug,
- fullgraph=fullgraph,
- import_from_cache=import_from_cache,
- disable=disable,
- return_logits=return_logits,
- supports_sdpa=supports_sdpa,
+ sdpa_dynamic_mask = sdpa_dynamic_mask,
+ sdpa_bool_masks = sdpa_bool_masks,
+ sdpa_gqa_replace = sdpa_gqa_replace,
+ sdpa_dynamic_compile = sdpa_dynamic_compile,
+ compile_attention = compile_attention,
+ disable_causal_masks = disable_causal_masks,
+ compile_torch_modules = compile_torch_modules,
+ compile_custom_modules = compile_custom_modules,
+ compile_function_calls = compile_function_calls,
+ fuse_lm_head = fuse_lm_head,
+ gradient_checkpointing = gradient_checkpointing,
+ manual_replacements = manual_replacements,
+ fast_lora_forwards = fast_lora_forwards,
+ fast_residual_stream = fast_residual_stream,
+ accurate_accumulation = accurate_accumulation,
+ epilogue_fusion = epilogue_fusion,
+ max_autotune = max_autotune,
+ shape_padding = shape_padding,
+ cudagraphs = cudagraphs,
+ debug = debug,
+ fullgraph = fullgraph,
+ import_from_cache = import_from_cache,
+ disable = disable,
+ return_logits = return_logits,
+ supports_sdpa = supports_sdpa,
)
# Redo patches which override compiler
for temporary_patch in TEMPORARY_PATCHES:
@@ -1993,7 +1993,7 @@ def validate_loftq_config(loftq_config, lora_dropout, bias, init_lora_weights, m
"Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"
"We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`."
)
- loftq_config = LoftQConfig(loftq_bits=4, loftq_iter=1)
+ loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)
if hasattr(model.config, "quantization_config"):
raise ValueError(
@@ -2069,9 +2069,9 @@ class TorchAOConfig:
base_config_and_filter_fns: List[
Tuple["AOBaseConfig", Optional[Callable[[torch.nn.Module, str], bool]]]
] = field(
- default_factory=lambda: [
+ default_factory = lambda: [
(
- Int4WeightOnlyConfig(group_size=128),
+ Int4WeightOnlyConfig(group_size = 128),
lambda m, _: isinstance(m, torch.nn.Linear)
and getattr(m, "in_features", 0) >= 128,
),
@@ -2143,7 +2143,7 @@ def _convert_torchao_model(model):
module_to_fqn_dict = {}
for base_config, filter_fn in model._torchao_config.base_config_and_filter_fns:
- quantize_(model, QATConfig(base_config, step="convert"), filter_fn=filter_fn)
+ quantize_(model, QATConfig(base_config, step = "convert"), filter_fn = filter_fn)
# Default filter function used for quantize_
if filter_fn is None:
@@ -2166,7 +2166,7 @@ def _convert_torchao_model(model):
kwargs["modules_to_not_convert"] = []
quant_config = ModuleFqnToConfig(module_to_fqn_dict)
- quantization_config = TorchAoConfig(quant_type=quant_config, **kwargs)
+ quantization_config = TorchAoConfig(quant_type = quant_config, **kwargs)
model.config.quantization_config = quantization_config
@@ -2199,17 +2199,17 @@ def _prepare_model_for_qat(
and m.in_features >= group_size
)
torchao_config = TorchAOConfig(
- qat_scheme=qat_scheme,
- base_config_and_filter_fns=[(base_config, filter_fn)],
+ qat_scheme = qat_scheme,
+ base_config_and_filter_fns = [(base_config, filter_fn)],
)
elif qat_scheme == "fp8-fp8":
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
base_config = Float8DynamicActivationFloat8WeightConfig(
- granularity=PerRow()
+ granularity = PerRow()
)
torchao_config = TorchAOConfig(
- qat_scheme=qat_scheme, base_config_and_filter_fns=[(base_config, None)]
+ qat_scheme = qat_scheme, base_config_and_filter_fns = [(base_config, None)]
)
elif qat_scheme == "int8-int4":
from torchao.quantization import (
@@ -2218,35 +2218,35 @@ def _prepare_model_for_qat(
)
torchao_config = TorchAOConfig(
- qat_scheme=qat_scheme,
- base_config_and_filter_fns=[
+ qat_scheme = qat_scheme,
+ base_config_and_filter_fns = [
(
IntxWeightOnlyConfig(
- weight_dtype=torch.int8, granularity=PerAxis(0)
+ weight_dtype = torch.int8, granularity = PerAxis(0)
),
lambda m, fqn: isinstance(m, torch.nn.Embedding),
),
(
Int8DynamicActivationIntxWeightConfig(
- weight_dtype=torch.int4, weight_granularity=PerGroup(32)
+ weight_dtype = torch.int4, weight_granularity = PerGroup(32)
),
None,
),
],
- prequantization_transform=_untie_input_output_embeddings,
+ prequantization_transform = _untie_input_output_embeddings,
)
elif qat_scheme == "int4":
from torchao.quantization import Int4WeightOnlyConfig
group_size = 128
- base_config = Int4WeightOnlyConfig(group_size=group_size)
+ base_config = Int4WeightOnlyConfig(group_size = group_size)
filter_fn = (
lambda m, _: isinstance(m, torch.nn.Linear)
and m.in_features >= group_size
)
torchao_config = TorchAOConfig(
- qat_scheme=qat_scheme,
- base_config_and_filter_fns=[(base_config, filter_fn)],
+ qat_scheme = qat_scheme,
+ base_config_and_filter_fns = [(base_config, filter_fn)],
)
else:
raise ValueError(f"Unexpected QAT scheme {qat_scheme}")
@@ -2264,7 +2264,7 @@ def _prepare_model_for_qat(
if torchao_config.prequantization_transform is not None:
torchao_config.prequantization_transform(model)
for base_config, filter_fn in torchao_config.base_config_and_filter_fns:
- quantize_(model, QATConfig(base_config, step="prepare"), filter_fn=filter_fn)
+ quantize_(model, QATConfig(base_config, step = "prepare"), filter_fn = filter_fn)
return model
diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py
index 9f2d3cfc0..234fc2100 100644
--- a/unsloth/models/gemma2.py
+++ b/unsloth/models/gemma2.py
@@ -113,8 +113,8 @@ def Gemma2Attention_fast_forward(
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
if past_key_value is not None:
- K = torch.cat([past_key_value[0], K], dim=2)
- V = torch.cat([past_key_value[1], V], dim=2)
+ K = torch.cat([past_key_value[0], K], dim = 2)
+ V = torch.cat([past_key_value[1], V], dim = 2)
past_key_value = (K, V) if use_cache else None
# Only enable if the attention_mask is True
@@ -139,10 +139,10 @@ def Gemma2Attention_fast_forward(
Q,
K,
V,
- causal=True,
- softcap=self.config.attn_logit_softcapping,
- softmax_scale=self._flash_attention_softmax_scale,
- window_size=window,
+ causal = True,
+ softcap = self.config.attn_logit_softcapping,
+ softmax_scale = self._flash_attention_softmax_scale,
+ window_size = window,
)
A = A.reshape(bsz, q_len, n_heads * head_dim)
else:
@@ -174,7 +174,7 @@ def Gemma2DecoderLayer_fast_forward(
self, "_flag_for_generation"
): # past_key_value is not None:
out_weight = torch.empty(
- self.input_layernorm.weight.shape, dtype=torch.float32, device="cuda:0"
+ self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0"
)
# Self Attention
@@ -183,15 +183,15 @@ def Gemma2DecoderLayer_fast_forward(
self.input_layernorm, hidden_states, out_weight
)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- causal_mask=causal_mask,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- padding_mask=padding_mask,
- _flag_for_generation=self._flag_for_generation,
+ hidden_states = hidden_states,
+ causal_mask = causal_mask,
+ attention_mask = attention_mask,
+ position_ids = position_ids,
+ past_key_value = past_key_value,
+ output_attentions = output_attentions,
+ use_cache = use_cache,
+ padding_mask = padding_mask,
+ _flag_for_generation = self._flag_for_generation,
)
hidden_states = fast_rms_layernorm_inference_gemma(
self.post_attention_layernorm, hidden_states, out_weight
@@ -211,31 +211,31 @@ def Gemma2DecoderLayer_fast_forward(
else:
residual = hidden_states
hidden_states = fast_rms_layernorm(
- self.input_layernorm, hidden_states, gemma=True
+ self.input_layernorm, hidden_states, gemma = True
)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- causal_mask=causal_mask,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- padding_mask=padding_mask,
+ hidden_states = hidden_states,
+ causal_mask = causal_mask,
+ attention_mask = attention_mask,
+ position_ids = position_ids,
+ past_key_value = past_key_value,
+ output_attentions = output_attentions,
+ use_cache = use_cache,
+ padding_mask = padding_mask,
)
hidden_states = fast_rms_layernorm(
- self.post_attention_layernorm, hidden_states, gemma=True
+ self.post_attention_layernorm, hidden_states, gemma = True
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm(
- self.pre_feedforward_layernorm, hidden_states, gemma=True
+ self.pre_feedforward_layernorm, hidden_states, gemma = True
)
hidden_states = self.mlp(hidden_states)
hidden_states = fast_rms_layernorm(
- self.post_feedforward_layernorm, hidden_states, gemma=True
+ self.post_feedforward_layernorm, hidden_states, gemma = True
)
hidden_states = residual + hidden_states
@@ -260,9 +260,9 @@ def Gemma2Attention_fast_forward_inference(
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
- do_prefill=False,
- attention_mask=None,
- use_sliding_window=False,
+ do_prefill = False,
+ attention_mask = None,
+ use_sliding_window = False,
):
Xn = hidden_states
bsz, _, hd = hidden_states.size()
@@ -286,24 +286,24 @@ def Gemma2Attention_fast_forward_inference(
if do_prefill:
self.paged_attention = torch.empty(
(KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
- dtype=dtype,
- device=device,
+ dtype = dtype,
+ device = device,
)
self.paged_attention_K = self.paged_attention[:, 0]
self.paged_attention_V = self.paged_attention[:, 1]
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
self.temp_QA = torch.empty(
- (2, bsz, 1, attention_size), dtype=dtype, device=device
+ (2, bsz, 1, attention_size), dtype = dtype, device = device
)
self.temp_KV = torch.empty(
- (2, bsz, 1, n_kv_heads * head_dim), dtype=dtype, device=device
+ (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
)
- self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype=dtype, device=device)
+ self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
# Only for Gemma2
- self.temp_O = torch.empty((1, bsz, hidden_size), dtype=dtype, device=device)
+ self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
self.attention = torch.empty(
- (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype=dtype, device=device
+ (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
)
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
@@ -331,9 +331,9 @@ def Gemma2Attention_fast_forward_inference(
(bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
)
- Qn = fast_linear_forward(self.q_proj, Xn, out=self.temp_QA[0])
- Kn = fast_linear_forward(self.k_proj, Xn, out=self.temp_KV[0])
- Vn = fast_linear_forward(self.v_proj, Xn, out=self.temp_KV[1])
+ Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
+ Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
+ Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
@@ -348,7 +348,7 @@ def Gemma2Attention_fast_forward_inference(
RH_Q = self.RH_Q
RH_Q[:, :, :, :h] = Qn[:, :, :, h:]
RH_Q[:, :, :, h:] = Qn[:, :, :, :h]
- torch.neg(RH_Q[:, :, :, :h], out=RH_Q[:, :, :, :h])
+ torch.neg(RH_Q[:, :, :, :h], out = RH_Q[:, :, :, :h])
Qn *= cos
Qn.addcmul_(RH_Q, sin)
@@ -357,7 +357,7 @@ def Gemma2Attention_fast_forward_inference(
] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
RH_K[:, :, :, :h] = Kn[:, :, :, h:]
RH_K[:, :, :, h:] = Kn[:, :, :, :h]
- torch.neg(RH_K[:, :, :, :h], out=RH_K[:, :, :, :h])
+ torch.neg(RH_K[:, :, :, :h], out = RH_K[:, :, :, :h])
Kn *= cos
Kn.addcmul_(RH_K, sin)
@@ -400,21 +400,21 @@ def Gemma2Attention_fast_forward_inference(
self.scalar
) # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
- A = torch_matmul(Qn, Knn.transpose(2, 3), out=self.attention[:, :, :, :cached_len])
+ A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len])
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
A *= self.reciprocal_t
- torch_tanh(A, out=A)
+ torch_tanh(A, out = A)
A *= self.t # Logit softcapping
- A[:] = torch_nn_functional_softmax(A, dim=-1, dtype=torch.float32) # .to(A.dtype)
- A = torch_matmul(A, Vnn, out=Qn)
+ A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32) # .to(A.dtype)
+ A = torch_matmul(A, Vnn, out = Qn)
# else:
# A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
# pass
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
- A = fast_linear_forward(self.o_proj, A, out=self.temp_O)
+ A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
return A, (Kn, Vn)
@@ -425,13 +425,13 @@ def Gemma2Model_fast_forward_inference(
input_ids,
past_key_values,
position_ids,
- attention_mask=None,
+ attention_mask = None,
):
out_weights = tuple(
torch.empty_like(
self.model.layers[0].input_layernorm.weight,
- dtype=torch.float32,
- device=torch.device(x),
+ dtype = torch.float32,
+ device = torch.device(x),
)
for x in range(DEVICE_COUNT)
)
@@ -441,7 +441,7 @@ def Gemma2Model_fast_forward_inference(
# 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
# 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
hidden_states *= torch.tensor(
- math_sqrt(self.config.hidden_size), dtype=hidden_states.dtype
+ math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype
)
bsz, q_len, hd = hidden_states.shape
@@ -456,7 +456,7 @@ def Gemma2Model_fast_forward_inference(
(bsz, q_len),
hidden_states,
seq_len,
- sliding_window=self.config.sliding_window,
+ sliding_window = self.config.sliding_window,
)
GA = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
@@ -484,12 +484,12 @@ def Gemma2Model_fast_forward_inference(
)
hidden_states, present_key_value = Gemma2Attention_fast_forward_inference(
decoder_layer.self_attn,
- hidden_states=hidden_states,
- past_key_value=past_key_values[idx],
- position_ids=position_ids,
- attention_mask=SWA if use_sliding_window else GA,
- do_prefill=not hasattr(decoder_layer.self_attn, "paged_attention"),
- use_sliding_window=use_sliding_window,
+ hidden_states = hidden_states,
+ past_key_value = past_key_values[idx],
+ position_ids = position_ids,
+ attention_mask = SWA if use_sliding_window else GA,
+ do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
+ use_sliding_window = use_sliding_window,
)
hidden_states = fast_rms_layernorm_inference_gemma(
decoder_layer.post_attention_layernorm,
@@ -518,10 +518,10 @@ def Gemma2Model_fast_forward_inference(
)
return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_decoder_cache,
- hidden_states=[],
- attentions=[],
+ last_hidden_state = hidden_states,
+ past_key_values = next_decoder_cache,
+ hidden_states = [],
+ attentions = [],
)
@@ -529,10 +529,10 @@ class FastGemma2Model(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
- model_name="gemma2",
- rope_module=GemmaFixedRotaryEmbedding,
- scaled_rope_module=GemmaFixedLinearScalingRotaryEmbedding,
- attention_module=Gemma2Attention,
+ model_name = "gemma2",
+ rope_module = GemmaFixedRotaryEmbedding,
+ scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,
+ attention_module = Gemma2Attention,
)
if init_name is not None:
exec(function, globals())
@@ -564,7 +564,7 @@ class FastGemma2Model(FastLlamaModel):
def post_patch(model, tokenizer):
# Gemma does not downcast RoPE
model, tokenizer = patch_model_and_tokenizer(
- model, tokenizer, downcast_rope=False
+ model, tokenizer, downcast_rope = False
)
# Add 1 to weight
diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py
index 7fd6bbd6f..0a816e399 100644
--- a/unsloth/models/granite.py
+++ b/unsloth/models/granite.py
@@ -109,8 +109,8 @@ def GraniteAttention_fast_forward(
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
if past_key_value is not None:
- K = torch.cat([past_key_value[0], K], dim=2)
- V = torch.cat([past_key_value[1], V], dim=2)
+ K = torch.cat([past_key_value[0], K], dim = 2)
+ V = torch.cat([past_key_value[1], V], dim = 2)
past_key_value = (K, V) if use_cache else None
# Attention module
@@ -135,7 +135,7 @@ def GraniteAttention_fast_forward(
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
A = xformers_attention(
- Q, K, V, attn_bias=causal_mask, scale=self.scaling, p=dropout_p
+ Q, K, V, attn_bias = causal_mask, scale = self.scaling, p = dropout_p
)
A = A.view(bsz, q_len, n_heads, head_dim)
@@ -148,10 +148,10 @@ def GraniteAttention_fast_forward(
Q,
K,
V,
- causal=True,
- window_size=window,
- softmax_scale=self.scaling,
- dropout_p=dropout_p,
+ causal = True,
+ window_size = window,
+ softmax_scale = self.scaling,
+ dropout_p = dropout_p,
)
else:
# Grouped query attention
@@ -170,10 +170,10 @@ def GraniteAttention_fast_forward(
Q,
K,
V,
- attn_mask=attention_mask,
- scale=self.scaling,
- is_causal=False,
- dropout_p=dropout_p,
+ attn_mask = attention_mask,
+ scale = self.scaling,
+ is_causal = False,
+ dropout_p = dropout_p,
)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
@@ -212,18 +212,18 @@ def GraniteDecoderLayer_fast_forward(
self.input_layernorm, hidden_states
)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- causal_mask=causal_mask,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- padding_mask=padding_mask,
- position_embeddings=position_embeddings,
- _flag_for_generation=self._flag_for_generation,
+ hidden_states = hidden_states,
+ causal_mask = causal_mask,
+ attention_mask = attention_mask,
+ position_ids = position_ids,
+ past_key_value = past_key_value,
+ output_attentions = output_attentions,
+ use_cache = use_cache,
+ padding_mask = padding_mask,
+ position_embeddings = position_embeddings,
+ _flag_for_generation = self._flag_for_generation,
)
- hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier)
+ hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
# Fully Connected
residual = hidden_states
@@ -231,28 +231,28 @@ def GraniteDecoderLayer_fast_forward(
self.post_attention_layernorm, hidden_states
)
hidden_states = fast_swiglu_inference(self.mlp, hidden_states)
- hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier)
+ hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
else:
residual = hidden_states
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- causal_mask=causal_mask,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- padding_mask=padding_mask,
- position_embeddings=position_embeddings,
+ hidden_states = hidden_states,
+ causal_mask = causal_mask,
+ attention_mask = attention_mask,
+ position_ids = position_ids,
+ past_key_value = past_key_value,
+ output_attentions = output_attentions,
+ use_cache = use_cache,
+ padding_mask = padding_mask,
+ position_embeddings = position_embeddings,
)
- hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier)
+ hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
hidden_states = self.mlp(hidden_states)
- hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier)
+ hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
outputs = (hidden_states,)
if output_attentions:
@@ -275,9 +275,9 @@ def GraniteAttention_fast_forward_inference(
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
- do_prefill=False,
- attention_mask=None,
- use_sliding_window=False,
+ do_prefill = False,
+ attention_mask = None,
+ use_sliding_window = False,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
assert (
@@ -306,24 +306,24 @@ def GraniteAttention_fast_forward_inference(
if do_prefill:
self.paged_attention = torch.empty(
(KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
- dtype=dtype,
- device=device,
+ dtype = dtype,
+ device = device,
)
self.paged_attention_K = self.paged_attention[:, 0]
self.paged_attention_V = self.paged_attention[:, 1]
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
self.temp_QA = torch.empty(
- (2, bsz, 1, attention_size), dtype=dtype, device=device
+ (2, bsz, 1, attention_size), dtype = dtype, device = device
)
self.temp_KV = torch.empty(
- (2, bsz, 1, n_kv_heads * head_dim), dtype=dtype, device=device
+ (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
)
- self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype=dtype, device=device)
+ self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
# Only for Gemma2
- self.temp_O = torch.empty((1, bsz, hidden_size), dtype=dtype, device=device)
+ self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
self.attention = torch.empty(
- (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype=dtype, device=device
+ (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
)
self.half_head_dim = head_dim // 2
@@ -343,9 +343,9 @@ def GraniteAttention_fast_forward_inference(
(bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
)
- Qn = fast_linear_forward(self.q_proj, Xn, out=self.temp_QA[0])
- Kn = fast_linear_forward(self.k_proj, Xn, out=self.temp_KV[0])
- Vn = fast_linear_forward(self.v_proj, Xn, out=self.temp_KV[1])
+ Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
+ Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
+ Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
@@ -359,7 +359,7 @@ def GraniteAttention_fast_forward_inference(
RH_Q = self.RH_Q
RH_Q[:, :, :, :h] = Qn[:, :, :, h:]
RH_Q[:, :, :, h:] = Qn[:, :, :, :h]
- torch.neg(RH_Q[:, :, :, :h], out=RH_Q[:, :, :, :h])
+ torch.neg(RH_Q[:, :, :, :h], out = RH_Q[:, :, :, :h])
Qn *= cos
Qn.addcmul_(RH_Q, sin)
@@ -368,7 +368,7 @@ def GraniteAttention_fast_forward_inference(
] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
RH_K[:, :, :, :h] = Kn[:, :, :, h:]
RH_K[:, :, :, h:] = Kn[:, :, :, :h]
- torch.neg(RH_K[:, :, :, :h], out=RH_K[:, :, :, :h])
+ torch.neg(RH_K[:, :, :, :h], out = RH_K[:, :, :, :h])
Kn *= cos
Kn.addcmul_(RH_K, sin)
@@ -396,18 +396,18 @@ def GraniteAttention_fast_forward_inference(
# pass
Qn *= self.scaling
- A = torch_matmul(Qn, Kn.transpose(2, 3), out=self.attention[:, :, :, :cached_len])
+ A = torch_matmul(Qn, Kn.transpose(2, 3), out = self.attention[:, :, :, :cached_len])
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
- A[:] = torch_nn_functional_softmax(A, dim=-1, dtype=torch.float32) # .to(A.dtype)
- A = torch_matmul(A, Vn, out=Qn)
+ A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32) # .to(A.dtype)
+ A = torch_matmul(A, Vn, out = Qn)
# else:
# A = scaled_dot_product_attention(Qn, Kn, Vn, attn_mask = attention_mask, is_causal = False)
# pass
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
- A = fast_linear_forward(self.o_proj, A, out=self.temp_O)
+ A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
return A, (Kn, Vn)
@@ -418,7 +418,7 @@ def GraniteModel_fast_forward_inference(
input_ids,
past_key_values,
position_ids,
- attention_mask=None,
+ attention_mask = None,
):
input_ids = input_ids[:, : self.max_seq_length]
hidden_states = self.model.embed_tokens(input_ids)
@@ -459,37 +459,37 @@ def GraniteModel_fast_forward_inference(
)
hidden_states, present_key_value = GraniteAttention_fast_forward_inference(
decoder_layer.self_attn,
- hidden_states=hidden_states,
- past_key_value=past_key_values[idx],
- position_ids=position_ids,
- attention_mask=attention_mask,
- do_prefill=not hasattr(decoder_layer.self_attn, "paged_attention"),
- position_embeddings=position_embeddings,
+ hidden_states = hidden_states,
+ past_key_value = past_key_values[idx],
+ position_ids = position_ids,
+ attention_mask = attention_mask,
+ do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
+ position_embeddings = position_embeddings,
)
- hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier)
+ hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(
decoder_layer.post_attention_layernorm, hidden_states
)
hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states)
- hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier)
+ hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
next_decoder_cache.append(present_key_value)
hidden_states = fast_rms_layernorm_inference(self.model.norm, hidden_states)
return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_decoder_cache,
- hidden_states=[],
- attentions=[],
+ last_hidden_state = hidden_states,
+ past_key_values = next_decoder_cache,
+ hidden_states = [],
+ attentions = [],
)
class GraniteRotaryEmbedding(LlamaRotaryEmbedding):
def __init__(self, config):
- super().__init__(config=config)
+ super().__init__(config = config)
def patched_init(original_init):
@@ -510,10 +510,10 @@ class FastGraniteModel(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
- model_name="granite",
- rope_module=GraniteRotaryEmbedding,
- scaled_rope_module=LlamaLinearScalingRotaryEmbedding,
- attention_module=GraniteAttention,
+ model_name = "granite",
+ rope_module = GraniteRotaryEmbedding,
+ scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
+ attention_module = GraniteAttention,
)
if init_name is not None:
exec(function, globals())
@@ -548,7 +548,7 @@ class FastGraniteModel(FastLlamaModel):
model.config.update({"unsloth_version": __version__})
# We also do this for the lm_head
- lm_head = torch.nn.Linear(1, 1, bias=None)
+ lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.lm_head.weight
lm_head.in_features = lm_head.weight.shape[1]
@@ -560,7 +560,7 @@ class FastGraniteModel(FastLlamaModel):
model.model.embed_tokens.weight.data_ptr()
!= model.lm_head.weight.data_ptr()
):
- lm_head = torch.nn.Linear(1, 1, bias=None)
+ lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.model.embed_tokens.weight
lm_head.in_features = lm_head.weight.shape[1]
diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py
index b280c4978..1fe4fbffb 100644
--- a/unsloth/models/llama.py
+++ b/unsloth/models/llama.py
@@ -143,7 +143,7 @@ SDPA_HAS_GQA = "enable_gqa" in scaled_dot_product_attention.__doc__
def _fast_prepare_inputs_for_generation(
self,
input_ids,
- attention_mask=None,
+ attention_mask = None,
**kwargs,
):
past_key_values = kwargs.get("past_key_values", None)
@@ -185,7 +185,7 @@ def _fast_prepare_inputs_for_generation(
"target_length": cache_length,
"dtype": self.dtype,
"cache_position": torch.arange(
- cache_length, cache_length + 1, device=input_ids.device
+ cache_length, cache_length + 1, device = input_ids.device
),
"batch_size": bs,
"config": self.config,
@@ -240,8 +240,8 @@ def LlamaAttention_fast_forward_inference(
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
- do_prefill=False,
- attention_mask=None,
+ do_prefill = False,
+ attention_mask = None,
):
"""
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
@@ -293,29 +293,29 @@ def LlamaAttention_fast_forward_inference(
if do_prefill:
self.paged_attention = torch.empty(
(KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
- dtype=dtype,
- device=device,
+ dtype = dtype,
+ device = device,
)
self.paged_attention_K = self.paged_attention[:, 0]
self.paged_attention_V = self.paged_attention[:, 1]
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
self.temp_QA = torch.empty(
- (2, bsz, 1, attention_size), dtype=dtype, device=device
+ (2, bsz, 1, attention_size), dtype = dtype, device = device
)
self.temp_KV = torch.empty(
- (2, bsz, 1, n_kv_heads * head_dim), dtype=dtype, device=device
+ (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
)
- self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype=dtype, device=device)
+ self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
# Mistral Nemo 12b has weird dimensions
if attention_size != hidden_size:
- self.temp_O = torch.empty((1, bsz, hidden_size), dtype=dtype, device=device)
+ self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
else:
self.temp_O = self.temp_QA[1][:, :, :hidden_size]
self.attention = torch.empty(
- (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype=dtype, device=device
+ (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
)
self.scalar = 1.0 / math_sqrt(self.head_dim)
self.half_head_dim = head_dim // 2
@@ -335,9 +335,9 @@ def LlamaAttention_fast_forward_inference(
(bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
)
- Qn = fast_linear_forward(self.q_proj, Xn, out=self.temp_QA[0])
- Kn = fast_linear_forward(self.k_proj, Xn, out=self.temp_KV[0])
- Vn = fast_linear_forward(self.v_proj, Xn, out=self.temp_KV[1])
+ Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
+ Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
+ Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
@@ -414,30 +414,30 @@ def LlamaAttention_fast_forward_inference(
Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
A = torch_matmul(
- Qn, Knn.transpose(2, 3), out=self.attention[:, :, :, :cached_len]
+ Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]
)
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
A[:] = torch_nn_functional_softmax(
- A, dim=-1, dtype=torch.float32
+ A, dim = -1, dtype = torch.float32
) # .to(A.dtype)
- A = torch_matmul(A, Vnn, out=Qn)
+ A = torch_matmul(A, Vnn, out = Qn)
else:
if SDPA_HAS_GQA:
A = scaled_dot_product_attention(
Qn,
Knn,
Vnn,
- attn_mask=attention_mask,
- is_causal=is_causal,
- enable_gqa=True,
+ attn_mask = attention_mask,
+ is_causal = is_causal,
+ enable_gqa = True,
)
else:
A = scaled_dot_product_attention(
- Qn, Knn, Vnn, attn_mask=attention_mask, is_causal=is_causal
+ Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = is_causal
)
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
- A = fast_linear_forward(self.o_proj, A, out=self.temp_O)
+ A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
return A, (Kn, Vn)
@@ -445,7 +445,7 @@ torch_nn_functional_silu = torch.nn.functional.silu
def fast_swiglu_inference(
- self, X, temp_gate=None, temp_up=None, gate_multiplier=None, down_multiplier=None
+ self, X, temp_gate = None, temp_up = None, gate_multiplier = None, down_multiplier = None
):
# gate = self.gate_proj(X)
# up = self.up_proj(X)
@@ -453,18 +453,18 @@ def fast_swiglu_inference(
# mlp_size = self.config.intermediate_size
# temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0")
- gate = fast_linear_forward(self.gate_proj, X, out=temp_gate)
+ gate = fast_linear_forward(self.gate_proj, X, out = temp_gate)
if gate_multiplier is not None:
gate *= gate_multiplier
- up = fast_linear_forward(self.up_proj, X, out=temp_up)
+ up = fast_linear_forward(self.up_proj, X, out = temp_up)
- gate = torch_nn_functional_silu(gate, inplace=True)
+ gate = torch_nn_functional_silu(gate, inplace = True)
gate *= up
# X = self.down_proj(gate)
- down = fast_linear_forward(self.down_proj, gate, out=up[:, :, :hd])
+ down = fast_linear_forward(self.down_proj, gate, out = up[:, :, :hd])
if down_multiplier is not None:
down *= down_multiplier
@@ -476,14 +476,14 @@ torch_square = torch.square
torch_mean = torch.mean
-def fast_rms_layernorm_inference(self, X, XX=None, XX2=None, variance=None):
+def fast_rms_layernorm_inference(self, X, XX = None, XX2 = None, variance = None):
old_dtype = X.dtype
if XX is None:
XX = X.to(torch.float32)
- variance = XX.square().mean(-1, keepdim=True)
+ variance = XX.square().mean(-1, keepdim = True)
else:
XX.copy_(X)
- torch_mean(torch_square(XX, out=XX2), -1, keepdim=True, out=variance)
+ torch_mean(torch_square(XX, out = XX2), -1, keepdim = True, out = variance)
variance += self.variance_epsilon
XX *= variance.rsqrt_()
@@ -496,9 +496,9 @@ def fast_rms_layernorm_inference(self, X, XX=None, XX2=None, variance=None):
return X
-def fast_rms_layernorm_inference_gemma(self, X, out_weight=None):
+def fast_rms_layernorm_inference_gemma(self, X, out_weight = None):
XX = X.to(torch.float32)
- variance = XX.square().mean(-1, keepdim=True)
+ variance = XX.square().mean(-1, keepdim = True)
variance += self.variance_epsilon
XX *= variance.rsqrt_()
@@ -513,15 +513,15 @@ def fast_rms_layernorm_inference_gemma(self, X, out_weight=None):
# Normal layernorm with mean removal
-@torch.compile(fullgraph=False, dynamic=True, options=torch_compile_options)
+@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
def fast_layernorm_compiled(layernorm, X):
old_dtype = X.dtype
X = X.float()
- mean = X.mean(-1, keepdim=True)
+ mean = X.mean(-1, keepdim = True)
Xbar = X - mean
X = (
Xbar
- * torch.rsqrt(Xbar.square().mean(-1, keepdim=True) + layernorm.variance_epsilon)
+ * torch.rsqrt(Xbar.square().mean(-1, keepdim = True) + layernorm.variance_epsilon)
* layernorm.weight.float()
)
return X.to(old_dtype)
@@ -574,7 +574,7 @@ def LlamaAttention_fast_forward(
else:
# Extend RoPE dynamically to fit in VRA
rotary_emb = self.rotary_emb
- rotary_emb.extend_rope_embedding(V, seq_len=kv_seq_len)
+ rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
# if position_ids is None:
# # Useful for LongRoPE
@@ -591,8 +591,8 @@ def LlamaAttention_fast_forward(
Q, K = fast_rope_embedding(Q, K, cos, sin)
if past_key_value is not None:
- K = torch.cat([past_key_value[0], K], dim=2)
- V = torch.cat([past_key_value[1], V], dim=2)
+ K = torch.cat([past_key_value[0], K], dim = 2)
+ V = torch.cat([past_key_value[1], V], dim = 2)
past_key_value = (K, V) if use_cache else None
# Attention module
@@ -614,14 +614,14 @@ def LlamaAttention_fast_forward(
V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
else:
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
- A = xformers_attention(Q, K, V, attn_bias=causal_mask)
+ A = xformers_attention(Q, K, V, attn_bias = causal_mask)
A = A.view(bsz, q_len, n_heads, head_dim)
elif HAS_FLASH_ATTENTION and attention_mask is None:
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
- A = flash_attn_func(Q, K, V, causal=True)
+ A = flash_attn_func(Q, K, V, causal = True)
else:
# when qlen==vlen and attn_mask is None, we should use causal attention
Q_len = Q.shape[-2]
@@ -638,9 +638,9 @@ def LlamaAttention_fast_forward(
Q,
K,
V,
- attn_mask=attention_mask,
- is_causal=is_causal,
- enable_gqa=n_groups != 1,
+ attn_mask = attention_mask,
+ is_causal = is_causal,
+ enable_gqa = n_groups != 1,
)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2) # .contiguous()
@@ -661,7 +661,7 @@ def LlamaAttention_fast_forward(
# Needs (batch_size, n_heads, seq_len, head_dim)
# is_casual and attention_mask must not be both set!
A = scaled_dot_product_attention(
- Q, K, V, attn_mask=attention_mask, is_causal=is_causal
+ Q, K, V, attn_mask = attention_mask, is_causal = is_causal
)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
@@ -676,7 +676,7 @@ def LlamaAttention_fast_forward(
def LlamaDecoderLayer_fast_forward(
self,
hidden_states: torch.Tensor,
- causal_mask=None,
+ causal_mask = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
@@ -706,15 +706,15 @@ def LlamaDecoderLayer_fast_forward(
self.input_layernorm, hidden_states
)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- causal_mask=causal_mask,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- padding_mask=padding_mask,
- position_embeddings=position_embeddings,
+ hidden_states = hidden_states,
+ causal_mask = causal_mask,
+ attention_mask = attention_mask,
+ position_ids = position_ids,
+ past_key_value = past_key_value,
+ output_attentions = output_attentions,
+ use_cache = use_cache,
+ padding_mask = padding_mask,
+ position_embeddings = position_embeddings,
)
hidden_states += residual
@@ -729,15 +729,15 @@ def LlamaDecoderLayer_fast_forward(
residual = hidden_states
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- causal_mask=causal_mask,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- padding_mask=padding_mask,
- position_embeddings=position_embeddings,
+ hidden_states = hidden_states,
+ causal_mask = causal_mask,
+ attention_mask = attention_mask,
+ position_ids = position_ids,
+ past_key_value = past_key_value,
+ output_attentions = output_attentions,
+ use_cache = use_cache,
+ padding_mask = padding_mask,
+ position_embeddings = position_embeddings,
)
hidden_states = residual + hidden_states
@@ -839,8 +839,8 @@ def LlamaModel_fast_forward(
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
- dtype=torch.int32,
- device=f"{DEVICE_TYPE_TORCH}:0",
+ dtype = torch.int32,
+ device = f"{DEVICE_TYPE_TORCH}:0",
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
elif position_ids is not None:
@@ -873,7 +873,7 @@ def LlamaModel_fast_forward(
# Ie 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
# & 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
normalizer = torch.tensor(
- math_sqrt(self.config.hidden_size), dtype=inputs_embeds.dtype
+ math_sqrt(self.config.hidden_size), dtype = inputs_embeds.dtype
)
if train_embed_tokens:
@@ -930,7 +930,7 @@ def LlamaModel_fast_forward(
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
- sliding_window=getattr(self.config, "sliding_window", None),
+ sliding_window = getattr(self.config, "sliding_window", None),
)
# Must NOT convert to bool - weirdly this causes stuff to error out!
# if attention_mask is not None:
@@ -985,14 +985,14 @@ def LlamaModel_fast_forward(
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
- sliding_window=self.config.sliding_window,
+ sliding_window = self.config.sliding_window,
)
dynamic_GA_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
- sliding_window=None,
+ sliding_window = None,
)
use_static_mask = False
@@ -1012,15 +1012,15 @@ def LlamaModel_fast_forward(
self.SWA_mask = (
AttentionMaskConverter(
- is_causal=True,
- sliding_window=self.config.sliding_window,
+ is_causal = True,
+ sliding_window = self.config.sliding_window,
)
.to_causal_4d(
1,
n,
n,
- dtype=inputs_embeds.dtype,
- device=DEVICE_TYPE_TORCH,
+ dtype = inputs_embeds.dtype,
+ device = DEVICE_TYPE_TORCH,
)
.squeeze(0)
.squeeze(0)
@@ -1028,14 +1028,14 @@ def LlamaModel_fast_forward(
self.GA_mask = (
AttentionMaskConverter(
- is_causal=True,
+ is_causal = True,
)
.to_causal_4d(
1,
n,
n,
- dtype=inputs_embeds.dtype,
- device=DEVICE_TYPE_TORCH,
+ dtype = inputs_embeds.dtype,
+ device = DEVICE_TYPE_TORCH,
)
.squeeze(0)
.squeeze(0)
@@ -1087,8 +1087,8 @@ def LlamaModel_fast_forward(
*inputs,
past_key_value,
output_attentions,
- padding_mask=padding_mask,
- position_embeddings=position_embeddings,
+ padding_mask = padding_mask,
+ position_embeddings = position_embeddings,
)
return custom_forward
@@ -1099,22 +1099,22 @@ def LlamaModel_fast_forward(
mask,
attention_mask,
position_ids,
- use_reentrant=True,
- preserve_rng_state=False,
+ use_reentrant = True,
+ preserve_rng_state = False,
)
hidden_states = layer_outputs[0]
else:
layer_outputs = decoder_layer(
hidden_states,
- causal_mask=mask,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- padding_mask=padding_mask,
- position_embeddings=position_embeddings,
+ causal_mask = mask,
+ attention_mask = attention_mask,
+ position_ids = position_ids,
+ past_key_value = past_key_value,
+ output_attentions = output_attentions,
+ use_cache = use_cache,
+ padding_mask = padding_mask,
+ position_embeddings = position_embeddings,
)
hidden_states = layer_outputs[0]
@@ -1139,10 +1139,10 @@ def LlamaModel_fast_forward(
hidden_states = self.norm(hidden_states)
elif IS_FALCON_H1:
hidden_states = fast_rms_layernorm(
- self.final_layernorm, hidden_states, gemma=IS_GEMMA
+ self.final_layernorm, hidden_states, gemma = IS_GEMMA
)
else:
- hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma=IS_GEMMA)
+ hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma = IS_GEMMA)
if output_hidden_states:
all_hidden_states += (hidden_states,)
@@ -1155,17 +1155,17 @@ def LlamaModel_fast_forward(
if v is not None
)
return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
+ last_hidden_state = hidden_states,
+ past_key_values = next_cache,
+ hidden_states = all_hidden_states,
+ attentions = all_self_attns,
)
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
def _LlamaModel_fast_forward_inference(
- attention_fast_forward_inference=LlamaAttention_fast_forward_inference,
- mlp_fast_forward_inference=fast_swiglu_inference,
+ attention_fast_forward_inference = LlamaAttention_fast_forward_inference,
+ mlp_fast_forward_inference = fast_swiglu_inference,
):
# This makes the attention and MLP customisable.
# Now for models like qwen3 or cohere which use custom attention operations, we can use this function
@@ -1174,7 +1174,7 @@ def _LlamaModel_fast_forward_inference(
input_ids,
past_key_values,
position_ids,
- attention_mask=None,
+ attention_mask = None,
):
input_ids = input_ids[:, : self.max_seq_length]
bsz, q_len = input_ids.shape
@@ -1187,17 +1187,17 @@ def _LlamaModel_fast_forward_inference(
assert q_len == 1
# Get saved buffers to reduce memory movement
residual = torch.empty(
- (bsz, q_len, hd), dtype=torch.float32, device=f"{DEVICE_TYPE_TORCH}:0"
+ (bsz, q_len, hd), dtype = torch.float32, device = f"{DEVICE_TYPE_TORCH}:0"
)
_XX = torch.empty(
- (2, bsz, q_len, hd), dtype=torch.float32, device=f"{DEVICE_TYPE_TORCH}:0"
+ (2, bsz, q_len, hd), dtype = torch.float32, device = f"{DEVICE_TYPE_TORCH}:0"
)
XX, XX2 = _XX[0], _XX[1]
variance = torch.empty(
- (bsz, q_len, 1), dtype=torch.float32, device=f"{DEVICE_TYPE_TORCH}:0"
+ (bsz, q_len, 1), dtype = torch.float32, device = f"{DEVICE_TYPE_TORCH}:0"
)
temp_mlp = torch.empty(
- (2, bsz, 1, mlp_size), dtype=X.dtype, device=f"{DEVICE_TYPE_TORCH}:0"
+ (2, bsz, 1, mlp_size), dtype = X.dtype, device = f"{DEVICE_TYPE_TORCH}:0"
)
temp_gates, temp_ups = (
tuple(temp_mlp[0].to(torch.device(x)) for x in range(DEVICE_COUNT)),
@@ -1211,7 +1211,7 @@ def _LlamaModel_fast_forward_inference(
(bsz, q_len),
X,
seq_len,
- sliding_window=getattr(self.config, "sliding_window", None),
+ sliding_window = getattr(self.config, "sliding_window", None),
)
else:
attention_mask = None
@@ -1227,17 +1227,17 @@ def _LlamaModel_fast_forward_inference(
X = fast_rms_layernorm_inference(
decoder_layer.input_layernorm,
X,
- XX=XX,
- XX2=XX2,
- variance=variance,
+ XX = XX,
+ XX2 = XX2,
+ variance = variance,
)
X, present_key_value = attention_fast_forward_inference(
decoder_layer.self_attn,
- hidden_states=X,
- past_key_value=past_key_values[idx],
- position_ids=position_ids,
- attention_mask=attention_mask,
- do_prefill=not hasattr(decoder_layer.self_attn, "paged_attention"),
+ hidden_states = X,
+ past_key_value = past_key_values[idx],
+ position_ids = position_ids,
+ attention_mask = attention_mask,
+ do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
)
X += residual
@@ -1245,15 +1245,15 @@ def _LlamaModel_fast_forward_inference(
X = fast_rms_layernorm_inference(
decoder_layer.post_attention_layernorm,
X,
- XX=XX,
- XX2=XX2,
- variance=variance,
+ XX = XX,
+ XX2 = XX2,
+ variance = variance,
)
X = mlp_fast_forward_inference(
decoder_layer.mlp,
X,
- temp_gate=temp_gates[device_index],
- temp_up=temp_ups[device_index],
+ temp_gate = temp_gates[device_index],
+ temp_up = temp_ups[device_index],
)
X += residual
@@ -1261,16 +1261,16 @@ def _LlamaModel_fast_forward_inference(
X = fast_rms_layernorm_inference(
self.model.norm,
X,
- XX=XX,
- XX2=XX2,
- variance=variance,
+ XX = XX,
+ XX2 = XX2,
+ variance = variance,
)
return BaseModelOutputWithPast(
- last_hidden_state=X,
- past_key_values=next_decoder_cache,
- hidden_states=[],
- attentions=[],
+ last_hidden_state = X,
+ past_key_values = next_decoder_cache,
+ hidden_states = [],
+ attentions = [],
)
return LlamaModel_fast_forward_inference_custom
@@ -1304,8 +1304,8 @@ def CausalLM_fast_forward(fast_forward_inference):
self,
input_ids,
past_key_values,
- position_ids=position_ids,
- attention_mask=attention_mask,
+ position_ids = position_ids,
+ attention_mask = attention_mask,
)
else:
causal_mask = (
@@ -1328,16 +1328,16 @@ def CausalLM_fast_forward(fast_forward_inference):
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
self.model._has_no_labels = labels is None
outputs = self.model(
- input_ids=input_ids,
- causal_mask=causal_mask,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
+ input_ids = input_ids,
+ causal_mask = causal_mask,
+ attention_mask = attention_mask,
+ position_ids = position_ids,
+ past_key_values = past_key_values,
+ inputs_embeds = inputs_embeds,
+ use_cache = use_cache,
+ output_attentions = output_attentions,
+ output_hidden_states = output_hidden_states,
+ return_dict = return_dict,
)
hidden_states = outputs[0]
@@ -1360,11 +1360,11 @@ def CausalLM_fast_forward(fast_forward_inference):
if num_logits_to_keep != 0:
hidden_states = hidden_states[:, -num_logits_to_keep:, :]
return CausalLMOutputWithPast(
- loss=None,
- logits=hidden_states,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
+ loss = None,
+ logits = hidden_states,
+ past_key_values = outputs.past_key_values,
+ hidden_states = outputs.hidden_states,
+ attentions = outputs.attentions,
)
if bsz == 1 and q_len == 1:
@@ -1397,28 +1397,28 @@ def CausalLM_fast_forward(fast_forward_inference):
# logit_softcapping = logit_softcapping,
# )
loss = unsloth_fused_ce_loss(
- trainer=None,
- hidden_states=hidden_states,
- lm_head_weight=lm_head,
- lm_head_bias=None,
- labels=labels,
- mask=None,
- n_items=n_items,
- scaling=getattr(self, "accelerator_scaler", None),
- target_gb=None,
- torch_compile=True,
- logit_softcapping=logit_softcapping,
+ trainer = None,
+ hidden_states = hidden_states,
+ lm_head_weight = lm_head,
+ lm_head_bias = None,
+ labels = labels,
+ mask = None,
+ n_items = n_items,
+ scaling = getattr(self, "accelerator_scaler", None),
+ target_gb = None,
+ torch_compile = True,
+ logit_softcapping = logit_softcapping,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
output = CausalLMOutputWithPast(
- loss=loss,
- logits=EMPTY_LOGITS,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
+ loss = loss,
+ logits = EMPTY_LOGITS,
+ past_key_values = outputs.past_key_values,
+ hidden_states = outputs.hidden_states,
+ attentions = outputs.attentions,
)
return output
pass
@@ -1451,11 +1451,11 @@ def CausalLM_fast_forward(fast_forward_inference):
if n_items is None:
n_items = kwargs.get("n_items", None)
loss = fast_cross_entropy_loss(
- logits=shift_logits,
- labels=shift_labels,
- logit_softcapping=logit_softcapping,
- logit_scaling=logit_scaling,
- n_items=n_items,
+ logits = shift_logits,
+ labels = shift_labels,
+ logit_softcapping = logit_softcapping,
+ logit_scaling = logit_scaling,
+ n_items = n_items,
)
else:
if logit_scaling != 0:
@@ -1470,18 +1470,18 @@ def CausalLM_fast_forward(fast_forward_inference):
logits = logit_softcapping * logits
else:
logits *= 1.0 / logit_softcapping
- torch.tanh(logits, out=logits)
+ torch.tanh(logits, out = logits)
logits *= logit_softcapping
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
+ loss = loss,
+ logits = logits,
+ past_key_values = outputs.past_key_values,
+ hidden_states = outputs.hidden_states,
+ attentions = outputs.attentions,
)
return _CausalLM_fast_forward
@@ -1490,43 +1490,43 @@ def CausalLM_fast_forward(fast_forward_inference):
@torch._disable_dynamo
def PeftModel_fast_forward(
self,
- input_ids=None,
- causal_mask=None,
- attention_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- task_ids=None,
- num_logits_to_keep=0,
- logits_to_keep=0,
+ input_ids = None,
+ causal_mask = None,
+ attention_mask = None,
+ inputs_embeds = None,
+ labels = None,
+ output_attentions = None,
+ output_hidden_states = None,
+ return_dict = None,
+ task_ids = None,
+ num_logits_to_keep = 0,
+ logits_to_keep = 0,
**kwargs,
):
is_classification = "Classification" in str(type(self.base_model.model))
if is_classification:
return self.base_model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- labels=labels,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
+ input_ids = input_ids,
+ attention_mask = attention_mask,
+ inputs_embeds = inputs_embeds,
+ labels = labels,
+ output_attentions = output_attentions,
+ output_hidden_states = output_hidden_states,
+ return_dict = return_dict,
**kwargs,
)
else:
return self.base_model(
- input_ids=input_ids,
- causal_mask=causal_mask,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- labels=labels,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- num_logits_to_keep=num_logits_to_keep,
- logits_to_keep=logits_to_keep,
+ input_ids = input_ids,
+ causal_mask = causal_mask,
+ attention_mask = attention_mask,
+ inputs_embeds = inputs_embeds,
+ labels = labels,
+ output_attentions = output_attentions,
+ output_hidden_states = output_hidden_states,
+ return_dict = return_dict,
+ num_logits_to_keep = num_logits_to_keep,
+ logits_to_keep = logits_to_keep,
**kwargs,
)
@@ -1542,11 +1542,11 @@ class LlamaRotaryEmbedding(torch.nn.Module):
# The precision of RoPE buffers is not correct, so we cast to int64.
def __init__(
self,
- dim=None,
- max_position_embeddings=2048,
- base=10000,
- device=None,
- config=None, # [TODO] Hack to pass in config - need to remove later
+ dim = None,
+ max_position_embeddings = 2048,
+ base = 10000,
+ device = None,
+ config = None, # [TODO] Hack to pass in config - need to remove later
):
super().__init__()
if config is not None:
@@ -1578,17 +1578,17 @@ class LlamaRotaryEmbedding(torch.nn.Module):
# Build here to make `torch.jit.trace` work.
for device_idx in range(DEVICE_COUNT):
self._set_cos_sin_cache(
- seq_len=self.current_rope_size,
- device=torch.device(device_idx),
- dtype=torch.get_default_dtype(),
+ seq_len = self.current_rope_size,
+ device = torch.device(device_idx),
+ dtype = torch.get_default_dtype(),
)
# dummy so that patch_utils doesn't fail for now
self.cos_cached = torch.empty(
- 1, device=get_current_device(), dtype=torch.get_default_dtype()
+ 1, device = get_current_device(), dtype = torch.get_default_dtype()
)
self.sin_cached = torch.empty(
- 1, device=get_current_device(), dtype=torch.get_default_dtype()
+ 1, device = get_current_device(), dtype = torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
@@ -1598,27 +1598,27 @@ class LlamaRotaryEmbedding(torch.nn.Module):
inv_freq = 1.0 / (
self.base
** (
- torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float()
+ torch.arange(0, self.dim, 2, dtype = torch.int64, device = "cpu").float()
/ self.dim
)
)
t = torch.arange(
- self.current_rope_size, device="cpu", dtype=torch.int64
+ self.current_rope_size, device = "cpu", dtype = torch.int64
).float()
freqs = torch.outer(t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True)
- sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True)
+ emb = torch.cat((freqs, freqs), dim = -1)
+ cos = emb.cos().to(dtype = dtype, device = device, non_blocking = True)
+ sin = emb.sin().to(dtype = dtype, device = device, non_blocking = True)
self.multi_gpu_cos_cached[device.index] = cos
self.multi_gpu_sin_cached[device.index] = sin
return cos, sin
- def forward(self, x, position_ids=None, seq_len=None):
+ def forward(self, x, position_ids = None, seq_len = None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len is not None and seq_len > self.current_rope_size:
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
+ self._set_cos_sin_cache(seq_len = seq_len, device = x.device, dtype = x.dtype)
device_index = x.device.index
return (
@@ -1626,7 +1626,7 @@ class LlamaRotaryEmbedding(torch.nn.Module):
self.multi_gpu_sin_cached[device_index][:seq_len],
)
- def get_cached(self, seq_len=None, device_index=None):
+ def get_cached(self, seq_len = None, device_index = None):
if device_index is None:
device_index = get_current_device()
return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[
@@ -1640,7 +1640,7 @@ class LlamaRotaryEmbedding(torch.nn.Module):
self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
for device_idx in range(DEVICE_COUNT):
self._set_cos_sin_cache(
- self.current_rope_size, device=torch.device(device_idx), dtype=x.dtype
+ self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype
)
@@ -1652,20 +1652,20 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
# The precision of RoPE buffers is not correct, so we cast to int64.
def __init__(
self,
- dim=None,
- max_position_embeddings=2048,
- base=10000,
- device=None,
- scaling_factor=1.0,
- config=None, # [TODO] Hack to pass in config - need to remove later
+ dim = None,
+ max_position_embeddings = 2048,
+ base = 10000,
+ device = None,
+ scaling_factor = 1.0,
+ config = None, # [TODO] Hack to pass in config - need to remove later
):
self.scaling_factor = scaling_factor
super().__init__(
- dim=dim,
- max_position_embeddings=max_position_embeddings,
- base=base,
- device=device,
- config=config,
+ dim = dim,
+ max_position_embeddings = max_position_embeddings,
+ base = base,
+ device = device,
+ config = config,
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
@@ -1673,20 +1673,20 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
inv_freq = 1.0 / (
self.base
** (
- torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float()
+ torch.arange(0, self.dim, 2, dtype = torch.int64, device = "cpu").float()
/ self.dim
)
)
t = torch.arange(
- self.current_rope_size, device="cpu", dtype=torch.int64
+ self.current_rope_size, device = "cpu", dtype = torch.int64
).float()
t = t / self.scaling_factor
freqs = torch.outer(t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True)
- sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True)
+ emb = torch.cat((freqs, freqs), dim = -1)
+ cos = emb.cos().to(dtype = dtype, device = device, non_blocking = True)
+ sin = emb.sin().to(dtype = dtype, device = device, non_blocking = True)
self.multi_gpu_cos_cached[device.index] = cos
self.multi_gpu_sin_cached[device.index] = sin
return cos, sin
@@ -1697,11 +1697,11 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
class LlamaExtendedRotaryEmbedding(torch.nn.Module):
def __init__(
self,
- dim=None,
- max_position_embeddings=2048,
- base=10000,
- device=None,
- config=None, # [TODO] Hack to pass in config - need to remove later
+ dim = None,
+ max_position_embeddings = 2048,
+ base = 10000,
+ device = None,
+ config = None, # [TODO] Hack to pass in config - need to remove later
):
super().__init__()
if config is not None:
@@ -1728,27 +1728,27 @@ class LlamaExtendedRotaryEmbedding(torch.nn.Module):
inv_freq = 1.0 / (
self.base
** (
- torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float()
+ torch.arange(0, self.dim, 2, dtype = torch.int64, device = "cpu").float()
/ self.dim
)
)
inv_freq = self.apply_scaling(inv_freq)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.register_buffer("inv_freq", inv_freq, persistent = False)
# Build here to make `torch.jit.trace` work.
for device_idx in range(DEVICE_COUNT):
self._set_cos_sin_cache(
- seq_len=self.current_rope_size,
- device=torch.device(device_idx),
- dtype=torch.get_default_dtype(),
+ seq_len = self.current_rope_size,
+ device = torch.device(device_idx),
+ dtype = torch.get_default_dtype(),
)
# dummy so that patch_utils doesn't fail for now
self.cos_cached = torch.empty(
- 1, device=get_current_device(), dtype=torch.get_default_dtype()
+ 1, device = get_current_device(), dtype = torch.get_default_dtype()
)
self.sin_cached = torch.empty(
- 1, device=get_current_device(), dtype=torch.get_default_dtype()
+ 1, device = get_current_device(), dtype = torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
@@ -1757,29 +1757,29 @@ class LlamaExtendedRotaryEmbedding(torch.nn.Module):
self.current_rope_size = seq_len
t = torch.arange(
- self.current_rope_size, device=self.inv_freq.device, dtype=torch.int64
+ self.current_rope_size, device = self.inv_freq.device, dtype = torch.int64
).float()
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True)
- sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True)
+ emb = torch.cat((freqs, freqs), dim = -1)
+ cos = emb.cos().to(dtype = dtype, device = device, non_blocking = True)
+ sin = emb.sin().to(dtype = dtype, device = device, non_blocking = True)
self.multi_gpu_cos_cached[device.index] = cos
self.multi_gpu_sin_cached[device.index] = sin
return cos, sin
- def forward(self, x, position_ids=None, seq_len=None):
+ def forward(self, x, position_ids = None, seq_len = None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len is not None and seq_len > self.current_rope_size:
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
+ self._set_cos_sin_cache(seq_len = seq_len, device = x.device, dtype = x.dtype)
device_index = x.device.index
return (
self.multi_gpu_cos_cached[device_index][:seq_len],
self.multi_gpu_sin_cached[device_index][:seq_len],
)
- def get_cached(self, seq_len=None, device_index=None):
+ def get_cached(self, seq_len = None, device_index = None):
if device_index is None:
device_index = get_current_device()
return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[
@@ -1793,7 +1793,7 @@ class LlamaExtendedRotaryEmbedding(torch.nn.Module):
self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
for device_idx in range(DEVICE_COUNT):
self._set_cos_sin_cache(
- self.current_rope_size, device=torch.device(device_idx), dtype=x.dtype
+ self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype
)
# From https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py#L41
@@ -1819,21 +1819,21 @@ class LlamaExtendedRotaryEmbedding(torch.nn.Module):
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
- return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
+ return torch.tensor(new_freqs, dtype = freqs.dtype, device = freqs.device)
class LongRopeRotaryEmbedding(torch.nn.Module):
# For Phi 3.5 128K https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/modeling_phi3.py
def __init__(
self,
- dim=None,
- max_position_embeddings=131072,
- original_max_position_embeddings=4096,
- base=10000,
- short_factor=None,
- long_factor=None,
- device=None,
- config=None, # [TODO] Hack to pass in config - need to remove later
+ dim = None,
+ max_position_embeddings = 131072,
+ original_max_position_embeddings = 4096,
+ base = 10000,
+ short_factor = None,
+ long_factor = None,
+ device = None,
+ config = None, # [TODO] Hack to pass in config - need to remove later
):
super().__init__()
assert short_factor is not None
@@ -1868,11 +1868,11 @@ class LongRopeRotaryEmbedding(torch.nn.Module):
# Long RoPE similar to RoPE except short sequences have 1 cos / sin
# and long sequences have another cos / sin
inv_freq_shape = (
- torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float()
+ torch.arange(0, self.dim, 2, dtype = torch.int64, device = "cpu").float()
/ self.dim
)
- short_factor = torch.tensor(short_factor, device="cpu", dtype=torch.float32)
- long_factor = torch.tensor(long_factor, device="cpu", dtype=torch.float32)
+ short_factor = torch.tensor(short_factor, device = "cpu", dtype = torch.float32)
+ long_factor = torch.tensor(long_factor, device = "cpu", dtype = torch.float32)
short_inv_freq = 1.0 / (short_factor * self.base**inv_freq_shape)
long_inv_freq = 1.0 / (long_factor * self.base**inv_freq_shape)
@@ -1887,43 +1887,43 @@ class LongRopeRotaryEmbedding(torch.nn.Module):
self.scaling_factor = scaling_factor
# Short and long inv_freq
- self.register_buffer("short_inv_freq", short_inv_freq, persistent=False)
- self.register_buffer("long_inv_freq", long_inv_freq, persistent=False)
+ self.register_buffer("short_inv_freq", short_inv_freq, persistent = False)
+ self.register_buffer("long_inv_freq", long_inv_freq, persistent = False)
# Build here to make `torch.jit.trace` work.
# Initialize short sequences cache for all devices
dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16
t = torch.arange(
original_max_position_embeddings,
- device=self.short_inv_freq.device,
- dtype=torch.int64,
+ device = self.short_inv_freq.device,
+ dtype = torch.int64,
).float()
freqs = torch.outer(t, self.short_inv_freq)
- emb = torch.cat((freqs, freqs), dim=-1)
+ emb = torch.cat((freqs, freqs), dim = -1)
for device_idx in range(DEVICE_COUNT):
device_obj = torch.device(device_idx)
cos_cached = (emb.cos() * self.scaling_factor).to(
- dtype=dtype, device=device_obj, non_blocking=True
+ dtype = dtype, device = device_obj, non_blocking = True
)
sin_cached = (emb.sin() * self.scaling_factor).to(
- dtype=dtype, device=device_obj, non_blocking=True
+ dtype = dtype, device = device_obj, non_blocking = True
)
self.multi_gpu_short_cos_cached[device_idx] = cos_cached
self.multi_gpu_short_sin_cached[device_idx] = sin_cached
# dummy so that patch_utils doesn't fail for now
self.short_cos_cached = torch.empty(
- 1, device=get_current_device(), dtype=torch.get_default_dtype()
+ 1, device = get_current_device(), dtype = torch.get_default_dtype()
)
self.short_sin_cached = torch.empty(
- 1, device=get_current_device(), dtype=torch.get_default_dtype()
+ 1, device = get_current_device(), dtype = torch.get_default_dtype()
)
self.long_cos_cached = torch.empty(
- 1, device=get_current_device(), dtype=torch.get_default_dtype()
+ 1, device = get_current_device(), dtype = torch.get_default_dtype()
)
self.long_sin_cached = torch.empty(
- 1, device=get_current_device(), dtype=torch.get_default_dtype()
+ 1, device = get_current_device(), dtype = torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
@@ -1932,25 +1932,25 @@ class LongRopeRotaryEmbedding(torch.nn.Module):
self.current_rope_size = seq_len
t = torch.arange(
- self.current_rope_size, device=self.long_inv_freq.device, dtype=torch.int64
+ self.current_rope_size, device = self.long_inv_freq.device, dtype = torch.int64
).float()
# Long sequences
freqs = torch.outer(t, self.long_inv_freq)
- emb = torch.cat((freqs, freqs), dim=-1)
+ emb = torch.cat((freqs, freqs), dim = -1)
cos_cached = (emb.cos() * self.scaling_factor).to(
- dtype=dtype, device=device, non_blocking=True
+ dtype = dtype, device = device, non_blocking = True
)
sin_cached = (emb.sin() * self.scaling_factor).to(
- dtype=dtype, device=device, non_blocking=True
+ dtype = dtype, device = device, non_blocking = True
)
self.multi_gpu_long_cos_cached[device.index] = cos_cached
self.multi_gpu_long_sin_cached[device.index] = sin_cached
return cos_cached, sin_cached
- def forward(self, x, position_ids=None, seq_len=None):
+ def forward(self, x, position_ids = None, seq_len = None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len is not None and seq_len > self.current_rope_size:
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
+ self._set_cos_sin_cache(seq_len = seq_len, device = x.device, dtype = x.dtype)
device_index = x.device.index
@@ -1965,7 +1965,7 @@ class LongRopeRotaryEmbedding(torch.nn.Module):
self.multi_gpu_long_sin_cached[device_index][:seq_len],
)
- def get_cached(self, seq_len=None, device_index=None):
+ def get_cached(self, seq_len = None, device_index = None):
if device_index is None:
device_index = get_current_device()
if seq_len is not None and seq_len < self.original_max_position_embeddings:
@@ -1983,7 +1983,7 @@ class LongRopeRotaryEmbedding(torch.nn.Module):
self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
for device_idx in range(DEVICE_COUNT):
self._set_cos_sin_cache(
- self.current_rope_size, device=torch.device(device_idx), dtype=x.dtype
+ self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype
)
@@ -2041,7 +2041,7 @@ def unsloth_fast_generate(
# Mixed precision autocast
with (
_get_inference_mode_context_manager(self),
- torch.autocast(device_type=DEVICE_TYPE_TORCH, dtype=dtype),
+ torch.autocast(device_type = DEVICE_TYPE_TORCH, dtype = dtype),
):
output = self._old_generate(*args, **kwargs)
@@ -2065,12 +2065,12 @@ class FastLlamaModel:
@staticmethod
def pre_patch():
init_name, function = patch_llama_rope_scaling(
- model_name="llama",
- rope_module=LlamaRotaryEmbedding,
- scaled_rope_module=LlamaLinearScalingRotaryEmbedding,
- extended_rope_module=LlamaExtendedRotaryEmbedding,
- attention_module=LlamaAttention,
- longrope_module=LongRopeRotaryEmbedding,
+ model_name = "llama",
+ rope_module = LlamaRotaryEmbedding,
+ scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
+ extended_rope_module = LlamaExtendedRotaryEmbedding,
+ attention_module = LlamaAttention,
+ longrope_module = LongRopeRotaryEmbedding,
)
if init_name is not None:
exec(function, globals())
@@ -2103,27 +2103,27 @@ class FastLlamaModel:
@staticmethod
def from_pretrained(
- model_name="unsloth/llama-3-8b-bnb-4bit",
- max_seq_length=None,
- dtype=None,
- load_in_4bit=True,
- token=None,
- device_map="sequential",
- rope_scaling=None,
- fix_tokenizer=True,
- model_patcher=None,
- tokenizer_name=None,
- trust_remote_code=False,
- revision=None,
- fast_inference=False, # uses vLLM
- gpu_memory_utilization=0.5,
- float8_kv_cache=False,
- random_state=3407,
- max_lora_rank=16,
- disable_log_stats=False,
- unsloth_vllm_standby=False,
- num_labels=None,
- qat_scheme=None,
+ model_name = "unsloth/llama-3-8b-bnb-4bit",
+ max_seq_length = None,
+ dtype = None,
+ load_in_4bit = True,
+ token = None,
+ device_map = "sequential",
+ rope_scaling = None,
+ fix_tokenizer = True,
+ model_patcher = None,
+ tokenizer_name = None,
+ trust_remote_code = False,
+ revision = None,
+ fast_inference = False, # uses vLLM
+ gpu_memory_utilization = 0.5,
+ float8_kv_cache = False,
+ random_state = 3407,
+ max_lora_rank = 16,
+ disable_log_stats = False,
+ unsloth_vllm_standby = False,
+ num_labels = None,
+ qat_scheme = None,
**kwargs,
):
os.environ["UNSLOTH_USE_NEW_MODEL"] = "0"
@@ -2249,8 +2249,8 @@ class FastLlamaModel:
# RoPE Scaling
model_config = AutoConfig.from_pretrained(
model_name,
- token=token,
- attn_implementation="sdpa",
+ token = token,
+ attn_implementation = "sdpa",
)
model_config.model_name = model_name
model_max_seq_length = model_config.max_position_embeddings
@@ -2263,7 +2263,7 @@ class FastLlamaModel:
has_rope_scaling = False
try:
- with open(inspect.getfile(model_function), "r", encoding="utf-8") as file:
+ with open(inspect.getfile(model_function), "r", encoding = "utf-8") as file:
has_rope_scaling = "self.config.rope_scaling" in file.read()
except:
pass
@@ -2310,11 +2310,11 @@ class FastLlamaModel:
# we cannot quantize out_proj layer due to mamba kernels: https://github.com/tiiuae/Falcon-H1/issues/13#issuecomment-2918671274
llm_int8_skip_modules.append("out_proj")
bnb_config = BitsAndBytesConfig(
- load_in_4bit=True,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
- bnb_4bit_compute_dtype=dtype,
- llm_int8_skip_modules=llm_int8_skip_modules,
+ load_in_4bit = True,
+ bnb_4bit_use_double_quant = True,
+ bnb_4bit_quant_type = "nf4",
+ bnb_4bit_compute_dtype = dtype,
+ llm_int8_skip_modules = llm_int8_skip_modules,
)
# https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/12
@@ -2332,26 +2332,26 @@ class FastLlamaModel:
if num_labels is not None:
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
- device_map=device_map,
+ device_map = device_map,
# torch_dtype = dtype, # transformers changed torch_dtype to dtype
- num_labels=num_labels,
+ num_labels = num_labels,
# quantization_config = bnb_config,
- token=token,
- max_position_embeddings=max_position_embeddings,
- trust_remote_code=trust_remote_code,
- attn_implementation="eager",
+ token = token,
+ max_position_embeddings = max_position_embeddings,
+ trust_remote_code = trust_remote_code,
+ attn_implementation = "eager",
**kwargs,
)
elif not fast_inference:
model = AutoModelForCausalLM.from_pretrained(
model_name,
- device_map=device_map,
+ device_map = device_map,
# torch_dtype = dtype, # transformers changed torch_dtype to dtype
# quantization_config = bnb_config,
- token=token,
- max_position_embeddings=max_position_embeddings,
- trust_remote_code=trust_remote_code,
- attn_implementation="eager",
+ token = token,
+ max_position_embeddings = max_position_embeddings,
+ trust_remote_code = trust_remote_code,
+ attn_implementation = "eager",
**kwargs,
)
model.fast_generate = model.generate
@@ -2366,17 +2366,17 @@ class FastLlamaModel:
allowed_args = inspect.getfullargspec(load_vllm).args
load_vllm_kwargs = dict(
- model_name=model_name,
- config=model_config,
- gpu_memory_utilization=gpu_memory_utilization,
- max_seq_length=max_seq_length,
- dtype=dtype,
- float8_kv_cache=float8_kv_cache,
- enable_lora=True,
- max_lora_rank=max_lora_rank,
- disable_log_stats=disable_log_stats,
- use_bitsandbytes=load_in_4bit,
- unsloth_vllm_standby=unsloth_vllm_standby,
+ model_name = model_name,
+ config = model_config,
+ gpu_memory_utilization = gpu_memory_utilization,
+ max_seq_length = max_seq_length,
+ dtype = dtype,
+ float8_kv_cache = float8_kv_cache,
+ enable_lora = True,
+ max_lora_rank = max_lora_rank,
+ disable_log_stats = disable_log_stats,
+ use_bitsandbytes = load_in_4bit,
+ unsloth_vllm_standby = unsloth_vllm_standby,
)
for allowed_arg in allowed_args:
if allowed_arg not in load_vllm_kwargs and allowed_arg in kwargs:
@@ -2387,7 +2387,7 @@ class FastLlamaModel:
llm = load_vllm(**load_vllm_kwargs)
# Convert to HF format
- _, quant_state_dict = get_vllm_state_dict(llm, config=model_config)
+ _, quant_state_dict = get_vllm_state_dict(llm, config = model_config)
model = convert_vllm_to_huggingface(
quant_state_dict, model_config, dtype, bnb_config
)
@@ -2403,12 +2403,12 @@ class FastLlamaModel:
# Counteract saved tokenizers
tokenizer_name = model_name if tokenizer_name is None else tokenizer_name
tokenizer = load_correct_tokenizer(
- tokenizer_name=tokenizer_name,
- model_max_length=max_position_embeddings,
- padding_side="right",
- token=token,
- trust_remote_code=trust_remote_code,
- fix_tokenizer=fix_tokenizer,
+ tokenizer_name = tokenizer_name,
+ model_max_length = max_position_embeddings,
+ padding_side = "right",
+ token = token,
+ trust_remote_code = trust_remote_code,
+ fix_tokenizer = fix_tokenizer,
)
model, tokenizer = patch_tokenizer(model, tokenizer)
@@ -2489,7 +2489,7 @@ class FastLlamaModel:
front_spaces = re.match(r"[\t\s]{1,}", inner_training_loop).group(0)
inner_training_loop = re.sub(
- r"^" + front_spaces, "", inner_training_loop, flags=re.MULTILINE
+ r"^" + front_spaces, "", inner_training_loop, flags = re.MULTILINE
)
inner_training_loop = inner_training_loop.replace(
"train_dataloader = tpu_spmd_dataloader(train_dataloader)",
@@ -2521,12 +2521,12 @@ class FastLlamaModel:
# We check the tokenizer first for errors
if fix_tokenizer:
tokenizer = check_tokenizer(
- model=model,
- tokenizer=tokenizer,
- model_name=model_name,
- model_max_length=max_position_embeddings,
- padding_side="right",
- token=token,
+ model = model,
+ tokenizer = tokenizer,
+ model_name = model_name,
+ model_max_length = max_position_embeddings,
+ padding_side = "right",
+ token = token,
)
patch_saving_functions(tokenizer)
@@ -2597,15 +2597,15 @@ class FastLlamaModel:
@staticmethod
def post_patch(model, tokenizer):
model, tokenizer = patch_model_and_tokenizer(
- model, tokenizer, downcast_rope=True
+ model, tokenizer, downcast_rope = True
)
return model, tokenizer
@staticmethod
def get_peft_model(
model,
- r=16,
- target_modules=[
+ r = 16,
+ target_modules = [
"q_proj",
"k_proj",
"v_proj",
@@ -2614,20 +2614,20 @@ class FastLlamaModel:
"up_proj",
"down_proj",
],
- lora_alpha=16,
- lora_dropout=0.0,
- bias="none",
- layers_to_transform=None,
- layers_pattern=None,
- use_gradient_checkpointing="unsloth",
- random_state=3407,
- max_seq_length=2048, # not used anymore
- use_rslora=False,
- modules_to_save=None,
- init_lora_weights=True,
- loftq_config={},
- temporary_location="_unsloth_temporary_saved_buffers",
- qat_scheme=None,
+ lora_alpha = 16,
+ lora_dropout = 0.0,
+ bias = "none",
+ layers_to_transform = None,
+ layers_pattern = None,
+ use_gradient_checkpointing = "unsloth",
+ random_state = 3407,
+ max_seq_length = 2048, # not used anymore
+ use_rslora = False,
+ modules_to_save = None,
+ init_lora_weights = True,
+ loftq_config = {},
+ temporary_location = "_unsloth_temporary_saved_buffers",
+ qat_scheme = None,
**kwargs,
):
if os.environ.get("UNSLOTH_USE_NEW_MODEL", "0") == "1":
@@ -2641,22 +2641,22 @@ class FastLlamaModel:
if peft_arg not in kwargs:
kwargs[peft_arg] = flag
return FastBaseModel.get_peft_model(
- model=model,
- r=r,
- target_modules=target_modules,
- lora_alpha=lora_alpha,
- lora_dropout=lora_dropout,
- bias=bias,
- layers_to_transform=layers_to_transform,
- layers_pattern=layers_pattern,
- use_gradient_checkpointing=use_gradient_checkpointing,
- random_state=random_state,
- max_seq_length=max_seq_length,
- use_rslora=use_rslora,
- modules_to_save=modules_to_save,
- init_lora_weights=init_lora_weights,
- loftq_config=loftq_config,
- temporary_location=temporary_location,
+ model = model,
+ r = r,
+ target_modules = target_modules,
+ lora_alpha = lora_alpha,
+ lora_dropout = lora_dropout,
+ bias = bias,
+ layers_to_transform = layers_to_transform,
+ layers_pattern = layers_pattern,
+ use_gradient_checkpointing = use_gradient_checkpointing,
+ random_state = random_state,
+ max_seq_length = max_seq_length,
+ use_rslora = use_rslora,
+ modules_to_save = modules_to_save,
+ init_lora_weights = init_lora_weights,
+ loftq_config = loftq_config,
+ temporary_location = temporary_location,
**kwargs,
)
if os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1":
@@ -2668,7 +2668,7 @@ class FastLlamaModel:
if use_gradient_checkpointing == "unsloth":
patch_unsloth_smart_gradient_checkpointing(
- dtype=model.get_input_embeddings().weight.dtype
+ dtype = model.get_input_embeddings().weight.dtype
)
if type(r) is not int:
@@ -2744,7 +2744,7 @@ class FastLlamaModel:
new_dtype = torch.float32
model.get_input_embeddings().modules_to_save.default.to(
- device=DEVICE_TYPE_TORCH, dtype=new_dtype, non_blocking=True
+ device = DEVICE_TYPE_TORCH, dtype = new_dtype, non_blocking = True
)
model.get_input_embeddings().modules_to_save.default.requires_grad_(
True
@@ -2752,7 +2752,7 @@ class FastLlamaModel:
# [TODO] Move old embed_tokens to CPU - should be disk!
model.get_input_embeddings().original_module.to(
- device="cpu", non_blocking=True
+ device = "cpu", non_blocking = True
)
model.get_input_embeddings().original_module.requires_grad_(False)
@@ -2766,7 +2766,7 @@ class FastLlamaModel:
new_dtype = torch.float32
model.get_output_embeddings().modules_to_save.default.to(
- device=DEVICE_TYPE_TORCH, dtype=new_dtype, non_blocking=True
+ device = DEVICE_TYPE_TORCH, dtype = new_dtype, non_blocking = True
)
model.get_output_embeddings().modules_to_save.default.requires_grad_(
True
@@ -2774,7 +2774,7 @@ class FastLlamaModel:
# [TODO] Move old lm_head to CPU - should be disk!
model.get_output_embeddings().original_module.to(
- device="cpu", non_blocking=True
+ device = "cpu", non_blocking = True
)
model.get_output_embeddings().original_module.requires_grad_(False)
@@ -2829,7 +2829,7 @@ class FastLlamaModel:
"Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"
"We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`."
)
- loftq_config = LoftQConfig(loftq_bits=4, loftq_iter=1)
+ loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)
if hasattr(model.config, "quantization_config"):
raise ValueError(
@@ -2969,17 +2969,17 @@ class FastLlamaModel:
is_classification = "Classification" in str(type(model))
arguments = dict(
- r=r,
- lora_alpha=lora_alpha,
- target_modules=final_modules,
- lora_dropout=lora_dropout,
- bias=bias,
- task_type=TaskType.CAUSAL_LM if not is_classification else TaskType.SEQ_CLS,
- layers_to_transform=layers_to_transform,
- init_lora_weights=init_lora_weights,
- loftq_config=loftq_config,
- use_rslora=use_rslora,
- modules_to_save=modules_to_save,
+ r = r,
+ lora_alpha = lora_alpha,
+ target_modules = final_modules,
+ lora_dropout = lora_dropout,
+ bias = bias,
+ task_type = TaskType.CAUSAL_LM if not is_classification else TaskType.SEQ_CLS,
+ layers_to_transform = layers_to_transform,
+ init_lora_weights = init_lora_weights,
+ loftq_config = loftq_config,
+ use_rslora = use_rslora,
+ modules_to_save = modules_to_save,
**kwargs,
)
if not SUPPORTS_LOFTQ:
@@ -3042,7 +3042,7 @@ class FastLlamaModel:
new_dtype = torch.float32
model.get_input_embeddings().modules_to_save.default.to(
- device=DEVICE_TYPE_TORCH, dtype=new_dtype, non_blocking=True
+ device = DEVICE_TYPE_TORCH, dtype = new_dtype, non_blocking = True
)
model.get_input_embeddings().modules_to_save.default.requires_grad_(True)
@@ -3059,7 +3059,7 @@ class FastLlamaModel:
new_dtype = torch.float32
model.get_output_embeddings().modules_to_save.default.to(
- device=DEVICE_TYPE_TORCH, dtype=new_dtype, non_blocking=True
+ device = DEVICE_TYPE_TORCH, dtype = new_dtype, non_blocking = True
)
model.get_output_embeddings().modules_to_save.default.requires_grad_(True)
@@ -3096,12 +3096,12 @@ class FastLlamaModel:
@staticmethod
def patch_peft_model(
model,
- use_gradient_checkpointing="unsloth",
+ use_gradient_checkpointing = "unsloth",
):
if os.environ.get("UNSLOTH_USE_NEW_MODEL", "0") == "1":
return FastBaseModel.patch_peft_model(
- model=model,
- use_gradient_checkpointing=use_gradient_checkpointing,
+ model = model,
+ use_gradient_checkpointing = use_gradient_checkpointing,
)
if not isinstance(model, PeftModelForCausalLM) and not isinstance(
model, PeftModelForSequenceClassification
@@ -3138,8 +3138,8 @@ class FastLlamaModel:
model = prepare_model_for_kbit_training(
model,
- use_gradient_checkpointing=use_gradient_checkpointing,
- use_reentrant=True,
+ use_gradient_checkpointing = use_gradient_checkpointing,
+ use_reentrant = True,
)
# Fix up config for transformers uploading PEFT
@@ -3192,7 +3192,7 @@ class FastLlamaModel:
# We also do not inplace edit QKV for Cohere!
_apply_lora_mlp = (
- functools.partial(apply_lora_mlp, inplace=False)
+ functools.partial(apply_lora_mlp, inplace = False)
if model_type == "cohere"
else apply_lora_mlp
)
@@ -3375,7 +3375,7 @@ class FastLlamaModel:
return model
@staticmethod
- def for_training(model, use_gradient_checkpointing=True):
+ def for_training(model, use_gradient_checkpointing = True):
if not hasattr(model, "parameters"):
raise TypeError(
"Unsloth: I think you're passing a tokenizer, not the model to for_training!"
@@ -3424,4 +3424,4 @@ class FastLlamaModel:
from .rl import PatchFastRL
-PatchFastRL(FastLanguageModel=FastLlamaModel)
+PatchFastRL(FastLanguageModel = FastLlamaModel)
diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py
index 5389d84d2..e1c13315f 100644
--- a/unsloth/models/loader.py
+++ b/unsloth/models/loader.py
@@ -120,33 +120,33 @@ DISABLE_SDPA_MODEL_NAMES = [
class FastLanguageModel(FastLlamaModel):
@staticmethod
def from_pretrained(
- model_name="unsloth/Llama-3.2-1B-Instruct",
- max_seq_length=2048,
- dtype=None,
- load_in_4bit=True, # 4bit QLoRA
- load_in_8bit=False, # 8bit LoRA
- load_in_16bit=False, # 16bit LoRA
- full_finetuning=False,
- token=None,
- device_map="sequential",
- rope_scaling=None,
- fix_tokenizer=True,
- trust_remote_code=False,
- use_gradient_checkpointing="unsloth",
- resize_model_vocab=None,
- revision=None,
- use_exact_model_name=False,
- offload_embedding=False,
- float32_mixed_precision=None, # Forces float32 mixed precision
- fast_inference=False, # uses vLLM
- gpu_memory_utilization=0.5,
- float8_kv_cache=False,
- random_state=3407,
- max_lora_rank=64,
- disable_log_stats=True,
- qat_scheme=None,
- load_in_fp8=False, # fp8 LoRA (True, False, 'block')
- unsloth_tiled_mlp=False,
+ model_name = "unsloth/Llama-3.2-1B-Instruct",
+ max_seq_length = 2048,
+ dtype = None,
+ load_in_4bit = True, # 4bit QLoRA
+ load_in_8bit = False, # 8bit LoRA
+ load_in_16bit = False, # 16bit LoRA
+ full_finetuning = False,
+ token = None,
+ device_map = "sequential",
+ rope_scaling = None,
+ fix_tokenizer = True,
+ trust_remote_code = False,
+ use_gradient_checkpointing = "unsloth",
+ resize_model_vocab = None,
+ revision = None,
+ use_exact_model_name = False,
+ offload_embedding = False,
+ float32_mixed_precision = None, # Forces float32 mixed precision
+ fast_inference = False, # uses vLLM
+ gpu_memory_utilization = 0.5,
+ float8_kv_cache = False,
+ random_state = 3407,
+ max_lora_rank = 64,
+ disable_log_stats = True,
+ qat_scheme = None,
+ load_in_fp8 = False, # fp8 LoRA (True, False, 'block')
+ unsloth_tiled_mlp = False,
*args,
**kwargs,
):
@@ -157,40 +157,40 @@ class FastLanguageModel(FastLlamaModel):
try:
from huggingface_hub import login
- login(token=token)
+ login(token = token)
except:
pass
if load_in_8bit or full_finetuning or qat_scheme is not None:
return FastModel.from_pretrained(
- model_name=model_name,
- max_seq_length=max_seq_length,
- dtype=dtype,
- load_in_4bit=load_in_4bit,
- load_in_8bit=load_in_8bit,
- load_in_16bit=load_in_16bit,
- full_finetuning=full_finetuning,
- token=token,
- device_map=device_map,
- rope_scaling=rope_scaling, # [TODO] No effect
- fix_tokenizer=fix_tokenizer, # [TODO] No effect
- trust_remote_code=trust_remote_code,
- use_gradient_checkpointing=use_gradient_checkpointing,
- resize_model_vocab=resize_model_vocab, # [TODO] No effect
- revision=revision,
- return_logits=False, # Return logits
- fullgraph=True, # No graph breaks
- use_exact_model_name=use_exact_model_name,
- offload_embedding=offload_embedding,
- float32_mixed_precision=float32_mixed_precision,
+ model_name = model_name,
+ max_seq_length = max_seq_length,
+ dtype = dtype,
+ load_in_4bit = load_in_4bit,
+ load_in_8bit = load_in_8bit,
+ load_in_16bit = load_in_16bit,
+ full_finetuning = full_finetuning,
+ token = token,
+ device_map = device_map,
+ rope_scaling = rope_scaling, # [TODO] No effect
+ fix_tokenizer = fix_tokenizer, # [TODO] No effect
+ trust_remote_code = trust_remote_code,
+ use_gradient_checkpointing = use_gradient_checkpointing,
+ resize_model_vocab = resize_model_vocab, # [TODO] No effect
+ revision = revision,
+ return_logits = False, # Return logits
+ fullgraph = True, # No graph breaks
+ use_exact_model_name = use_exact_model_name,
+ offload_embedding = offload_embedding,
+ float32_mixed_precision = float32_mixed_precision,
# Pass vLLM/inference parameters
- fast_inference=fast_inference,
- gpu_memory_utilization=gpu_memory_utilization,
- float8_kv_cache=float8_kv_cache,
- random_state=random_state,
- max_lora_rank=max_lora_rank,
- disable_log_stats=disable_log_stats,
- qat_scheme=qat_scheme,
- load_in_fp8=load_in_fp8,
+ fast_inference = fast_inference,
+ gpu_memory_utilization = gpu_memory_utilization,
+ float8_kv_cache = float8_kv_cache,
+ random_state = random_state,
+ max_lora_rank = max_lora_rank,
+ disable_log_stats = disable_log_stats,
+ qat_scheme = qat_scheme,
+ load_in_fp8 = load_in_fp8,
*args,
**kwargs,
)
@@ -231,7 +231,7 @@ class FastLanguageModel(FastLlamaModel):
fp8_mode = None
if not use_exact_model_name:
new_model_name = get_model_name(
- model_name, load_in_4bit=load_in_4bit, load_in_fp8=load_in_fp8
+ model_name, load_in_4bit = load_in_4bit, load_in_fp8 = load_in_fp8
)
if new_model_name is None and load_in_fp8 != False:
fp8_mode = _get_fp8_mode_and_check_settings(
@@ -284,9 +284,9 @@ class FastLanguageModel(FastLlamaModel):
try:
model_config = AutoConfig.from_pretrained(
model_name,
- token=token,
- revision=revision,
- trust_remote_code=trust_remote_code,
+ token = token,
+ revision = revision,
+ trust_remote_code = trust_remote_code,
)
is_model = True
except Exception as error:
@@ -300,9 +300,9 @@ class FastLanguageModel(FastLlamaModel):
try:
peft_config = PeftConfig.from_pretrained(
model_name,
- token=token,
- revision=revision,
- trust_remote_code=trust_remote_code,
+ token = token,
+ revision = revision,
+ trust_remote_code = trust_remote_code,
)
is_peft = True
except Exception as error:
@@ -345,7 +345,7 @@ class FastLanguageModel(FastLlamaModel):
both_exist = exist_adapter_config and exist_config
else:
# Because HfFileSystem assumes linux paths, we need to set the path with forward slashes, even on Windows.
- files = HfFileSystem(token=token).glob(f"{model_name}/*.json")
+ files = HfFileSystem(token = token).glob(f"{model_name}/*.json")
files = list(os.path.split(x)[-1] for x in files)
if (
sum(x == "adapter_config.json" or x == "config.json" for x in files)
@@ -393,8 +393,8 @@ class FastLanguageModel(FastLlamaModel):
model_config = AutoConfig.from_pretrained(
model_name,
- token=token,
- trust_remote_code=trust_remote_code,
+ token = token,
+ trust_remote_code = trust_remote_code,
)
if not was_disabled:
@@ -484,42 +484,42 @@ class FastLanguageModel(FastLlamaModel):
# dispatch_model = FastGraniteModel
else:
return FastModel.from_pretrained(
- model_name=old_model_name,
- max_seq_length=max_seq_length,
- dtype=dtype,
- load_in_4bit=load_in_4bit,
- load_in_8bit=load_in_8bit,
- load_in_16bit=load_in_16bit,
- full_finetuning=full_finetuning,
- token=token,
- device_map=device_map,
- rope_scaling=rope_scaling, # [TODO] No effect
- fix_tokenizer=fix_tokenizer, # [TODO] No effect
- trust_remote_code=trust_remote_code,
- use_gradient_checkpointing=use_gradient_checkpointing,
- resize_model_vocab=resize_model_vocab, # [TODO] No effect
- revision=revision,
- return_logits=False, # Return logits
- fullgraph=True, # No graph breaks
- use_exact_model_name=use_exact_model_name,
- offload_embedding=offload_embedding,
- float32_mixed_precision=float32_mixed_precision,
+ model_name = old_model_name,
+ max_seq_length = max_seq_length,
+ dtype = dtype,
+ load_in_4bit = load_in_4bit,
+ load_in_8bit = load_in_8bit,
+ load_in_16bit = load_in_16bit,
+ full_finetuning = full_finetuning,
+ token = token,
+ device_map = device_map,
+ rope_scaling = rope_scaling, # [TODO] No effect
+ fix_tokenizer = fix_tokenizer, # [TODO] No effect
+ trust_remote_code = trust_remote_code,
+ use_gradient_checkpointing = use_gradient_checkpointing,
+ resize_model_vocab = resize_model_vocab, # [TODO] No effect
+ revision = revision,
+ return_logits = False, # Return logits
+ fullgraph = True, # No graph breaks
+ use_exact_model_name = use_exact_model_name,
+ offload_embedding = offload_embedding,
+ float32_mixed_precision = float32_mixed_precision,
# Pass vLLM/inference parameters
- fast_inference=fast_inference,
- gpu_memory_utilization=gpu_memory_utilization,
- float8_kv_cache=float8_kv_cache,
- random_state=random_state,
- max_lora_rank=max_lora_rank,
- disable_log_stats=disable_log_stats,
- qat_scheme=qat_scheme,
- load_in_fp8=load_in_fp8,
- unsloth_tiled_mlp=unsloth_tiled_mlp,
+ fast_inference = fast_inference,
+ gpu_memory_utilization = gpu_memory_utilization,
+ float8_kv_cache = float8_kv_cache,
+ random_state = random_state,
+ max_lora_rank = max_lora_rank,
+ disable_log_stats = disable_log_stats,
+ qat_scheme = qat_scheme,
+ load_in_fp8 = load_in_fp8,
+ unsloth_tiled_mlp = unsloth_tiled_mlp,
*args,
**kwargs,
)
if use_gradient_checkpointing == "unsloth":
- patch_unsloth_smart_gradient_checkpointing(dtype=dtype)
+ patch_unsloth_smart_gradient_checkpointing(dtype = dtype)
# Check if this is local model since the tokenizer gets overwritten
if (
@@ -535,24 +535,24 @@ class FastLanguageModel(FastLlamaModel):
fast_inference, model_name = fast_inference_setup(model_name, model_config)
model, tokenizer = dispatch_model.from_pretrained(
- model_name=model_name,
- max_seq_length=max_seq_length,
- dtype=_get_dtype(dtype),
- load_in_4bit=load_in_4bit,
- token=token,
- device_map=device_map,
- rope_scaling=rope_scaling,
- fix_tokenizer=fix_tokenizer,
- model_patcher=dispatch_model,
- tokenizer_name=tokenizer_name,
- trust_remote_code=trust_remote_code,
- revision=revision if not is_peft else None,
- fast_inference=fast_inference,
- gpu_memory_utilization=gpu_memory_utilization,
- float8_kv_cache=float8_kv_cache,
- random_state=random_state,
- max_lora_rank=max_lora_rank,
- disable_log_stats=disable_log_stats,
+ model_name = model_name,
+ max_seq_length = max_seq_length,
+ dtype = _get_dtype(dtype),
+ load_in_4bit = load_in_4bit,
+ token = token,
+ device_map = device_map,
+ rope_scaling = rope_scaling,
+ fix_tokenizer = fix_tokenizer,
+ model_patcher = dispatch_model,
+ tokenizer_name = tokenizer_name,
+ trust_remote_code = trust_remote_code,
+ revision = revision if not is_peft else None,
+ fast_inference = fast_inference,
+ gpu_memory_utilization = gpu_memory_utilization,
+ float8_kv_cache = float8_kv_cache,
+ random_state = random_state,
+ max_lora_rank = max_lora_rank,
+ disable_log_stats = disable_log_stats,
*args,
**kwargs,
)
@@ -601,10 +601,10 @@ class FastLanguageModel(FastLlamaModel):
model = PeftModel.from_pretrained(
model,
old_model_name,
- token=token,
- revision=revision,
- is_trainable=True,
- trust_remote_code=trust_remote_code,
+ token = token,
+ revision = revision,
+ is_trainable = True,
+ trust_remote_code = trust_remote_code,
)
# Patch it as well!
model = dispatch_model.patch_peft_model(model, use_gradient_checkpointing)
@@ -615,7 +615,7 @@ class FastLanguageModel(FastLlamaModel):
"UNSLOTH_TILED_MLP", "arctic" if unsloth_tiled_mlp else "0"
)
if patch_tiled_mlp_choice != "0" or unsloth_tiled_mlp:
- patch_tiled_mlp(model, patch_options_str=patch_tiled_mlp_choice)
+ patch_tiled_mlp(model, patch_options_str = patch_tiled_mlp_choice)
return model, tokenizer
@@ -645,40 +645,40 @@ class FastModel(FastBaseModel):
@staticmethod
def from_pretrained(
- model_name="unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
- max_seq_length=2048,
- dtype=None,
- load_in_4bit=True, # 4bit QLoRA
- load_in_8bit=False, # 8bit LoRA
- load_in_16bit=False, # 16bit LoRA
- full_finetuning=False,
- token=None,
- device_map="sequential",
- rope_scaling=None, # [TODO] No effect
- fix_tokenizer=True, # [TODO] No effect
- trust_remote_code=False,
- use_gradient_checkpointing="unsloth",
- resize_model_vocab=None, # [TODO] No effect
- revision=None,
- return_logits=False, # Return logits
- fullgraph=True, # No graph breaks
- use_exact_model_name=False,
- auto_model=None,
- whisper_language=None,
- whisper_task=None,
- unsloth_force_compile=False,
- offload_embedding=False,
- float32_mixed_precision=None, # Forces float32 mixed precision
+ model_name = "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
+ max_seq_length = 2048,
+ dtype = None,
+ load_in_4bit = True, # 4bit QLoRA
+ load_in_8bit = False, # 8bit LoRA
+ load_in_16bit = False, # 16bit LoRA
+ full_finetuning = False,
+ token = None,
+ device_map = "sequential",
+ rope_scaling = None, # [TODO] No effect
+ fix_tokenizer = True, # [TODO] No effect
+ trust_remote_code = False,
+ use_gradient_checkpointing = "unsloth",
+ resize_model_vocab = None, # [TODO] No effect
+ revision = None,
+ return_logits = False, # Return logits
+ fullgraph = True, # No graph breaks
+ use_exact_model_name = False,
+ auto_model = None,
+ whisper_language = None,
+ whisper_task = None,
+ unsloth_force_compile = False,
+ offload_embedding = False,
+ float32_mixed_precision = None, # Forces float32 mixed precision
# Add the missing vLLM/inference parameters
- fast_inference=False, # uses vLLM
- gpu_memory_utilization=0.5,
- float8_kv_cache=False,
- random_state=3407,
- max_lora_rank=64,
- disable_log_stats=True,
- qat_scheme=None,
- load_in_fp8=False, # fp8 LoRA (True, False, 'block')
- unsloth_tiled_mlp=False,
+ fast_inference = False, # uses vLLM
+ gpu_memory_utilization = 0.5,
+ float8_kv_cache = False,
+ random_state = 3407,
+ max_lora_rank = 64,
+ disable_log_stats = True,
+ qat_scheme = None,
+ load_in_fp8 = False, # fp8 LoRA (True, False, 'block')
+ unsloth_tiled_mlp = False,
*args,
**kwargs,
):
@@ -689,7 +689,7 @@ class FastModel(FastBaseModel):
try:
from huggingface_hub import login
- login(token=token)
+ login(token = token)
except:
pass
if whisper_language is not None:
@@ -765,7 +765,7 @@ class FastModel(FastBaseModel):
fp8_mode = None
if not use_exact_model_name:
new_model_name = get_model_name(
- model_name, load_in_4bit=load_in_4bit, load_in_fp8=load_in_fp8
+ model_name, load_in_4bit = load_in_4bit, load_in_fp8 = load_in_fp8
)
if new_model_name is None and load_in_fp8 != False:
fp8_mode = _get_fp8_mode_and_check_settings(
@@ -819,9 +819,9 @@ class FastModel(FastBaseModel):
try:
model_config = AutoConfig.from_pretrained(
model_name,
- token=token,
- revision=revision,
- trust_remote_code=trust_remote_code,
+ token = token,
+ revision = revision,
+ trust_remote_code = trust_remote_code,
)
is_model = True
except Exception as error:
@@ -835,9 +835,9 @@ class FastModel(FastBaseModel):
try:
peft_config = PeftConfig.from_pretrained(
model_name,
- token=token,
- revision=revision,
- trust_remote_code=trust_remote_code,
+ token = token,
+ revision = revision,
+ trust_remote_code = trust_remote_code,
)
is_peft = True
except Exception as error:
@@ -1011,7 +1011,7 @@ class FastModel(FastBaseModel):
exist_config = os.path.exists(os.path.join(model_name, "config.json"))
both_exist = exist_adapter_config and exist_config
else:
- files = HfFileSystem(token=token).glob(f"{model_name}/*.json")
+ files = HfFileSystem(token = token).glob(f"{model_name}/*.json")
files = list(os.path.split(x)[-1] for x in files)
if (
sum(x == "adapter_config.json" or x == "config.json" for x in files)
@@ -1059,8 +1059,8 @@ class FastModel(FastBaseModel):
model_config = AutoConfig.from_pretrained(
model_name,
- token=token,
- trust_remote_code=trust_remote_code,
+ token = token,
+ trust_remote_code = trust_remote_code,
)
if not was_disabled:
@@ -1092,40 +1092,40 @@ class FastModel(FastBaseModel):
break
# Patch gradient checkpointing
if use_gradient_checkpointing == "unsloth":
- patch_unsloth_smart_gradient_checkpointing(dtype=dtype)
+ patch_unsloth_smart_gradient_checkpointing(dtype = dtype)
with redirector:
- patch_loss_functions(torch_compile=False)
+ patch_loss_functions(torch_compile = False)
model_types, supports_sdpa = unsloth_compile_transformers(
- dtype=dtype,
- model_name=model_name,
- model_types=model_types,
- token=token,
- sdpa_dynamic_mask=True,
- sdpa_bool_masks=True,
- sdpa_gqa_replace=True,
- sdpa_dynamic_compile=True,
- compile_attention=True,
- disable_causal_masks=True,
- compile_torch_modules=True,
- compile_custom_modules=True,
- compile_function_calls=True,
- fuse_lm_head=True,
- gradient_checkpointing=True,
- manual_replacements=True,
- fast_lora_forwards=True,
- fast_residual_stream=False,
- accurate_accumulation=True,
- epilogue_fusion=True,
- max_autotune=False,
- shape_padding=True,
- cudagraphs=False,
- debug=False,
- fullgraph=fullgraph,
- import_from_cache=False,
- disable=False,
- return_logits=return_logits,
- trust_remote_code=trust_remote_code,
- unsloth_force_compile=unsloth_force_compile,
+ dtype = dtype,
+ model_name = model_name,
+ model_types = model_types,
+ token = token,
+ sdpa_dynamic_mask = True,
+ sdpa_bool_masks = True,
+ sdpa_gqa_replace = True,
+ sdpa_dynamic_compile = True,
+ compile_attention = True,
+ disable_causal_masks = True,
+ compile_torch_modules = True,
+ compile_custom_modules = True,
+ compile_function_calls = True,
+ fuse_lm_head = True,
+ gradient_checkpointing = True,
+ manual_replacements = True,
+ fast_lora_forwards = True,
+ fast_residual_stream = False,
+ accurate_accumulation = True,
+ epilogue_fusion = True,
+ max_autotune = False,
+ shape_padding = True,
+ cudagraphs = False,
+ debug = False,
+ fullgraph = fullgraph,
+ import_from_cache = False,
+ disable = False,
+ return_logits = return_logits,
+ trust_remote_code = trust_remote_code,
+ unsloth_force_compile = unsloth_force_compile,
)
# Fix SDPA issues
for model_type in DISABLE_SDPA_MODEL_NAMES:
@@ -1152,34 +1152,34 @@ class FastModel(FastBaseModel):
auto_model = AutoModelForVision2Seq if is_vlm else AutoModelForCausalLM
model, tokenizer = FastBaseModel.from_pretrained(
- model_name=model_name,
- max_seq_length=max_seq_length,
- dtype=_get_dtype(dtype),
- load_in_4bit=load_in_4bit,
- load_in_8bit=load_in_8bit,
- load_in_16bit=load_in_16bit,
- full_finetuning=full_finetuning,
- token=token,
- device_map=device_map,
- trust_remote_code=trust_remote_code,
- revision=revision if not is_peft else None,
- model_types=model_types,
- tokenizer_name=tokenizer_name,
- auto_model=auto_model,
- use_gradient_checkpointing=use_gradient_checkpointing,
- supports_sdpa=supports_sdpa,
- whisper_language=whisper_language,
- whisper_task=whisper_task,
- auto_config=model_config,
- offload_embedding=offload_embedding,
- float32_mixed_precision=float32_mixed_precision,
+ model_name = model_name,
+ max_seq_length = max_seq_length,
+ dtype = _get_dtype(dtype),
+ load_in_4bit = load_in_4bit,
+ load_in_8bit = load_in_8bit,
+ load_in_16bit = load_in_16bit,
+ full_finetuning = full_finetuning,
+ token = token,
+ device_map = device_map,
+ trust_remote_code = trust_remote_code,
+ revision = revision if not is_peft else None,
+ model_types = model_types,
+ tokenizer_name = tokenizer_name,
+ auto_model = auto_model,
+ use_gradient_checkpointing = use_gradient_checkpointing,
+ supports_sdpa = supports_sdpa,
+ whisper_language = whisper_language,
+ whisper_task = whisper_task,
+ auto_config = model_config,
+ offload_embedding = offload_embedding,
+ float32_mixed_precision = float32_mixed_precision,
# Pass vLLM/inference parameters
- fast_inference=fast_inference,
- gpu_memory_utilization=gpu_memory_utilization,
- float8_kv_cache=float8_kv_cache,
- random_state=random_state,
- max_lora_rank=max_lora_rank,
- disable_log_stats=disable_log_stats,
+ fast_inference = fast_inference,
+ gpu_memory_utilization = gpu_memory_utilization,
+ float8_kv_cache = float8_kv_cache,
+ random_state = random_state,
+ max_lora_rank = max_lora_rank,
+ disable_log_stats = disable_log_stats,
*args,
**kwargs,
)
@@ -1228,14 +1228,14 @@ class FastModel(FastBaseModel):
model = PeftModel.from_pretrained(
model,
old_model_name,
- token=token,
- revision=revision,
- is_trainable=True,
- trust_remote_code=trust_remote_code,
+ token = token,
+ revision = revision,
+ is_trainable = True,
+ trust_remote_code = trust_remote_code,
)
# Patch it as well!
model = FastBaseModel.post_patch_model(
- model, use_gradient_checkpointing, trust_remote_code=trust_remote_code
+ model, use_gradient_checkpointing, trust_remote_code = trust_remote_code
)
# Apply QAT if specified
@@ -1249,7 +1249,7 @@ class FastModel(FastBaseModel):
"UNSLOTH_TILED_MLP", "arctic" if unsloth_tiled_mlp else "0"
)
if patch_tiled_mlp_choice != "0" or unsloth_tiled_mlp:
- patch_tiled_mlp(model, patch_options_str=patch_tiled_mlp_choice)
+ patch_tiled_mlp(model, patch_options_str = patch_tiled_mlp_choice)
return model, tokenizer
diff --git a/unsloth/models/qwen2.py b/unsloth/models/qwen2.py
index 5207e9708..3f819d6dc 100644
--- a/unsloth/models/qwen2.py
+++ b/unsloth/models/qwen2.py
@@ -39,10 +39,10 @@ class FastQwen2Model(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
- model_name="qwen2",
- rope_module=LlamaRotaryEmbedding,
- scaled_rope_module=LlamaLinearScalingRotaryEmbedding,
- attention_module=Qwen2Attention,
+ model_name = "qwen2",
+ rope_module = LlamaRotaryEmbedding,
+ scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
+ attention_module = Qwen2Attention,
)
if init_name is not None:
exec(function, globals())
@@ -72,30 +72,30 @@ class FastQwen2Model(FastLlamaModel):
@staticmethod
def from_pretrained(
- model_name="Qwen/Qwen2-7B",
- max_seq_length=4096,
- dtype=None,
- load_in_4bit=True,
- token=None,
- device_map="sequential",
- rope_scaling=None, # Qwen2 does not support RoPE scaling
- fix_tokenizer=True,
- model_patcher=None,
- tokenizer_name=None,
- trust_remote_code=False,
+ model_name = "Qwen/Qwen2-7B",
+ max_seq_length = 4096,
+ dtype = None,
+ load_in_4bit = True,
+ token = None,
+ device_map = "sequential",
+ rope_scaling = None, # Qwen2 does not support RoPE scaling
+ fix_tokenizer = True,
+ model_patcher = None,
+ tokenizer_name = None,
+ trust_remote_code = False,
**kwargs,
):
return FastLlamaModel.from_pretrained(
- model_name=model_name,
- max_seq_length=max_seq_length,
- dtype=dtype,
- load_in_4bit=load_in_4bit,
- token=token,
- device_map=device_map,
- rope_scaling=rope_scaling,
- fix_tokenizer=fix_tokenizer,
- model_patcher=FastQwen2Model,
- tokenizer_name=tokenizer_name,
- trust_remote_code=trust_remote_code,
+ model_name = model_name,
+ max_seq_length = max_seq_length,
+ dtype = dtype,
+ load_in_4bit = load_in_4bit,
+ token = token,
+ device_map = device_map,
+ rope_scaling = rope_scaling,
+ fix_tokenizer = fix_tokenizer,
+ model_patcher = FastQwen2Model,
+ tokenizer_name = tokenizer_name,
+ trust_remote_code = trust_remote_code,
**kwargs,
)
diff --git a/unsloth/models/qwen3.py b/unsloth/models/qwen3.py
index 5c935203d..6f41579d8 100644
--- a/unsloth/models/qwen3.py
+++ b/unsloth/models/qwen3.py
@@ -115,7 +115,7 @@ def Qwen3Attention_fast_forward(
else:
# Extend RoPE dynamically to fit in VRA
rotary_emb = self.rotary_emb
- rotary_emb.extend_rope_embedding(V, seq_len=kv_seq_len)
+ rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
device_index = Q.device.index
if position_ids is None:
@@ -126,8 +126,8 @@ def Qwen3Attention_fast_forward(
Q, K = fast_rope_embedding(Q, K, cos, sin)
if past_key_value is not None:
- K = torch.cat([past_key_value[0], K], dim=2)
- V = torch.cat([past_key_value[1], V], dim=2)
+ K = torch.cat([past_key_value[0], K], dim = 2)
+ V = torch.cat([past_key_value[1], V], dim = 2)
past_key_value = (K, V) if use_cache else None
# Attention module
@@ -163,7 +163,7 @@ def Qwen3Attention_fast_forward(
K = K.view(1, K_M, n_kv_heads, n_groups, head_dim)
V = V.view(1, V_M, n_kv_heads, n_groups, head_dim)
- A = xformers_attention(Q, K, V, attn_bias=causal_mask)
+ A = xformers_attention(Q, K, V, attn_bias = causal_mask)
A = A.view(bsz, q_len, n_heads, head_dim)
elif HAS_FLASH_ATTENTION and attention_mask is None:
@@ -172,7 +172,7 @@ def Qwen3Attention_fast_forward(
V = V.transpose(1, 2)
sw = kv_seq_len
window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
- A = flash_attn_func(Q, K, V, causal=True, window_size=window)
+ A = flash_attn_func(Q, K, V, causal = True, window_size = window)
else:
# Grouped query attention
# if n_groups != 1:
@@ -195,7 +195,7 @@ def Qwen3Attention_fast_forward(
is_causal = False
A = scaled_dot_product_attention(
- Q, K, V, attn_mask=attention_mask, is_causal=is_causal
+ Q, K, V, attn_mask = attention_mask, is_causal = is_causal
)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
@@ -214,8 +214,8 @@ def Qwen3Attention_fast_forward_inference(
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
- do_prefill=False,
- attention_mask=None,
+ do_prefill = False,
+ attention_mask = None,
):
"""
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
@@ -267,29 +267,29 @@ def Qwen3Attention_fast_forward_inference(
if do_prefill:
self.paged_attention = torch.empty(
(KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
- dtype=dtype,
- device=device,
+ dtype = dtype,
+ device = device,
)
self.paged_attention_K = self.paged_attention[:, 0]
self.paged_attention_V = self.paged_attention[:, 1]
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
self.temp_QA = torch.empty(
- (2, bsz, 1, attention_size), dtype=dtype, device=device
+ (2, bsz, 1, attention_size), dtype = dtype, device = device
)
self.temp_KV = torch.empty(
- (2, bsz, 1, n_kv_heads * head_dim), dtype=dtype, device=device
+ (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
)
- self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype=dtype, device=device)
+ self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
# Mistral Nemo 12b has weird dimensions
if attention_size != hidden_size:
- self.temp_O = torch.empty((1, bsz, hidden_size), dtype=dtype, device=device)
+ self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
else:
self.temp_O = self.temp_QA[1][:, :, :hidden_size]
self.attention = torch.empty(
- (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype=dtype, device=device
+ (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
)
self.scalar = 1.0 / math_sqrt(self.head_dim)
self.half_head_dim = head_dim // 2
@@ -309,9 +309,9 @@ def Qwen3Attention_fast_forward_inference(
(bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
)
- Qn = fast_linear_forward(self.q_proj, Xn, out=self.temp_QA[0])
- Kn = fast_linear_forward(self.k_proj, Xn, out=self.temp_KV[0])
- Vn = fast_linear_forward(self.v_proj, Xn, out=self.temp_KV[1])
+ Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
+ Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
+ Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
Qn = Qn.view(
bsz, 1, n_heads, head_dim
) # .transpose(1, 2) # we will transpose after normalisation
@@ -399,30 +399,30 @@ def Qwen3Attention_fast_forward_inference(
Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
A = torch_matmul(
- Qn, Knn.transpose(2, 3), out=self.attention[:, :, :, :cached_len]
+ Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]
)
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
A[:] = torch_nn_functional_softmax(
- A, dim=-1, dtype=torch.float32
+ A, dim = -1, dtype = torch.float32
) # .to(A.dtype)
- A = torch_matmul(A, Vnn, out=Qn)
+ A = torch_matmul(A, Vnn, out = Qn)
else:
if SDPA_HAS_GQA:
A = scaled_dot_product_attention(
Qn,
Knn,
Vnn,
- attn_mask=attention_mask,
- is_causal=is_causal,
- enable_gqa=True,
+ attn_mask = attention_mask,
+ is_causal = is_causal,
+ enable_gqa = True,
)
else:
A = scaled_dot_product_attention(
- Qn, Knn, Vnn, attn_mask=attention_mask, is_causal=is_causal
+ Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = is_causal
)
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
- A = fast_linear_forward(self.o_proj, A, out=self.temp_O)
+ A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
return A, (Kn, Vn)
@@ -430,10 +430,10 @@ class FastQwen3Model(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
- model_name="Qwen3",
- rope_module=LlamaRotaryEmbedding,
- scaled_rope_module=LlamaLinearScalingRotaryEmbedding,
- attention_module=Qwen3Attention,
+ model_name = "Qwen3",
+ rope_module = LlamaRotaryEmbedding,
+ scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
+ attention_module = Qwen3Attention,
)
if init_name is not None:
exec(function, globals())
@@ -463,30 +463,30 @@ class FastQwen3Model(FastLlamaModel):
@staticmethod
def from_pretrained( # TODO: Change after release
- model_name="Qwen/Qwen3-7B",
- max_seq_length=4096,
- dtype=None,
- load_in_4bit=True,
- token=None,
- device_map="sequential",
- rope_scaling=None,
- fix_tokenizer=True,
- model_patcher=None,
- tokenizer_name=None,
- trust_remote_code=False,
+ model_name = "Qwen/Qwen3-7B",
+ max_seq_length = 4096,
+ dtype = None,
+ load_in_4bit = True,
+ token = None,
+ device_map = "sequential",
+ rope_scaling = None,
+ fix_tokenizer = True,
+ model_patcher = None,
+ tokenizer_name = None,
+ trust_remote_code = False,
**kwargs,
):
return FastLlamaModel.from_pretrained(
- model_name=model_name,
- max_seq_length=max_seq_length,
- dtype=dtype,
- load_in_4bit=load_in_4bit,
- token=token,
- device_map=device_map,
- rope_scaling=rope_scaling,
- fix_tokenizer=fix_tokenizer,
- model_patcher=FastQwen3Model,
- tokenizer_name=tokenizer_name,
- trust_remote_code=trust_remote_code,
+ model_name = model_name,
+ max_seq_length = max_seq_length,
+ dtype = dtype,
+ load_in_4bit = load_in_4bit,
+ token = token,
+ device_map = device_map,
+ rope_scaling = rope_scaling,
+ fix_tokenizer = fix_tokenizer,
+ model_patcher = FastQwen3Model,
+ tokenizer_name = tokenizer_name,
+ trust_remote_code = trust_remote_code,
**kwargs,
)
diff --git a/unsloth/models/qwen3_moe.py b/unsloth/models/qwen3_moe.py
index ae0bf9dee..bec3fa7b0 100644
--- a/unsloth/models/qwen3_moe.py
+++ b/unsloth/models/qwen3_moe.py
@@ -49,29 +49,29 @@ from unsloth_zoo.utils import Version, _get_dtype
torch_nn_functional_softmax = torch.nn.functional.softmax
-def Qwen3MoeSparseMoeBlock_fast_forward(self, X, temp_gate=None, temp_up=None):
+def Qwen3MoeSparseMoeBlock_fast_forward(self, X, temp_gate = None, temp_up = None):
# adapted from https://github.com/huggingface/transformers/pull/36878/files#diff-0855b77fc27ad9449158a1c74953f909b011c00de7125f7c8e68d0ff209c092aR356-R370
bsz, seq_len, hd = X.shape
X = X.view(-1, hd)
router_logits = fast_linear_forward(
- self.gate_proj, X, out=temp_gate
+ self.gate_proj, X, out = temp_gate
) # pretty much the only change from transformers implementation.
routing_weights = torch_nn_functional_softmax(
- router_logits, dim=-1, dtype=torch.float32
+ router_logits, dim = -1, dtype = torch.float32
)
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim = -1)
+ routing_weights /= routing_weights.sum(dim = -1, keepdim = True)
# we cast back to the input dtype
routing_weights = routing_weights.to(X.dtype)
- final_X = torch.zeros((bsz * seq_len, hd), dtype=torch.float32, device=X.device)
+ final_X = torch.zeros((bsz * seq_len, hd), dtype = torch.float32, device = X.device)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(
- selected_experts, num_classes=self.num_experts
+ selected_experts, num_classes = self.num_experts
).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
@@ -119,16 +119,16 @@ def Qwen3MoeDecoderLayer_fast_forward(
self.input_layernorm, hidden_states
)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- causal_mask=causal_mask,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- padding_mask=padding_mask,
- position_embeddings=position_embeddings,
- _flag_for_generation=self._flag_for_generation,
+ hidden_states = hidden_states,
+ causal_mask = causal_mask,
+ attention_mask = attention_mask,
+ position_ids = position_ids,
+ past_key_value = past_key_value,
+ output_attentions = output_attentions,
+ use_cache = use_cache,
+ padding_mask = padding_mask,
+ position_embeddings = position_embeddings,
+ _flag_for_generation = self._flag_for_generation,
)
hidden_states += residual
@@ -145,15 +145,15 @@ def Qwen3MoeDecoderLayer_fast_forward(
residual = hidden_states
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- causal_mask=causal_mask,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- padding_mask=padding_mask,
- position_embeddings=position_embeddings,
+ hidden_states = hidden_states,
+ causal_mask = causal_mask,
+ attention_mask = attention_mask,
+ position_ids = position_ids,
+ past_key_value = past_key_value,
+ output_attentions = output_attentions,
+ use_cache = use_cache,
+ padding_mask = padding_mask,
+ position_embeddings = position_embeddings,
)
hidden_states = residual + hidden_states
@@ -177,10 +177,10 @@ class FastQwen3MoeModel(FastQwen3Model):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
- model_name="Qwen3Moe",
- rope_module=LlamaRotaryEmbedding,
- scaled_rope_module=LlamaLinearScalingRotaryEmbedding,
- attention_module=Qwen3MoeAttention,
+ model_name = "Qwen3Moe",
+ rope_module = LlamaRotaryEmbedding,
+ scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
+ attention_module = Qwen3MoeAttention,
)
if init_name is not None:
exec(function, globals())
@@ -214,30 +214,30 @@ class FastQwen3MoeModel(FastQwen3Model):
@staticmethod
def from_pretrained( # TODO: Change after release
- model_name="Qwen/Qwen3-7B",
- max_seq_length=4096,
- dtype=None,
- load_in_4bit=True,
- token=None,
- device_map="sequential",
- rope_scaling=None,
- fix_tokenizer=True,
- model_patcher=None,
- tokenizer_name=None,
- trust_remote_code=False,
+ model_name = "Qwen/Qwen3-7B",
+ max_seq_length = 4096,
+ dtype = None,
+ load_in_4bit = True,
+ token = None,
+ device_map = "sequential",
+ rope_scaling = None,
+ fix_tokenizer = True,
+ model_patcher = None,
+ tokenizer_name = None,
+ trust_remote_code = False,
**kwargs,
):
return FastLlamaModel.from_pretrained(
- model_name=model_name,
- max_seq_length=max_seq_length,
- dtype=dtype,
- load_in_4bit=load_in_4bit,
- token=token,
- device_map=device_map,
- rope_scaling=rope_scaling,
- fix_tokenizer=fix_tokenizer,
- model_patcher=FastQwen3Model,
- tokenizer_name=tokenizer_name,
- trust_remote_code=trust_remote_code,
+ model_name = model_name,
+ max_seq_length = max_seq_length,
+ dtype = dtype,
+ load_in_4bit = load_in_4bit,
+ token = token,
+ device_map = device_map,
+ rope_scaling = rope_scaling,
+ fix_tokenizer = fix_tokenizer,
+ model_patcher = FastQwen3Model,
+ tokenizer_name = tokenizer_name,
+ trust_remote_code = trust_remote_code,
**kwargs,
)
diff --git a/unsloth/registry/_deepseek.py b/unsloth/registry/_deepseek.py
index bb6642456..4bbc852cd 100644
--- a/unsloth/registry/_deepseek.py
+++ b/unsloth/registry/_deepseek.py
@@ -30,70 +30,70 @@ class DeepseekR1ModelInfo(ModelInfo):
# Deepseek V3 Model Meta
DeepseekV3Meta = ModelMeta(
- org="deepseek-ai",
- base_name="DeepSeek",
- instruct_tags=[None],
- model_version="3",
- model_sizes=[""],
- model_info_cls=DeepseekV3ModelInfo,
- is_multimodal=False,
- quant_types=[QuantType.NONE, QuantType.BF16],
+ org = "deepseek-ai",
+ base_name = "DeepSeek",
+ instruct_tags = [None],
+ model_version = "3",
+ model_sizes = [""],
+ model_info_cls = DeepseekV3ModelInfo,
+ is_multimodal = False,
+ quant_types = [QuantType.NONE, QuantType.BF16],
)
DeepseekV3_0324Meta = ModelMeta(
- org="deepseek-ai",
- base_name="DeepSeek",
- instruct_tags=[None],
- model_version="3-0324",
- model_sizes=[""],
- model_info_cls=DeepseekV3ModelInfo,
- is_multimodal=False,
- quant_types=[QuantType.NONE, QuantType.GGUF],
+ org = "deepseek-ai",
+ base_name = "DeepSeek",
+ instruct_tags = [None],
+ model_version = "3-0324",
+ model_sizes = [""],
+ model_info_cls = DeepseekV3ModelInfo,
+ is_multimodal = False,
+ quant_types = [QuantType.NONE, QuantType.GGUF],
)
DeepseekR1Meta = ModelMeta(
- org="deepseek-ai",
- base_name="DeepSeek-R1",
- instruct_tags=[None],
- model_version="",
- model_sizes=[""],
- model_info_cls=DeepseekR1ModelInfo,
- is_multimodal=False,
- quant_types=[QuantType.NONE, QuantType.BF16, QuantType.GGUF],
+ org = "deepseek-ai",
+ base_name = "DeepSeek-R1",
+ instruct_tags = [None],
+ model_version = "",
+ model_sizes = [""],
+ model_info_cls = DeepseekR1ModelInfo,
+ is_multimodal = False,
+ quant_types = [QuantType.NONE, QuantType.BF16, QuantType.GGUF],
)
DeepseekR1ZeroMeta = ModelMeta(
- org="deepseek-ai",
- base_name="DeepSeek-R1",
- instruct_tags=[None],
- model_version="Zero",
- model_sizes=[""],
- model_info_cls=DeepseekR1ModelInfo,
- is_multimodal=False,
- quant_types=[QuantType.NONE, QuantType.GGUF],
+ org = "deepseek-ai",
+ base_name = "DeepSeek-R1",
+ instruct_tags = [None],
+ model_version = "Zero",
+ model_sizes = [""],
+ model_info_cls = DeepseekR1ModelInfo,
+ is_multimodal = False,
+ quant_types = [QuantType.NONE, QuantType.GGUF],
)
DeepseekR1DistillLlamaMeta = ModelMeta(
- org="deepseek-ai",
- base_name="DeepSeek-R1-Distill",
- instruct_tags=[None],
- model_version="Llama",
- model_sizes=["8", "70"],
- model_info_cls=DeepseekR1ModelInfo,
- is_multimodal=False,
- quant_types={"8": [QuantType.UNSLOTH, QuantType.GGUF], "70": [QuantType.GGUF]},
+ org = "deepseek-ai",
+ base_name = "DeepSeek-R1-Distill",
+ instruct_tags = [None],
+ model_version = "Llama",
+ model_sizes = ["8", "70"],
+ model_info_cls = DeepseekR1ModelInfo,
+ is_multimodal = False,
+ quant_types = {"8": [QuantType.UNSLOTH, QuantType.GGUF], "70": [QuantType.GGUF]},
)
# Deepseek R1 Distill Qwen Model Meta
DeepseekR1DistillQwenMeta = ModelMeta(
- org="deepseek-ai",
- base_name="DeepSeek-R1-Distill",
- instruct_tags=[None],
- model_version="Qwen",
- model_sizes=["1.5", "7", "14", "32"],
- model_info_cls=DeepseekR1ModelInfo,
- is_multimodal=False,
- quant_types={
+ org = "deepseek-ai",
+ base_name = "DeepSeek-R1-Distill",
+ instruct_tags = [None],
+ model_version = "Qwen",
+ model_sizes = ["1.5", "7", "14", "32"],
+ model_info_cls = DeepseekR1ModelInfo,
+ is_multimodal = False,
+ quant_types = {
"1.5": [QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF],
"7": [QuantType.UNSLOTH, QuantType.BNB],
"14": [QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF],
@@ -106,7 +106,7 @@ def register_deepseek_v3_models(include_original_model: bool = False):
global _IS_DEEPSEEK_V3_REGISTERED
if _IS_DEEPSEEK_V3_REGISTERED:
return
- _register_models(DeepseekV3Meta, include_original_model=include_original_model)
+ _register_models(DeepseekV3Meta, include_original_model = include_original_model)
_IS_DEEPSEEK_V3_REGISTERED = True
@@ -114,7 +114,7 @@ def register_deepseek_v3_0324_models(include_original_model: bool = False):
global _IS_DEEPSEEK_V3_0324_REGISTERED
if _IS_DEEPSEEK_V3_0324_REGISTERED:
return
- _register_models(DeepseekV3_0324Meta, include_original_model=include_original_model)
+ _register_models(DeepseekV3_0324Meta, include_original_model = include_original_model)
_IS_DEEPSEEK_V3_0324_REGISTERED = True
@@ -122,7 +122,7 @@ def register_deepseek_r1_models(include_original_model: bool = False):
global _IS_DEEPSEEK_R1_REGISTERED
if _IS_DEEPSEEK_R1_REGISTERED:
return
- _register_models(DeepseekR1Meta, include_original_model=include_original_model)
+ _register_models(DeepseekR1Meta, include_original_model = include_original_model)
_IS_DEEPSEEK_R1_REGISTERED = True
@@ -130,7 +130,7 @@ def register_deepseek_r1_zero_models(include_original_model: bool = False):
global _IS_DEEPSEEK_R1_ZERO_REGISTERED
if _IS_DEEPSEEK_R1_ZERO_REGISTERED:
return
- _register_models(DeepseekR1ZeroMeta, include_original_model=include_original_model)
+ _register_models(DeepseekR1ZeroMeta, include_original_model = include_original_model)
_IS_DEEPSEEK_R1_ZERO_REGISTERED = True
@@ -139,7 +139,7 @@ def register_deepseek_r1_distill_llama_models(include_original_model: bool = Fal
if _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED:
return
_register_models(
- DeepseekR1DistillLlamaMeta, include_original_model=include_original_model
+ DeepseekR1DistillLlamaMeta, include_original_model = include_original_model
)
_IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED = True
@@ -149,21 +149,21 @@ def register_deepseek_r1_distill_qwen_models(include_original_model: bool = Fals
if _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED:
return
_register_models(
- DeepseekR1DistillQwenMeta, include_original_model=include_original_model
+ DeepseekR1DistillQwenMeta, include_original_model = include_original_model
)
_IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED = True
def register_deepseek_models(include_original_model: bool = False):
- register_deepseek_v3_models(include_original_model=include_original_model)
- register_deepseek_v3_0324_models(include_original_model=include_original_model)
- register_deepseek_r1_models(include_original_model=include_original_model)
- register_deepseek_r1_zero_models(include_original_model=include_original_model)
+ register_deepseek_v3_models(include_original_model = include_original_model)
+ register_deepseek_v3_0324_models(include_original_model = include_original_model)
+ register_deepseek_r1_models(include_original_model = include_original_model)
+ register_deepseek_r1_zero_models(include_original_model = include_original_model)
register_deepseek_r1_distill_llama_models(
- include_original_model=include_original_model
+ include_original_model = include_original_model
)
register_deepseek_r1_distill_qwen_models(
- include_original_model=include_original_model
+ include_original_model = include_original_model
)
@@ -172,7 +172,7 @@ def _list_deepseek_r1_distill_models():
from unsloth.utils.hf_hub import list_models
models: list[HfModelInfo] = list_models(
- author="unsloth", search="Distill", limit=1000
+ author = "unsloth", search = "Distill", limit = 1000
)
distill_models = []
for model in models:
@@ -185,14 +185,14 @@ def _list_deepseek_r1_distill_models():
return distill_models
-register_deepseek_models(include_original_model=True)
+register_deepseek_models(include_original_model = True)
if __name__ == "__main__":
from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info
MODEL_REGISTRY.clear()
- register_deepseek_models(include_original_model=True)
+ register_deepseek_models(include_original_model = True)
for model_id, model_info in MODEL_REGISTRY.items():
model_info = _check_model_info(model_id)
diff --git a/unsloth/registry/_gemma.py b/unsloth/registry/_gemma.py
index 775b4b762..c338128bc 100644
--- a/unsloth/registry/_gemma.py
+++ b/unsloth/registry/_gemma.py
@@ -15,26 +15,26 @@ class GemmaModelInfo(ModelInfo):
# Gemma3 Base Model Meta
GemmaMeta3Base = ModelMeta(
- org="google",
- base_name="gemma",
- instruct_tags=["pt"], # pt = base
- model_version="3",
- model_sizes=["1", "4", "12", "27"],
- model_info_cls=GemmaModelInfo,
- is_multimodal=True,
- quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
+ org = "google",
+ base_name = "gemma",
+ instruct_tags = ["pt"], # pt = base
+ model_version = "3",
+ model_sizes = ["1", "4", "12", "27"],
+ model_info_cls = GemmaModelInfo,
+ is_multimodal = True,
+ quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
)
# Gemma3 Instruct Model Meta
GemmaMeta3Instruct = ModelMeta(
- org="google",
- base_name="gemma",
- instruct_tags=["it"], # it = instruction tuned
- model_version="3",
- model_sizes=["1", "4", "12", "27"],
- model_info_cls=GemmaModelInfo,
- is_multimodal=True,
- quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
+ org = "google",
+ base_name = "gemma",
+ instruct_tags = ["it"], # it = instruction tuned
+ model_version = "3",
+ model_sizes = ["1", "4", "12", "27"],
+ model_info_cls = GemmaModelInfo,
+ is_multimodal = True,
+ quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
)
@@ -42,7 +42,7 @@ def register_gemma_3_base_models(include_original_model: bool = False):
global _IS_GEMMA_3_BASE_REGISTERED
if _IS_GEMMA_3_BASE_REGISTERED:
return
- _register_models(GemmaMeta3Base, include_original_model=include_original_model)
+ _register_models(GemmaMeta3Base, include_original_model = include_original_model)
_IS_GEMMA_3_BASE_REGISTERED = True
@@ -50,13 +50,13 @@ def register_gemma_3_instruct_models(include_original_model: bool = False):
global _IS_GEMMA_3_INSTRUCT_REGISTERED
if _IS_GEMMA_3_INSTRUCT_REGISTERED:
return
- _register_models(GemmaMeta3Instruct, include_original_model=include_original_model)
+ _register_models(GemmaMeta3Instruct, include_original_model = include_original_model)
_IS_GEMMA_3_INSTRUCT_REGISTERED = True
def register_gemma_models(include_original_model: bool = False):
- register_gemma_3_base_models(include_original_model=include_original_model)
- register_gemma_3_instruct_models(include_original_model=include_original_model)
+ register_gemma_3_base_models(include_original_model = include_original_model)
+ register_gemma_3_instruct_models(include_original_model = include_original_model)
if __name__ == "__main__":
@@ -64,7 +64,7 @@ if __name__ == "__main__":
MODEL_REGISTRY.clear()
- register_gemma_models(include_original_model=True)
+ register_gemma_models(include_original_model = True)
for model_id, model_info in MODEL_REGISTRY.items():
model_info = _check_model_info(model_id)
diff --git a/unsloth/registry/_mistral.py b/unsloth/registry/_mistral.py
index 817fc1d5c..173d6cfde 100644
--- a/unsloth/registry/_mistral.py
+++ b/unsloth/registry/_mistral.py
@@ -23,14 +23,14 @@ class MistralSmallModelInfo(ModelInfo):
MistralSmall_2503_Base_Meta = ModelMeta(
- org="mistralai",
- base_name="Mistral-Small",
- instruct_tags=["Base"],
- model_version=_MISTRAL_SMALL_03_25_VERSION,
- model_sizes=["24"],
- model_info_cls=MistralSmallModelInfo,
- is_multimodal=False,
- quant_types=[QuantType.NONE, QuantType.UNSLOTH, QuantType.BNB],
+ org = "mistralai",
+ base_name = "Mistral-Small",
+ instruct_tags = ["Base"],
+ model_version = _MISTRAL_SMALL_03_25_VERSION,
+ model_sizes = ["24"],
+ model_info_cls = MistralSmallModelInfo,
+ is_multimodal = False,
+ quant_types = [QuantType.NONE, QuantType.UNSLOTH, QuantType.BNB],
)
MistralSmall_2503_Instruct_Meta = copy.deepcopy(MistralSmall_2503_Base_Meta)
@@ -54,23 +54,23 @@ def register_mistral_small_models(include_original_model: bool = False):
if _IS_MISTRAL_SMALL_REGISTERED:
return
_register_models(
- MistralSmall_2503_Base_Meta, include_original_model=include_original_model
+ MistralSmall_2503_Base_Meta, include_original_model = include_original_model
)
_register_models(
- MistralSmall_2503_Instruct_Meta, include_original_model=include_original_model
+ MistralSmall_2503_Instruct_Meta, include_original_model = include_original_model
)
_register_models(
- MistralSmall_2501_Base_Meta, include_original_model=include_original_model
+ MistralSmall_2501_Base_Meta, include_original_model = include_original_model
)
_register_models(
- MistralSmall_2501_Instruct_Meta, include_original_model=include_original_model
+ MistralSmall_2501_Instruct_Meta, include_original_model = include_original_model
)
_IS_MISTRAL_SMALL_REGISTERED = True
def register_mistral_models(include_original_model: bool = False):
- register_mistral_small_models(include_original_model=include_original_model)
+ register_mistral_small_models(include_original_model = include_original_model)
if __name__ == "__main__":
@@ -78,7 +78,7 @@ if __name__ == "__main__":
MODEL_REGISTRY.clear()
- register_mistral_models(include_original_model=True)
+ register_mistral_models(include_original_model = True)
for model_id, model_info in MODEL_REGISTRY.items():
model_info = _check_model_info(model_id)
diff --git a/unsloth/registry/_qwen.py b/unsloth/registry/_qwen.py
index fa4042bbe..f852cb876 100644
--- a/unsloth/registry/_qwen.py
+++ b/unsloth/registry/_qwen.py
@@ -43,50 +43,50 @@ class QwenQVQPreviewModelInfo(ModelInfo):
# Qwen2.5 Model Meta
Qwen_2_5_Meta = ModelMeta(
- org="Qwen",
- base_name="Qwen",
- instruct_tags=[None, "Instruct"],
- model_version="2.5",
- model_sizes=["3", "7"],
- model_info_cls=QwenModelInfo,
- is_multimodal=False,
- quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
+ org = "Qwen",
+ base_name = "Qwen",
+ instruct_tags = [None, "Instruct"],
+ model_version = "2.5",
+ model_sizes = ["3", "7"],
+ model_info_cls = QwenModelInfo,
+ is_multimodal = False,
+ quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
)
# Qwen2.5 VL Model Meta
Qwen_2_5_VLMeta = ModelMeta(
- org="Qwen",
- base_name="Qwen",
- instruct_tags=["Instruct"], # No base, only instruction tuned
- model_version="2.5",
- model_sizes=["3", "7", "32", "72"],
- model_info_cls=QwenVLModelInfo,
- is_multimodal=True,
- quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
+ org = "Qwen",
+ base_name = "Qwen",
+ instruct_tags = ["Instruct"], # No base, only instruction tuned
+ model_version = "2.5",
+ model_sizes = ["3", "7", "32", "72"],
+ model_info_cls = QwenVLModelInfo,
+ is_multimodal = True,
+ quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
)
# Qwen QwQ Model Meta
QwenQwQMeta = ModelMeta(
- org="Qwen",
- base_name="QwQ",
- instruct_tags=[None],
- model_version="",
- model_sizes=["32"],
- model_info_cls=QwenQwQModelInfo,
- is_multimodal=False,
- quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
+ org = "Qwen",
+ base_name = "QwQ",
+ instruct_tags = [None],
+ model_version = "",
+ model_sizes = ["32"],
+ model_info_cls = QwenQwQModelInfo,
+ is_multimodal = False,
+ quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
)
# Qwen QVQ Preview Model Meta
QwenQVQPreviewMeta = ModelMeta(
- org="Qwen",
- base_name="QVQ",
- instruct_tags=[None],
- model_version="",
- model_sizes=["72"],
- model_info_cls=QwenQVQPreviewModelInfo,
- is_multimodal=True,
- quant_types=[QuantType.NONE, QuantType.BNB],
+ org = "Qwen",
+ base_name = "QVQ",
+ instruct_tags = [None],
+ model_version = "",
+ model_sizes = ["72"],
+ model_info_cls = QwenQVQPreviewModelInfo,
+ is_multimodal = True,
+ quant_types = [QuantType.NONE, QuantType.BNB],
)
@@ -94,7 +94,7 @@ def register_qwen_2_5_models(include_original_model: bool = False):
global _IS_QWEN_2_5_REGISTERED
if _IS_QWEN_2_5_REGISTERED:
return
- _register_models(Qwen_2_5_Meta, include_original_model=include_original_model)
+ _register_models(Qwen_2_5_Meta, include_original_model = include_original_model)
_IS_QWEN_2_5_REGISTERED = True
@@ -102,7 +102,7 @@ def register_qwen_2_5_vl_models(include_original_model: bool = False):
global _IS_QWEN_2_5_VL_REGISTERED
if _IS_QWEN_2_5_VL_REGISTERED:
return
- _register_models(Qwen_2_5_VLMeta, include_original_model=include_original_model)
+ _register_models(Qwen_2_5_VLMeta, include_original_model = include_original_model)
_IS_QWEN_2_5_VL_REGISTERED = True
@@ -110,15 +110,15 @@ def register_qwen_qwq_models(include_original_model: bool = False):
global _IS_QWEN_QWQ_REGISTERED
if _IS_QWEN_QWQ_REGISTERED:
return
- _register_models(QwenQwQMeta, include_original_model=include_original_model)
- _register_models(QwenQVQPreviewMeta, include_original_model=include_original_model)
+ _register_models(QwenQwQMeta, include_original_model = include_original_model)
+ _register_models(QwenQVQPreviewMeta, include_original_model = include_original_model)
_IS_QWEN_QWQ_REGISTERED = True
def register_qwen_models(include_original_model: bool = False):
- register_qwen_2_5_models(include_original_model=include_original_model)
- register_qwen_2_5_vl_models(include_original_model=include_original_model)
- register_qwen_qwq_models(include_original_model=include_original_model)
+ register_qwen_2_5_models(include_original_model = include_original_model)
+ register_qwen_2_5_vl_models(include_original_model = include_original_model)
+ register_qwen_qwq_models(include_original_model = include_original_model)
if __name__ == "__main__":
@@ -126,7 +126,7 @@ if __name__ == "__main__":
MODEL_REGISTRY.clear()
- register_qwen_models(include_original_model=True)
+ register_qwen_models(include_original_model = True)
for model_id, model_info in MODEL_REGISTRY.items():
model_info = _check_model_info(model_id)
diff --git a/unsloth/registry/registry.py b/unsloth/registry/registry.py
index 880c9fb33..945301420 100644
--- a/unsloth/registry/registry.py
+++ b/unsloth/registry/registry.py
@@ -62,7 +62,7 @@ class ModelInfo:
@classmethod
def construct_model_name(
- cls, base_name, version, size, quant_type, instruct_tag, key=""
+ cls, base_name, version, size, quant_type, instruct_tag, key = ""
):
key = cls.append_instruct_tag(key, instruct_tag)
key = cls.append_quant_type(key, quant_type)
@@ -81,10 +81,10 @@ class ModelMeta:
base_name: str
model_version: str
model_info_cls: type[ModelInfo]
- model_sizes: list[str] = field(default_factory=list)
- instruct_tags: list[str] = field(default_factory=list)
+ model_sizes: list[str] = field(default_factory = list)
+ instruct_tags: list[str] = field(default_factory = list)
quant_types: list[QuantType] | dict[str, list[QuantType]] = field(
- default_factory=list
+ default_factory = list
)
is_multimodal: bool = False
@@ -104,11 +104,11 @@ def register_model(
name: str = None,
):
name = name or model_info_cls.construct_model_name(
- base_name=base_name,
- version=version,
- size=size,
- quant_type=quant_type,
- instruct_tag=instruct_tag,
+ base_name = base_name,
+ version = version,
+ size = size,
+ quant_type = quant_type,
+ instruct_tag = instruct_tag,
)
key = f"{org}/{name}"
@@ -118,14 +118,14 @@ def register_model(
)
MODEL_REGISTRY[key] = model_info_cls(
- org=org,
- base_name=base_name,
- version=version,
- size=size,
- is_multimodal=is_multimodal,
- instruct_tag=instruct_tag,
- quant_type=quant_type,
- name=name,
+ org = org,
+ base_name = base_name,
+ version = version,
+ size = size,
+ is_multimodal = is_multimodal,
+ instruct_tag = instruct_tag,
+ quant_type = quant_type,
+ name = name,
)
@@ -137,7 +137,7 @@ def _check_model_info(model_id: str, properties: list[str] = ["lastModified"]):
api = HfApi()
try:
- model_info: HfModelInfo = api.model_info(model_id, expand=properties)
+ model_info: HfModelInfo = api.model_info(model_id, expand = properties)
except Exception as e:
if isinstance(e, RepositoryNotFoundError):
warnings.warn(f"{model_id} not found on Hugging Face")
@@ -168,24 +168,24 @@ def _register_models(model_meta: ModelMeta, include_original_model: bool = False
# NOTE: models registered with org="unsloth" and QUANT_TYPE.NONE are aliases of QUANT_TYPE.UNSLOTH
_org = "unsloth" # unsloth models -- these are all quantized versions of the original model
register_model(
- model_info_cls=model_info_cls,
- org=_org,
- base_name=base_name,
- version=model_version,
- size=size,
- instruct_tag=instruct_tag,
- quant_type=quant_type,
- is_multimodal=is_multimodal,
+ model_info_cls = model_info_cls,
+ org = _org,
+ base_name = base_name,
+ version = model_version,
+ size = size,
+ instruct_tag = instruct_tag,
+ quant_type = quant_type,
+ is_multimodal = is_multimodal,
)
# include original model from releasing organization
if include_original_model:
register_model(
- model_info_cls=model_info_cls,
- org=org,
- base_name=base_name,
- version=model_version,
- size=size,
- instruct_tag=instruct_tag,
- quant_type=QuantType.NONE,
- is_multimodal=is_multimodal,
+ model_info_cls = model_info_cls,
+ org = org,
+ base_name = base_name,
+ version = model_version,
+ size = size,
+ instruct_tag = instruct_tag,
+ quant_type = QuantType.NONE,
+ is_multimodal = is_multimodal,
)
diff --git a/unsloth/trainer.py b/unsloth/trainer.py
index 196508ebb..6bdb5604d 100644
--- a/unsloth/trainer.py
+++ b/unsloth/trainer.py
@@ -79,7 +79,7 @@ def _create_unsloth_optimizer(
model,
optimizer_cls,
optimizer_kwargs,
- embedding_lr=5e-5,
+ embedding_lr = 5e-5,
):
lr = optimizer_kwargs["lr"]
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
diff --git a/unsloth/utils/hf_hub.py b/unsloth/utils/hf_hub.py
index 30255b863..75df00fbf 100644
--- a/unsloth/utils/hf_hub.py
+++ b/unsloth/utils/hf_hub.py
@@ -36,7 +36,7 @@ def get_model_info(
if _HFAPI is None:
_HFAPI = HfApi()
try:
- model_info: ModelInfo = _HFAPI.model_info(model_id, expand=properties)
+ model_info: ModelInfo = _HFAPI.model_info(model_id, expand = properties)
except Exception as e:
print(f"Error getting model info for {model_id}: {e}")
model_info = None
@@ -68,11 +68,11 @@ def list_models(
properties = None
models: list[ModelInfo] = _HFAPI.list_models(
- author=author,
- search=search,
- sort=sort,
- limit=limit,
- expand=properties,
- full=full,
+ author = author,
+ search = search,
+ sort = sort,
+ limit = limit,
+ expand = properties,
+ full = full,
)
return models