diff --git a/tests/qlora/test_hf_qlora_train_and_merge.py b/tests/qlora/test_hf_qlora_train_and_merge.py index 797d94018..ae975b026 100644 --- a/tests/qlora/test_hf_qlora_train_and_merge.py +++ b/tests/qlora/test_hf_qlora_train_and_merge.py @@ -54,37 +54,37 @@ if __name__ == "__main__": seed = 42 batch_size = 5 num_generations = 5 - tokenizer = setup_tokenizer(model_name, fixup_funcs=[fix_llama3_tokenizer]) + tokenizer = setup_tokenizer(model_name, fixup_funcs = [fix_llama3_tokenizer]) temperature = 0.8 max_new_tokens = 20 - peft_config = get_peft_config(lora_rank=lora_rank, target_modules="all-linear") - model = setup_model(model_name, quantize=True, dtype=dtype, peft_config=peft_config) + peft_config = get_peft_config(lora_rank = lora_rank, target_modules = "all-linear") + model = setup_model(model_name, quantize = True, dtype = dtype, peft_config = peft_config) prompt = tokenizer.apply_chat_template( - [USER_MESSAGE], tokenize=False, add_generation_prompt=True + [USER_MESSAGE], tokenize = False, add_generation_prompt = True ) with header_footer_context("Test Prompt and Answer"): print(f"Test Prompt:\n{prompt}\nExpected Answer:\n{ANSWER}") dataset: Dataset = create_dataset( - tokenizer, num_examples=num_examples, messages=DEFAULT_MESSAGES + tokenizer, num_examples = num_examples, messages = DEFAULT_MESSAGES ) with header_footer_context("Dataset"): print(f"Dataset: {next(iter(dataset))}") training_args = SFTConfig( - output_dir=output_dir, - max_steps=max_steps, - per_device_train_batch_size=batch_size, - log_level="info", - report_to="none", - num_train_epochs=1, - logging_steps=1, - seed=seed, - bf16=dtype == torch.bfloat16, - fp16=dtype == torch.float16, - save_strategy="no", + output_dir = output_dir, + max_steps = max_steps, + per_device_train_batch_size = batch_size, + log_level = "info", + report_to = "none", + num_train_epochs = 1, + logging_steps = 1, + seed = seed, + bf16 = dtype == torch.bfloat16, + fp16 = dtype == torch.float16, + save_strategy = "no", ) with header_footer_context("Train Args"): @@ -92,7 +92,7 @@ if __name__ == "__main__": print(peft_config) trainer = setup_trainer( - model, tokenizer, dataset, training_args, peft_config=peft_config + model, tokenizer, dataset, training_args, peft_config = peft_config ) with header_footer_context("Model"): @@ -108,11 +108,11 @@ if __name__ == "__main__": responses = sample_responses( model, tokenizer, - prompt=prompt, + prompt = prompt, **generation_args, ) with header_footer_context("Responses before training"): - check_responses(responses, answer=ANSWER, prompt=prompt) + check_responses(responses, answer = ANSWER, prompt = prompt) with header_footer_context("Peft Weights before training"): for name, stats in itertools.islice(describe_peft_weights(model), 2): @@ -129,11 +129,11 @@ if __name__ == "__main__": responses = sample_responses( model, tokenizer, - prompt=prompt, + prompt = prompt, **generation_args, ) with header_footer_context("Responses after training"): - check_responses(responses, answer=ANSWER, prompt=prompt) + check_responses(responses, answer = ANSWER, prompt = prompt) model_copy = deepcopy(model) @@ -142,18 +142,18 @@ if __name__ == "__main__": responses = sample_responses( merged_model, tokenizer, - prompt=prompt, + prompt = prompt, **generation_args, ) with header_footer_context("Responses after custom merging to 16bit"): - check_responses(responses, answer=ANSWER, prompt=prompt) + check_responses(responses, answer = ANSWER, prompt = prompt) merged_model_peft = model_copy.merge_and_unload() responses = sample_responses( merged_model_peft, tokenizer, - prompt=prompt, + prompt = prompt, **generation_args, ) with header_footer_context("Responses after peft merge_and_unload"): - check_responses(responses, answer=ANSWER, prompt=prompt) + check_responses(responses, answer = ANSWER, prompt = prompt) diff --git a/tests/qlora/test_unsloth_qlora_train_and_merge.py b/tests/qlora/test_unsloth_qlora_train_and_merge.py index 59fa813fa..9040ad793 100644 --- a/tests/qlora/test_unsloth_qlora_train_and_merge.py +++ b/tests/qlora/test_unsloth_qlora_train_and_merge.py @@ -50,13 +50,13 @@ def get_unsloth_model_and_tokenizer( dtype: torch.dtype = torch.bfloat16, ): return FastLanguageModel.from_pretrained( - model_name=model_name, - max_seq_length=max_seq_length, - load_in_4bit=load_in_4bit, - fast_inference=fast_inference, - max_lora_rank=max_lora_rank, - gpu_memory_utilization=gpu_memory_utilization, - dtype=dtype, + model_name = model_name, + max_seq_length = max_seq_length, + load_in_4bit = load_in_4bit, + fast_inference = fast_inference, + max_lora_rank = max_lora_rank, + gpu_memory_utilization = gpu_memory_utilization, + dtype = dtype, ) @@ -69,11 +69,11 @@ def get_unsloth_peft_model( ): return FastLanguageModel.get_peft_model( model, - r=lora_rank, - target_modules=target_modules, - lora_alpha=lora_rank, - use_gradient_checkpointing=use_gradient_checkpointing, - random_state=random_state, + r = lora_rank, + target_modules = target_modules, + lora_alpha = lora_rank, + use_gradient_checkpointing = use_gradient_checkpointing, + random_state = random_state, ) @@ -101,48 +101,48 @@ if __name__ == "__main__": model, tokenizer = get_unsloth_model_and_tokenizer( model_name, - max_seq_length=512, - load_in_4bit=True, - fast_inference=False, - max_lora_rank=lora_rank, - dtype=dtype, + max_seq_length = 512, + load_in_4bit = True, + fast_inference = False, + max_lora_rank = lora_rank, + dtype = dtype, ) temperature = 0.8 max_new_tokens = 20 model = get_unsloth_peft_model( model, - lora_rank=lora_rank, - target_modules=target_modules, - use_gradient_checkpointing=gradient_checkpointing, - random_state=seed, + lora_rank = lora_rank, + target_modules = target_modules, + use_gradient_checkpointing = gradient_checkpointing, + random_state = seed, ) prompt = tokenizer.apply_chat_template( - [USER_MESSAGE], tokenize=False, add_generation_prompt=True + [USER_MESSAGE], tokenize = False, add_generation_prompt = True ) with header_footer_context("Test Prompt and Answer"): print(f"Test Prompt:\n{prompt}\nExpected Answer:\n{ANSWER}") dataset: Dataset = create_dataset( - tokenizer, num_examples=num_examples, messages=DEFAULT_MESSAGES + tokenizer, num_examples = num_examples, messages = DEFAULT_MESSAGES ) with header_footer_context("Dataset"): print(f"Dataset: {next(iter(dataset))}") training_args = SFTConfig( - output_dir=output_dir, - max_steps=max_steps, - per_device_train_batch_size=batch_size, - log_level="info", - report_to="none", - num_train_epochs=1, - logging_steps=1, - seed=seed, - bf16=dtype == torch.bfloat16, - fp16=dtype == torch.float16, - save_strategy="no", + output_dir = output_dir, + max_steps = max_steps, + per_device_train_batch_size = batch_size, + log_level = "info", + report_to = "none", + num_train_epochs = 1, + logging_steps = 1, + seed = seed, + bf16 = dtype == torch.bfloat16, + fp16 = dtype == torch.float16, + save_strategy = "no", ) with header_footer_context("Train Args"): @@ -163,11 +163,11 @@ if __name__ == "__main__": responses = sample_responses( model, tokenizer, - prompt=prompt, + prompt = prompt, **generation_args, ) with header_footer_context("Responses before training"): - check_responses(responses, answer=ANSWER, prompt=prompt) + check_responses(responses, answer = ANSWER, prompt = prompt) with header_footer_context("Peft Weights before training"): for name, stats in itertools.islice(describe_peft_weights(model), 2): print(f"{name}:\n{stats}") @@ -183,29 +183,29 @@ if __name__ == "__main__": responses = sample_responses( model, tokenizer, - prompt=prompt, + prompt = prompt, **generation_args, ) with header_footer_context("Responses after training"): - check_responses(responses, answer=ANSWER, prompt=prompt) + check_responses(responses, answer = ANSWER, prompt = prompt) model.save_pretrained_merged( unsloth_merged_path, tokenizer, - save_method="merged_16bit", + save_method = "merged_16bit", ) merged_model_unsloth, tokenizer = get_unsloth_model_and_tokenizer( unsloth_merged_path, - max_seq_length=512, - load_in_4bit=False, - fast_inference=False, - dtype=dtype, + max_seq_length = 512, + load_in_4bit = False, + fast_inference = False, + dtype = dtype, ) responses = sample_responses( merged_model_unsloth, tokenizer, - prompt=prompt, + prompt = prompt, **generation_args, ) with header_footer_context("Responses after unsloth merge to 16bit"): - check_responses(responses, answer=ANSWER, prompt=prompt) + check_responses(responses, answer = ANSWER, prompt = prompt) diff --git a/tests/saving/language_models/test_merged_model_perplexity_qwen_2.5.py b/tests/saving/language_models/test_merged_model_perplexity_qwen_2.5.py index 622e72d25..d63bb9fe0 100644 --- a/tests/saving/language_models/test_merged_model_perplexity_qwen_2.5.py +++ b/tests/saving/language_models/test_merged_model_perplexity_qwen_2.5.py @@ -78,17 +78,17 @@ def formatting_prompts_func(examples): } -def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=False): +def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False): """Load model and compute perplexity in subprocess""" from unsloth import FastLanguageModel from tests.utils.perplexity_eval import ppl_model # Load model merged_model, merged_tokenizer = FastLanguageModel.from_pretrained( - model_name="./unsloth_out/merged_qwen_text_model", - max_seq_length=2048, - load_in_4bit=load_in_4bit, - load_in_8bit=load_in_8bit, + model_name = "./unsloth_out/merged_qwen_text_model", + max_seq_length = 2048, + load_in_4bit = load_in_4bit, + load_in_8bit = load_in_8bit, ) # Set up tokenizer # merged_tokenizer = get_chat_template( @@ -98,7 +98,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal # Load dataset fresh in subprocess dataset_ppl = load_dataset( - "allenai/openassistant-guanaco-reformatted", split="eval" + "allenai/openassistant-guanaco-reformatted", split = "eval" ) alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. @@ -146,7 +146,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal "text": texts, } - dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True) + dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True) # Compute perplexity using the passed dataset ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl) @@ -172,7 +172,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal # Main execution code should be wrapped in this guard if __name__ == "__main__": - mp.set_start_method("spawn", force=True) + mp.set_start_method("spawn", force = True) if torch.cuda.is_bf16_supported(): compute_dtype = torch.bfloat16 @@ -182,31 +182,31 @@ if __name__ == "__main__": attn_implementation = "sdpa" model, tokenizer = FastLanguageModel.from_pretrained( - model_name="unsloth/Qwen2.5-7B-Instruct", - max_seq_length=2048, - dtype=compute_dtype, - load_in_4bit=True, - load_in_8bit=False, - full_finetuning=False, - attn_implementation=attn_implementation, + model_name = "unsloth/Qwen2.5-7B-Instruct", + max_seq_length = 2048, + dtype = compute_dtype, + load_in_4bit = True, + load_in_8bit = False, + full_finetuning = False, + attn_implementation = attn_implementation, ) dataset_train = load_dataset( - "allenai/openassistant-guanaco-reformatted", split="train" + "allenai/openassistant-guanaco-reformatted", split = "train" ) dataset_ppl = load_dataset( - "allenai/openassistant-guanaco-reformatted", split="eval" + "allenai/openassistant-guanaco-reformatted", split = "eval" ) - dataset_train = dataset_train.map(formatting_prompts_func, batched=True) - dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True) + dataset_train = dataset_train.map(formatting_prompts_func, batched = True) + dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True) add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl)) model = FastLanguageModel.get_peft_model( model, - r=16, - target_modules=[ + r = 16, + target_modules = [ "k_proj", "q_proj", "v_proj", @@ -215,40 +215,40 @@ if __name__ == "__main__": "down_proj", "up_proj", ], - lora_alpha=16, - lora_dropout=0, - bias="none", - use_gradient_checkpointing="unsloth", - random_state=3407, - use_rslora=False, - loftq_config=None, + lora_alpha = 16, + lora_dropout = 0, + bias = "none", + use_gradient_checkpointing = "unsloth", + random_state = 3407, + use_rslora = False, + loftq_config = None, ) from unsloth import is_bfloat16_supported trainer = SFTTrainer( - model=model, - tokenizer=tokenizer, - train_dataset=dataset_train, - dataset_text_field="text", - max_seq_length=2048, - data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer), - dataset_num_proc=2, - packing=False, - args=TrainingArguments( - per_device_train_batch_size=2, - gradient_accumulation_steps=4, - warmup_ratio=0.1, - max_steps=200, - learning_rate=2e-4, - fp16=not is_bfloat16_supported(), - bf16=is_bfloat16_supported(), - logging_steps=50, - optim="adamw_8bit", - lr_scheduler_type="linear", - seed=3407, - output_dir="outputs", - report_to="none", + model = model, + tokenizer = tokenizer, + train_dataset = dataset_train, + dataset_text_field = "text", + max_seq_length = 2048, + data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer), + dataset_num_proc = 2, + packing = False, + args = TrainingArguments( + per_device_train_batch_size = 2, + gradient_accumulation_steps = 4, + warmup_ratio = 0.1, + max_steps = 200, + learning_rate = 2e-4, + fp16 = not is_bfloat16_supported(), + bf16 = is_bfloat16_supported(), + logging_steps = 50, + optim = "adamw_8bit", + lr_scheduler_type = "linear", + seed = 3407, + output_dir = "outputs", + report_to = "none", ), ) @@ -260,7 +260,7 @@ if __name__ == "__main__": # saving and merging the model to local disk print("merge and save to local disk") model.save_pretrained_merged( - save_directory="./unsloth_out/merged_qwen_text_model", tokenizer=tokenizer + save_directory = "./unsloth_out/merged_qwen_text_model", tokenizer = tokenizer ) # print("cleaning") @@ -272,10 +272,10 @@ if __name__ == "__main__": # load model from local disk and test print("Loading merged model in 4 bit for perplexity test") merged_model, merged_tokenizer = FastLanguageModel.from_pretrained( - model_name="./unsloth_out/merged_qwen_text_model", - max_seq_length=2048, - load_in_4bit=True, - load_in_8bit=False, + model_name = "./unsloth_out/merged_qwen_text_model", + max_seq_length = 2048, + load_in_4bit = True, + load_in_8bit = False, ) add_to_comparison( @@ -284,7 +284,7 @@ if __name__ == "__main__": print("Computing 8-bit model perplexity in subprocess...") result_queue = mp.Queue() - p = mp.Process(target=load_and_compute_8bit_ppl, args=(result_queue, False, True)) + p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True)) p.start() p.join() @@ -293,10 +293,10 @@ if __name__ == "__main__": print("Loading merged model in 16 bit for perplexity test") merged_model, merged_tokenizer = FastLanguageModel.from_pretrained( - model_name="./unsloth_out/merged_qwen_text_model", - max_seq_length=2048, - load_in_4bit=False, - load_in_8bit=False, + model_name = "./unsloth_out/merged_qwen_text_model", + max_seq_length = 2048, + load_in_4bit = False, + load_in_8bit = False, ) add_to_comparison( diff --git a/tests/saving/language_models/test_push_to_hub_merged.py b/tests/saving/language_models/test_push_to_hub_merged.py index b92842b90..58d589305 100644 --- a/tests/saving/language_models/test_push_to_hub_merged.py +++ b/tests/saving/language_models/test_push_to_hub_merged.py @@ -37,7 +37,7 @@ def formatting_prompts_func(examples): convos = examples["messages"] texts = [ tokenizer.apply_chat_template( - convo, tokenize=False, add_generation_prompt=False + convo, tokenize = False, add_generation_prompt = False ) for convo in convos ] @@ -52,34 +52,34 @@ else: attn_implementation = "sdpa" model, tokenizer = FastLanguageModel.from_pretrained( - model_name="unsloth/Llama-3.2-1B-Instruct", - max_seq_length=2048, - dtype=compute_dtype, - load_in_4bit=True, - load_in_8bit=False, - full_finetuning=False, - attn_implementation=attn_implementation, + model_name = "unsloth/Llama-3.2-1B-Instruct", + max_seq_length = 2048, + dtype = compute_dtype, + load_in_4bit = True, + load_in_8bit = False, + full_finetuning = False, + attn_implementation = attn_implementation, ) tokenizer = get_chat_template( tokenizer, - chat_template="llama-3.1", + chat_template = "llama-3.1", ) from unsloth.chat_templates import standardize_sharegpt -dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split="train") -dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split="eval") +dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split = "train") +dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split = "eval") -dataset_train = dataset_train.map(formatting_prompts_func, batched=True) -dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True) +dataset_train = dataset_train.map(formatting_prompts_func, batched = True) +dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True) add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl)) model = FastLanguageModel.get_peft_model( model, - r=16, - target_modules=[ + r = 16, + target_modules = [ "k_proj", "q_proj", "v_proj", @@ -88,40 +88,40 @@ model = FastLanguageModel.get_peft_model( "down_proj", "up_proj", ], - lora_alpha=16, - lora_dropout=0, - bias="none", - use_gradient_checkpointing="unsloth", - random_state=3407, - use_rslora=False, - loftq_config=None, + lora_alpha = 16, + lora_dropout = 0, + bias = "none", + use_gradient_checkpointing = "unsloth", + random_state = 3407, + use_rslora = False, + loftq_config = None, ) from unsloth import is_bfloat16_supported trainer = SFTTrainer( - model=model, - tokenizer=tokenizer, - train_dataset=dataset_train, - dataset_text_field="text", - max_seq_length=2048, - data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer), - dataset_num_proc=2, - packing=False, - args=TrainingArguments( - per_device_train_batch_size=2, - gradient_accumulation_steps=4, - warmup_ratio=0.1, - max_steps=30, - learning_rate=2e-4, - fp16=not is_bfloat16_supported(), - bf16=is_bfloat16_supported(), - logging_steps=50, - optim="adamw_8bit", - lr_scheduler_type="linear", - seed=3407, - output_dir="outputs", - report_to="none", + model = model, + tokenizer = tokenizer, + train_dataset = dataset_train, + dataset_text_field = "text", + max_seq_length = 2048, + data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer), + dataset_num_proc = 2, + packing = False, + args = TrainingArguments( + per_device_train_batch_size = 2, + gradient_accumulation_steps = 4, + warmup_ratio = 0.1, + max_steps = 30, + learning_rate = 2e-4, + fp16 = not is_bfloat16_supported(), + bf16 = is_bfloat16_supported(), + logging_steps = 50, + optim = "adamw_8bit", + lr_scheduler_type = "linear", + seed = 3407, + output_dir = "outputs", + report_to = "none", ), ) @@ -129,8 +129,8 @@ from unsloth.chat_templates import train_on_responses_only trainer = train_on_responses_only( trainer, - instruction_part="<|start_header_id|>user<|end_header_id|>\n\n", - response_part="<|start_header_id|>assistant<|end_header_id|>\n\n", + instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n", + response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n", ) # run training @@ -160,7 +160,7 @@ try: print("\n" + "=" * 80) print("=== UPLOADING MODEL TO HUB ===".center(80)) print("=" * 80 + "\n") - model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token) + model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token) success["upload"] = True print("✅ Model uploaded successfully!") except Exception as e: diff --git a/tests/saving/language_models/test_push_to_hub_merged_sharded_index_file.py b/tests/saving/language_models/test_push_to_hub_merged_sharded_index_file.py index 1f07ce643..038565d17 100644 --- a/tests/saving/language_models/test_push_to_hub_merged_sharded_index_file.py +++ b/tests/saving/language_models/test_push_to_hub_merged_sharded_index_file.py @@ -37,7 +37,7 @@ def formatting_prompts_func(examples): convos = examples["messages"] texts = [ tokenizer.apply_chat_template( - convo, tokenize=False, add_generation_prompt=False + convo, tokenize = False, add_generation_prompt = False ) for convo in convos ] @@ -52,34 +52,34 @@ else: attn_implementation = "sdpa" model, tokenizer = FastLanguageModel.from_pretrained( - model_name="unsloth/Llama-3.1-8B-Instruct", - max_seq_length=2048, - dtype=compute_dtype, - load_in_4bit=True, - load_in_8bit=False, - full_finetuning=False, - attn_implementation=attn_implementation, + model_name = "unsloth/Llama-3.1-8B-Instruct", + max_seq_length = 2048, + dtype = compute_dtype, + load_in_4bit = True, + load_in_8bit = False, + full_finetuning = False, + attn_implementation = attn_implementation, ) tokenizer = get_chat_template( tokenizer, - chat_template="llama-3.1", + chat_template = "llama-3.1", ) from unsloth.chat_templates import standardize_sharegpt -dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split="train") -dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split="eval") +dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split = "train") +dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split = "eval") -dataset_train = dataset_train.map(formatting_prompts_func, batched=True) -dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True) +dataset_train = dataset_train.map(formatting_prompts_func, batched = True) +dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True) add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl)) model = FastLanguageModel.get_peft_model( model, - r=16, - target_modules=[ + r = 16, + target_modules = [ "k_proj", "q_proj", "v_proj", @@ -88,40 +88,40 @@ model = FastLanguageModel.get_peft_model( "down_proj", "up_proj", ], - lora_alpha=16, - lora_dropout=0, - bias="none", - use_gradient_checkpointing="unsloth", - random_state=3407, - use_rslora=False, - loftq_config=None, + lora_alpha = 16, + lora_dropout = 0, + bias = "none", + use_gradient_checkpointing = "unsloth", + random_state = 3407, + use_rslora = False, + loftq_config = None, ) from unsloth import is_bfloat16_supported trainer = SFTTrainer( - model=model, - tokenizer=tokenizer, - train_dataset=dataset_train, - dataset_text_field="text", - max_seq_length=2048, - data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer), - dataset_num_proc=2, - packing=False, - args=TrainingArguments( - per_device_train_batch_size=2, - gradient_accumulation_steps=4, - warmup_ratio=0.1, - max_steps=30, - learning_rate=2e-4, - fp16=not is_bfloat16_supported(), - bf16=is_bfloat16_supported(), - logging_steps=50, - optim="adamw_8bit", - lr_scheduler_type="linear", - seed=3407, - output_dir="outputs", - report_to="none", + model = model, + tokenizer = tokenizer, + train_dataset = dataset_train, + dataset_text_field = "text", + max_seq_length = 2048, + data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer), + dataset_num_proc = 2, + packing = False, + args = TrainingArguments( + per_device_train_batch_size = 2, + gradient_accumulation_steps = 4, + warmup_ratio = 0.1, + max_steps = 30, + learning_rate = 2e-4, + fp16 = not is_bfloat16_supported(), + bf16 = is_bfloat16_supported(), + logging_steps = 50, + optim = "adamw_8bit", + lr_scheduler_type = "linear", + seed = 3407, + output_dir = "outputs", + report_to = "none", ), ) @@ -129,8 +129,8 @@ from unsloth.chat_templates import train_on_responses_only trainer = train_on_responses_only( trainer, - instruction_part="<|start_header_id|>user<|end_header_id|>\n\n", - response_part="<|start_header_id|>assistant<|end_header_id|>\n\n", + instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n", + response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n", ) # run training @@ -161,7 +161,7 @@ try: print("\n" + "=" * 80) print("=== UPLOADING MODEL TO HUB ===".center(80)) print("=" * 80 + "\n") - model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token) + model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token) success["upload"] = True print("✅ Model uploaded successfully!") except Exception as e: @@ -173,8 +173,8 @@ try: print("\n" + "=" * 80) print("=== VERIFYING REPO CONTENTS ===".center(80)) print("=" * 80 + "\n") - fs = HfFileSystem(token=hf_token) - file_list = fs.ls(repo_name, detail=True) + fs = HfFileSystem(token = hf_token) + file_list = fs.ls(repo_name, detail = True) safetensors_found = any( file["name"].endswith("model.safetensors.index.json") for file in file_list ) diff --git a/tests/saving/language_models/test_save_merged_grpo_model.py b/tests/saving/language_models/test_save_merged_grpo_model.py index e93b8add9..67b649305 100644 --- a/tests/saving/language_models/test_save_merged_grpo_model.py +++ b/tests/saving/language_models/test_save_merged_grpo_model.py @@ -24,7 +24,7 @@ max_seq_length = 2048 # Can increase for longer reasoning traces lora_rank = 64 # Larger rank = smarter, but slower -def evaluate_merged_model(result_queue, load_in_4bit=False, load_in_8bit=False): +def evaluate_merged_model(result_queue, load_in_4bit = False, load_in_8bit = False): from unsloth import FastLanguageModel from tests.utils.aime_eval import evaluate_model_aime @@ -32,12 +32,12 @@ def evaluate_merged_model(result_queue, load_in_4bit=False, load_in_8bit=False): lora_rank = 64 # Larger rank = smarter, but slower model, tokenizer = FastLanguageModel.from_pretrained( - model_name="./final_merged_model", - max_seq_length=max_seq_length, - load_in_4bit=True, # False for LoRA 16bit - fast_inference=True, # Enable vLLM fast inference - max_lora_rank=lora_rank, - gpu_memory_utilization=0.8, # Reduce if out of memory + model_name = "./final_merged_model", + max_seq_length = max_seq_length, + load_in_4bit = True, # False for LoRA 16bit + fast_inference = True, # Enable vLLM fast inference + max_lora_rank = lora_rank, + gpu_memory_utilization = 0.8, # Reduce if out of memory ) print(f"\n{'='*60}") @@ -53,14 +53,14 @@ def evaluate_merged_model(result_queue, load_in_4bit=False, load_in_8bit=False): print(f"{'='*60}") evaluate_model_aime( - model=model, - tokenizer=tokenizer, - model_type=model_type, - temperature=0.3, - n_sampling=8, - max_tokens=32768, - top_p=0.95, - seed=0, + model = model, + tokenizer = tokenizer, + model_type = model_type, + temperature = 0.3, + n_sampling = 8, + max_tokens = 32768, + top_p = 0.95, + seed = 0, ) result_queue.put(results) @@ -74,12 +74,12 @@ def evaluate_merged_model(result_queue, load_in_4bit=False, load_in_8bit=False): # Main execution code should be wrapped in this guard def training_run(result_queue): model, tokenizer = FastLanguageModel.from_pretrained( - model_name="meta-llama/Llama-3.2-3B-Instruct", - max_seq_length=max_seq_length, - load_in_4bit=False, # False for LoRA 16bit - fast_inference=True, # Enable vLLM fast inference - max_lora_rank=lora_rank, - gpu_memory_utilization=0.8, # Reduce if out of memory + model_name = "meta-llama/Llama-3.2-3B-Instruct", + max_seq_length = max_seq_length, + load_in_4bit = False, # False for LoRA 16bit + fast_inference = True, # Enable vLLM fast inference + max_lora_rank = lora_rank, + gpu_memory_utilization = 0.8, # Reduce if out of memory ) """### Helper Functions @@ -166,10 +166,10 @@ def training_run(result_queue): lengths = dataset.map( lambda x: { "tokens": tokenizer.apply_chat_template( - x["prompt"], add_generation_prompt=True, tokenize=True + x["prompt"], add_generation_prompt = True, tokenize = True ) }, - batched=True, + batched = True, ).map(lambda x: {"length": len(x["tokens"])})["length"] max_length = max(lengths) @@ -181,7 +181,7 @@ def training_run(result_queue): ) return max_length, avg_length - def extract_unsloth_answer(text, start_tag="", end_tag=""): + def extract_unsloth_answer(text, start_tag = "", end_tag = ""): """Extract answer from Unsloth SOLUTION tags""" pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag) matches = re.findall(pattern, text, re.DOTALL) @@ -213,10 +213,10 @@ def training_run(result_queue): """Count tokens in text""" if not text: return 0 - encoding = tokenizer_instance(text, return_tensors="pt") + encoding = tokenizer_instance(text, return_tensors = "pt") return len(encoding["input_ids"][0]) - def check_format_compliance(text, format_type="unsloth"): + def check_format_compliance(text, format_type = "unsloth"): """Check if response follows expected format""" if format_type == "unsloth": reasoning_start = "" @@ -419,11 +419,11 @@ def training_run(result_queue): # Save comparison comparison_data = { "summary": all_results, - "best_model": max(all_results, key=lambda x: x["exact_match_pct"]), + "best_model": max(all_results, key = lambda x: x["exact_match_pct"]), } with open("model_comparison_comprehensive.json", "w") as f: - json.dump(comparison_data, f, indent=4) + json.dump(comparison_data, f, indent = 4) print( f"\nBest performing model: {comparison_data['best_model']['model_type']} " @@ -449,10 +449,10 @@ def training_run(result_queue): from datasets import load_dataset # Load GSM8K - gsm8k_dataset = load_dataset("openai/gsm8k", "main", split="train") + gsm8k_dataset = load_dataset("openai/gsm8k", "main", split = "train") # Load LIMO (adjust this based on your access method) - limo_train = load_dataset("GAIR/LIMO", split="train") + limo_train = load_dataset("GAIR/LIMO", split = "train") # Prepare datasets gsm8k_train = prepare_gsm8k_dataset(gsm8k_dataset) @@ -466,28 +466,28 @@ def training_run(result_queue): # Single temperature evaluation on combined dataset results = evaluate_model_aime( - model=model, - tokenizer=tokenizer, - model_type="base", - temperature=0.3, - n_sampling=8, - max_tokens=32768, - top_p=0.95, - seed=0, + model = model, + tokenizer = tokenizer, + model_type = "base", + temperature = 0.3, + n_sampling = 8, + max_tokens = 32768, + top_p = 0.95, + seed = 0, ) from unsloth.chat_templates import get_chat_template tokenizer = get_chat_template( tokenizer, - chat_template="llama-3.1", + chat_template = "llama-3.1", ) def formatting_prompts_func(examples): convos = examples["prompt"] texts = [ tokenizer.apply_chat_template( - convo, tokenize=False, add_generation_prompt=False + convo, tokenize = False, add_generation_prompt = False ) for convo in convos ] @@ -497,7 +497,7 @@ def training_run(result_queue): limo_train = limo_train.map( formatting_prompts_func, - batched=True, + batched = True, ) from trl import SFTTrainer @@ -510,8 +510,8 @@ def training_run(result_queue): model = FastLanguageModel.get_peft_model( model, - r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 - target_modules=[ + r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 + target_modules = [ "q_proj", "k_proj", "v_proj", @@ -520,37 +520,37 @@ def training_run(result_queue): "up_proj", "down_proj", ], # Remove QKVO if out of memory - lora_alpha=lora_rank, - use_gradient_checkpointing="unsloth", # Enable long context finetuning - random_state=3407, + lora_alpha = lora_rank, + use_gradient_checkpointing = "unsloth", # Enable long context finetuning + random_state = 3407, ) if limo_train is not None: trainer = SFTTrainer( - model=model, - tokenizer=tokenizer, - train_dataset=limo_train, - dataset_text_field="text", - max_seq_length=max_seq_length, - data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer), - dataset_num_proc=2, - packing=False, # Can make training 5x faster for short sequences. - args=TrainingArguments( - per_device_train_batch_size=2, - gradient_accumulation_steps=4, - warmup_steps=5, - num_train_epochs=1, # Set this for 1 full training run. + model = model, + tokenizer = tokenizer, + train_dataset = limo_train, + dataset_text_field = "text", + max_seq_length = max_seq_length, + data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer), + dataset_num_proc = 2, + packing = False, # Can make training 5x faster for short sequences. + args = TrainingArguments( + per_device_train_batch_size = 2, + gradient_accumulation_steps = 4, + warmup_steps = 5, + num_train_epochs = 1, # Set this for 1 full training run. # max_steps = 60, - learning_rate=2e-4, - fp16=not is_bfloat16_supported(), - bf16=is_bfloat16_supported(), - logging_steps=1, - optim="adamw_8bit", - weight_decay=0.01, - lr_scheduler_type="linear", - seed=3407, - output_dir="outputs", - report_to="none", # Use this for WandB etc + learning_rate = 2e-4, + fp16 = not is_bfloat16_supported(), + bf16 = is_bfloat16_supported(), + logging_steps = 1, + optim = "adamw_8bit", + weight_decay = 0.01, + lr_scheduler_type = "linear", + seed = 3407, + output_dir = "outputs", + report_to = "none", # Use this for WandB etc ), ) @@ -558,8 +558,8 @@ def training_run(result_queue): trainer = train_on_responses_only( trainer, - instruction_part="<|start_header_id|>user<|end_header_id|>\n\n", - response_part="<|start_header_id|>assistant<|end_header_id|>\n\n", + instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n", + response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n", ) # Train @@ -588,7 +588,7 @@ def training_run(result_queue): PRINT_EVERY_STEPS = 5 match_numbers = re.compile( - solution_start + r".*?([\d\.\,]{1,})", flags=re.MULTILINE | re.DOTALL + solution_start + r".*?([\d\.\,]{1,})", flags = re.MULTILINE | re.DOTALL ) def check_numbers(prompts, completions, answer, **kwargs): @@ -642,37 +642,37 @@ def training_run(result_queue): from trl import GRPOConfig, GRPOTrainer training_args = GRPOConfig( - learning_rate=5e-6, - weight_decay=0.1, - warmup_ratio=0.1, - lr_scheduler_type="cosine", - optim="adamw_torch_fused", - logging_steps=1, - per_device_train_batch_size=1, - gradient_accumulation_steps=4, # Increase to 4 for smoother training - num_generations=8, # Decrease if out of memory - max_prompt_length=max_prompt_length, - max_completion_length=max_seq_length - max_prompt_length, + learning_rate = 5e-6, + weight_decay = 0.1, + warmup_ratio = 0.1, + lr_scheduler_type = "cosine", + optim = "adamw_torch_fused", + logging_steps = 1, + per_device_train_batch_size = 1, + gradient_accumulation_steps = 4, # Increase to 4 for smoother training + num_generations = 8, # Decrease if out of memory + max_prompt_length = max_prompt_length, + max_completion_length = max_seq_length - max_prompt_length, # num_train_epochs = 1, # Set to 1 for a full training run # max_steps = 250, - max_steps=1000, - save_steps=250, - max_grad_norm=0.1, - report_to="none", # Can use Weights & Biases - output_dir="outputs", + max_steps = 1000, + save_steps = 250, + max_grad_norm = 0.1, + report_to = "none", # Can use Weights & Biases + output_dir = "outputs", ) trainer = GRPOTrainer( - model=model, - processing_class=tokenizer, - reward_funcs=[ + model = model, + processing_class = tokenizer, + reward_funcs = [ match_format_exactly, match_format_approximately, check_answer_correctness, check_numbers, ], - args=training_args, - train_dataset=gsm8k_train, + args = training_args, + train_dataset = gsm8k_train, ) # Train @@ -696,14 +696,14 @@ def training_run(result_queue): print(f"{'='*60}") grpo_results = evaluate_model_aime( - model=model, - tokenizer=tokenizer, - model_type="grpo", - temperature=0.3, - n_sampling=8, - max_tokens=32768, - top_p=0.95, - seed=0, + model = model, + tokenizer = tokenizer, + model_type = "grpo", + temperature = 0.3, + n_sampling = 8, + max_tokens = 32768, + top_p = 0.95, + seed = 0, ) all_results.append(grpo_results) @@ -716,7 +716,7 @@ def training_run(result_queue): # Save as merged model try: model.save_pretrained_merged( - "final_merged_model", tokenizer, save_method="merged_16bit" + "final_merged_model", tokenizer, save_method = "merged_16bit" ) print("✅ Merged model saved to: final_merged_model/") except Exception as e: @@ -774,12 +774,12 @@ def training_run(result_queue): if __name__ == "__main__": - mp.set_start_method("spawn", force=True) + mp.set_start_method("spawn", force = True) result_queue = mp.Queue() all_results = [] # run main finetuning and grpo loop - p = mp.Process(target=training_run, args=(result_queue,)) + p = mp.Process(target = training_run, args = (result_queue,)) p.start() p.join() @@ -787,7 +787,7 @@ if __name__ == "__main__": all_results = results # evaluate merged model loaded 16bits - p = mp.Process(target=evaluate_merged_model, args=(result_queue, False, False)) + p = mp.Process(target = evaluate_merged_model, args = (result_queue, False, False)) p.start() p.join() @@ -796,7 +796,7 @@ if __name__ == "__main__": safe_remove_directory("./unsloth_compiled_cache") # Merged model load 8 bits model AIME eval - p = mp.Process(target=evaluate_merged_model, args=(result_queue, False, True)) + p = mp.Process(target = evaluate_merged_model, args = (result_queue, False, True)) p.start() p.join() @@ -806,7 +806,7 @@ if __name__ == "__main__": safe_remove_directory("./unsloth_compiled_cache") # Merged model load 4 bits model AIME eval - p = mp.Process(target=evaluate_merged_model, args=(result_queue, True, False)) + p = mp.Process(target = evaluate_merged_model, args = (result_queue, True, False)) p.start() p.join() diff --git a/tests/saving/test_unsloth_save.py b/tests/saving/test_unsloth_save.py index 5ec2a943e..35fdad6ba 100644 --- a/tests/saving/test_unsloth_save.py +++ b/tests/saving/test_unsloth_save.py @@ -43,31 +43,31 @@ tokenizer_files = [ ] -@pytest.fixture(scope="session", params=model_to_test) +@pytest.fixture(scope = "session", params = model_to_test) def loaded_model_tokenizer(request): model_name = request.param print("Loading model and tokenizer...") model, tokenizer = FastModel.from_pretrained( model_name, # use small model - max_seq_length=128, - dtype=None, - load_in_4bit=True, + max_seq_length = 128, + dtype = None, + load_in_4bit = True, ) # Apply LoRA model = FastModel.get_peft_model( model, - r=16, - target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], - lora_alpha=16, - use_gradient_checkpointing="unsloth", + r = 16, + target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"], + lora_alpha = 16, + use_gradient_checkpointing = "unsloth", ) return model, tokenizer -@pytest.fixture(scope="session", params=torchao_models) +@pytest.fixture(scope = "session", params = torchao_models) def fp16_model_tokenizer(request): """Load model in FP16 for TorchAO quantization""" model_name = request.param @@ -75,29 +75,29 @@ def fp16_model_tokenizer(request): model, tokenizer = FastModel.from_pretrained( model_name, - max_seq_length=128, - dtype=None, - load_in_4bit=False, # No BnB quantization + max_seq_length = 128, + dtype = None, + load_in_4bit = False, # No BnB quantization ) # Apply LoRA model = FastModel.get_peft_model( model, - r=16, - target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], - lora_alpha=16, - use_gradient_checkpointing="unsloth", + r = 16, + target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"], + lora_alpha = 16, + use_gradient_checkpointing = "unsloth", ) return model, tokenizer -@pytest.fixture(scope="session") +@pytest.fixture(scope = "session") def model(loaded_model_tokenizer): return loaded_model_tokenizer[0] -@pytest.fixture(scope="session") +@pytest.fixture(scope = "session") def tokenizer(loaded_model_tokenizer): return loaded_model_tokenizer[1] @@ -133,7 +133,7 @@ def test_save_merged_16bit(model, tokenizer, temp_save_dir: str): ) model.save_pretrained_merged( - save_path, tokenizer=tokenizer, save_method="merged_16bit" + save_path, tokenizer = tokenizer, save_method = "merged_16bit" ) # Check model files @@ -172,9 +172,9 @@ def test_save_merged_16bit(model, tokenizer, temp_save_dir: str): # Test loading the model from the saved path loaded_model, loaded_tokenizer = FastLanguageModel.from_pretrained( save_path, - max_seq_length=128, - dtype=None, - load_in_4bit=True, + max_seq_length = 128, + dtype = None, + load_in_4bit = True, ) @@ -186,7 +186,7 @@ def test_save_merged_4bit(model, tokenizer, temp_save_dir: str): ) model.save_pretrained_merged( - save_path, tokenizer=tokenizer, save_method="merged_4bit_forced" + save_path, tokenizer = tokenizer, save_method = "merged_4bit_forced" ) # Check model files @@ -230,15 +230,15 @@ def test_save_merged_4bit(model, tokenizer, temp_save_dir: str): # Test loading the model from the saved path loaded_model, loaded_tokenizer = FastModel.from_pretrained( save_path, - max_seq_length=128, - dtype=None, - load_in_4bit=True, + max_seq_length = 128, + dtype = None, + load_in_4bit = True, ) @pytest.mark.skipif( importlib.util.find_spec("torchao") is None, - reason="require torchao to be installed", + reason = "require torchao to be installed", ) def test_save_torchao(fp16_model_tokenizer, temp_save_dir: str): model, tokenizer = fp16_model_tokenizer @@ -251,9 +251,9 @@ def test_save_torchao(fp16_model_tokenizer, temp_save_dir: str): torchao_config = Int8DynamicActivationInt8WeightConfig() model.save_pretrained_torchao( save_path, - tokenizer=tokenizer, - torchao_config=torchao_config, - push_to_hub=False, + tokenizer = tokenizer, + torchao_config = torchao_config, + push_to_hub = False, ) weight_files_16bit = [ @@ -316,15 +316,15 @@ def test_save_torchao(fp16_model_tokenizer, temp_save_dir: str): with torch.serialization.safe_globals([getattr]): loaded_model, loaded_tokenizer = FastModel.from_pretrained( torchao_save_path, - max_seq_length=128, - dtype=None, - load_in_4bit=False, + max_seq_length = 128, + dtype = None, + load_in_4bit = False, ) @pytest.mark.skipif( importlib.util.find_spec("torchao") is None, - reason="require torchao to be installed", + reason = "require torchao to be installed", ) def test_save_and_inference_torchao(fp16_model_tokenizer, temp_save_dir: str): model, tokenizer = fp16_model_tokenizer @@ -343,9 +343,9 @@ def test_save_and_inference_torchao(fp16_model_tokenizer, temp_save_dir: str): # Save with TorchAO model.save_pretrained_torchao( save_path, - tokenizer=tokenizer, - torchao_config=torchao_config, - push_to_hub=False, + tokenizer = tokenizer, + torchao_config = torchao_config, + push_to_hub = False, ) torchao_save_path = save_path + "-torchao" @@ -361,9 +361,9 @@ def test_save_and_inference_torchao(fp16_model_tokenizer, temp_save_dir: str): with torch.serialization.safe_globals([getattr]): loaded_model, loaded_tokenizer = FastModel.from_pretrained( torchao_save_path, - max_seq_length=128, - dtype=None, - load_in_4bit=False, + max_seq_length = 128, + dtype = None, + load_in_4bit = False, ) FastModel.for_inference(loaded_model) # Enable native 2x faster inference @@ -376,24 +376,24 @@ def test_save_and_inference_torchao(fp16_model_tokenizer, temp_save_dir: str): ] inputs = loaded_tokenizer.apply_chat_template( messages, - tokenize=True, - add_generation_prompt=True, # Must add for generation - return_tensors="pt", + tokenize = True, + add_generation_prompt = True, # Must add for generation + return_tensors = "pt", ).to("cuda") outputs = loaded_model.generate( # ← Use loaded_model, not model - input_ids=inputs, - max_new_tokens=64, - use_cache=False, # Avoid cache issues - temperature=1.5, - min_p=0.1, - do_sample=True, - pad_token_id=loaded_tokenizer.pad_token_id or loaded_tokenizer.eos_token_id, + input_ids = inputs, + max_new_tokens = 64, + use_cache = False, # Avoid cache issues + temperature = 1.5, + min_p = 0.1, + do_sample = True, + pad_token_id = loaded_tokenizer.pad_token_id or loaded_tokenizer.eos_token_id, ) # Decode with the LOADED tokenizer - generated_text = loaded_tokenizer.decode(outputs[0], skip_special_tokens=True) - input_text = loaded_tokenizer.decode(inputs[0], skip_special_tokens=True) + generated_text = loaded_tokenizer.decode(outputs[0], skip_special_tokens = True) + input_text = loaded_tokenizer.decode(inputs[0], skip_special_tokens = True) response_part = generated_text[len(input_text) :].strip() print(f"Input: {input_text}") diff --git a/tests/saving/text_to_speech_models/test_csm.py b/tests/saving/text_to_speech_models/test_csm.py index 7f82ca63a..c1a892a8d 100644 --- a/tests/saving/text_to_speech_models/test_csm.py +++ b/tests/saving/text_to_speech_models/test_csm.py @@ -26,11 +26,11 @@ print(f"{'='*80}") model, tokenizer = FastModel.from_pretrained( - model_name="unsloth/csm-1b", - max_seq_length=2048, # Choose any for long context! - dtype=None, # Leave as None for auto-detection - auto_model=CsmForConditionalGeneration, - load_in_4bit=False, # Select True for 4bit - reduces memory usage + model_name = "unsloth/csm-1b", + max_seq_length = 2048, # Choose any for long context! + dtype = None, # Leave as None for auto-detection + auto_model = CsmForConditionalGeneration, + load_in_4bit = False, # Select True for 4bit - reduces memory usage ) @@ -39,8 +39,8 @@ base_model_class = model.__class__.__name__ model = FastModel.get_peft_model( model, - r=32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 - target_modules=[ + r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 + target_modules = [ "q_proj", "k_proj", "v_proj", @@ -49,14 +49,14 @@ model = FastModel.get_peft_model( "up_proj", "down_proj", ], - lora_alpha=32, - lora_dropout=0, # Supports any, but = 0 is optimized - bias="none", # Supports any, but = "none" is optimized + lora_alpha = 32, + lora_dropout = 0, # Supports any, but = 0 is optimized + bias = "none", # Supports any, but = "none" is optimized # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes! - use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context - random_state=3407, - use_rslora=False, # We support rank stabilized LoRA - loftq_config=None, # And LoftQ + use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context + random_state = 3407, + use_rslora = False, # We support rank stabilized LoRA + loftq_config = None, # And LoftQ ) print("✅ Model and LoRA adapters loaded successfully!") @@ -110,11 +110,11 @@ print(f"{'='*80}") model, processor = FastModel.from_pretrained( - model_name="./csm", - max_seq_length=2048, # Choose any for long context! - dtype=None, # Leave as None for auto-detection - auto_model=CsmForConditionalGeneration, - load_in_4bit=False, # Select True for 4bit - reduces memory usage + model_name = "./csm", + max_seq_length = 2048, # Choose any for long context! + dtype = None, # Leave as None for auto-detection + auto_model = CsmForConditionalGeneration, + load_in_4bit = False, # Select True for 4bit - reduces memory usage ) from transformers import AutoProcessor @@ -138,19 +138,19 @@ try: "We just finished fine tuning a text to speech model... and it's pretty good!" ) speaker_id = 0 - inputs = processor(f"[{speaker_id}]{text}", add_special_tokens=True).to("cuda") + inputs = processor(f"[{speaker_id}]{text}", add_special_tokens = True).to("cuda") audio_values = model.generate( **inputs, - max_new_tokens=125, # 125 tokens is 10 seconds of audio, for longer speech increase this + max_new_tokens = 125, # 125 tokens is 10 seconds of audio, for longer speech increase this # play with these parameters to get the best results - depth_decoder_temperature=0.6, - depth_decoder_top_k=0, - depth_decoder_top_p=0.9, - temperature=0.8, - top_k=50, - top_p=1.0, + depth_decoder_temperature = 0.6, + depth_decoder_top_k = 0, + depth_decoder_top_p = 0.9, + temperature = 0.8, + top_k = 50, + top_p = 1.0, ######################################################### - output_audio=True, + output_audio = True, ) audio = audio_values[0].to(torch.float32).cpu().numpy() sf.write("example_without_context.wav", audio, 24000) diff --git a/tests/saving/text_to_speech_models/test_lasa.py b/tests/saving/text_to_speech_models/test_lasa.py index e9e05af08..804ff512f 100644 --- a/tests/saving/text_to_speech_models/test_lasa.py +++ b/tests/saving/text_to_speech_models/test_lasa.py @@ -42,10 +42,10 @@ print(f"{'='*80}") max_seq_length = 2048 model, tokenizer = FastLanguageModel.from_pretrained( - model_name="unsloth/Llasa-1B", - max_seq_length=max_seq_length, - dtype=None, # Select None for auto detection - load_in_4bit=False, # Choose True for 4bit which reduces memory + model_name = "unsloth/Llasa-1B", + max_seq_length = max_seq_length, + dtype = None, # Select None for auto detection + load_in_4bit = False, # Choose True for 4bit which reduces memory # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf ) @@ -54,16 +54,16 @@ base_model_class = model.__class__.__name__ model = FastLanguageModel.get_peft_model( model, - r=128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 - target_modules=["q_proj", "v_proj"], - lora_alpha=128, - lora_dropout=0, # Supports any, but = 0 is optimized - bias="none", # Supports any, but = "none" is optimized + r = 128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 + target_modules = ["q_proj", "v_proj"], + lora_alpha = 128, + lora_dropout = 0, # Supports any, but = 0 is optimized + bias = "none", # Supports any, but = "none" is optimized # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes! - use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context - random_state=3407, - use_rslora=False, # We support rank stabilized LoRA - loftq_config=None, # And LoftQ + use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context + random_state = 3407, + use_rslora = False, # We support rank stabilized LoRA + loftq_config = None, # And LoftQ ) print("✅ Model and LoRA adapters loaded successfully!") @@ -117,10 +117,10 @@ print(f"{'='*80}") model, tokenizer = FastLanguageModel.from_pretrained( - model_name="./lasa", - max_seq_length=max_seq_length, - dtype=None, # Select None for auto detection - load_in_4bit=False, # Choose True for 4bit which reduces memory + model_name = "./lasa", + max_seq_length = max_seq_length, + dtype = None, # Select None for auto detection + load_in_4bit = False, # Choose True for 4bit which reduces memory # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf ) @@ -166,7 +166,7 @@ def extract_speech_ids(speech_tokens_str): # TTS start! with torch.inference_mode(): - with torch.amp.autocast("cuda", dtype=model.dtype): + with torch.amp.autocast("cuda", dtype = model.dtype): formatted_text = ( f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>" ) @@ -178,7 +178,7 @@ with torch.inference_mode(): ] input_ids = tokenizer.apply_chat_template( - chat, tokenize=True, return_tensors="pt", continue_final_message=True + chat, tokenize = True, return_tensors = "pt", continue_final_message = True ) input_ids = input_ids.to("cuda") @@ -187,16 +187,16 @@ with torch.inference_mode(): # Generate the speech autoregressively outputs = model.generate( input_ids, - max_length=2048, # We trained our model with a max length of 2048 - eos_token_id=speech_end_id, - do_sample=True, - top_p=1.2, # Adjusts the diversity of generated content - temperature=1.2, # Controls randomness in output + max_length = 2048, # We trained our model with a max length of 2048 + eos_token_id = speech_end_id, + do_sample = True, + top_p = 1.2, # Adjusts the diversity of generated content + temperature = 1.2, # Controls randomness in output ) # Extract the speech tokens generated_ids = outputs[0][input_ids.shape[1] : -1] - speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens = True) # Convert token <|s_23456|> to int 23456 speech_tokens = extract_speech_ids(speech_tokens) diff --git a/tests/saving/text_to_speech_models/test_whisper.py b/tests/saving/text_to_speech_models/test_whisper.py index 002d136cb..55f6d98ca 100644 --- a/tests/saving/text_to_speech_models/test_whisper.py +++ b/tests/saving/text_to_speech_models/test_whisper.py @@ -28,12 +28,12 @@ print(f"{'='*80}") model, tokenizer = FastModel.from_pretrained( - model_name="unsloth/whisper-large-v3", - dtype=None, # Leave as None for auto detection - load_in_4bit=False, # Set to True to do 4bit quantization which reduces memory - auto_model=WhisperForConditionalGeneration, - whisper_language="English", - whisper_task="transcribe", + model_name = "unsloth/whisper-large-v3", + dtype = None, # Leave as None for auto detection + load_in_4bit = False, # Set to True to do 4bit quantization which reduces memory + auto_model = WhisperForConditionalGeneration, + whisper_language = "English", + whisper_task = "transcribe", # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf ) @@ -46,17 +46,17 @@ model.generation_config.forced_decoder_ids = None model = FastModel.get_peft_model( model, - r=64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 - target_modules=["q_proj", "v_proj"], - lora_alpha=64, - lora_dropout=0, # Supports any, but = 0 is optimized - bias="none", # Supports any, but = "none" is optimized + r = 64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 + target_modules = ["q_proj", "v_proj"], + lora_alpha = 64, + lora_dropout = 0, # Supports any, but = 0 is optimized + bias = "none", # Supports any, but = "none" is optimized # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes! - use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context - random_state=3407, - use_rslora=False, # We support rank stabilized LoRA - loftq_config=None, # And LoftQ - task_type=None, # ** MUST set this for Whisper ** + use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context + random_state = 3407, + use_rslora = False, # We support rank stabilized LoRA + loftq_config = None, # And LoftQ + task_type = None, # ** MUST set this for Whisper ** ) print("✅ Model and LoRA adapters loaded successfully!") @@ -110,12 +110,12 @@ print(f"{'='*80}") model, tokenizer = FastModel.from_pretrained( - model_name="./whisper", - dtype=None, # Leave as None for auto detection - load_in_4bit=False, # Set to True to do 4bit quantization which reduces memory - auto_model=WhisperForConditionalGeneration, - whisper_language="English", - whisper_task="transcribe", + model_name = "./whisper", + dtype = None, # Leave as None for auto detection + load_in_4bit = False, # Set to True to do 4bit quantization which reduces memory + auto_model = WhisperForConditionalGeneration, + whisper_language = "English", + whisper_task = "transcribe", # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf ) @@ -135,7 +135,7 @@ try: headers = { "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" } - response = requests.get(audio_url, headers=headers) + response = requests.get(audio_url, headers = headers) response.raise_for_status() with open(audio_file, "wb") as f: f.write(response.content) @@ -156,12 +156,12 @@ model.eval() # Create pipeline without specifying the device whisper = pipeline( "automatic-speech-recognition", - model=model, - tokenizer=tokenizer.tokenizer, - feature_extractor=tokenizer.feature_extractor, - processor=tokenizer, - return_language=True, - torch_dtype=torch.float16, # Remove the device parameter + model = model, + tokenizer = tokenizer.tokenizer, + feature_extractor = tokenizer.feature_extractor, + processor = tokenizer, + return_language = True, + torch_dtype = torch.float16, # Remove the device parameter ) # Example usage audio_file = "Speech_12dB_s16.flac" diff --git a/tests/saving/vision_models/test_index_file_sharded_model.py b/tests/saving/vision_models/test_index_file_sharded_model.py index 39cadb433..f73716984 100644 --- a/tests/saving/vision_models/test_index_file_sharded_model.py +++ b/tests/saving/vision_models/test_index_file_sharded_model.py @@ -21,7 +21,7 @@ from tests.utils.cleanup_utils import safe_remove_directory ## Dataset Preparation""" print("\n📊 Loading and preparing dataset...") -dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split="train") +dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split = "train") # To select the first 2000 examples train_dataset = dataset.select(range(2000)) @@ -81,11 +81,11 @@ print("🤖 Loading base vision model...") try: model, tokenizer = FastVisionModel.from_pretrained( # model_name = "unsloth/Qwen2-VL-7B-Instruct", - model_name="unsloth/Qwen2-VL-7B-Instruct", - max_seq_length=2048, # Choose any for long context! - load_in_4bit=True, # 4 bit quantization to reduce memory - load_in_8bit=False, # [NEW!] A bit more accurate, uses 2x memory - full_finetuning=False, # [NEW!] We have full finetuning now! + model_name = "unsloth/Qwen2-VL-7B-Instruct", + max_seq_length = 2048, # Choose any for long context! + load_in_4bit = True, # 4 bit quantization to reduce memory + load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory + full_finetuning = False, # [NEW!] We have full finetuning now! ) except Exception as e: print(f"❌ Failed to load base model: {e}") @@ -96,18 +96,18 @@ print("\n🔧 Setting up LoRA configuration...") try: model = FastVisionModel.get_peft_model( model, - finetune_vision_layers=True, # Turn off for just text! - finetune_language_layers=True, # Should leave on! - finetune_attention_modules=True, # Attention good for GRPO - finetune_mlp_modules=True, # SHould leave on always! - r=16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 - lora_alpha=32, - lora_dropout=0, # Supports any, but = 0 is optimized - bias="none", # Supports any, but = "none" is optimized - use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context - random_state=3407, - use_rslora=False, # We support rank stabilized LoRA - loftq_config=None, # And LoftQ + finetune_vision_layers = True, # Turn off for just text! + finetune_language_layers = True, # Should leave on! + finetune_attention_modules = True, # Attention good for GRPO + finetune_mlp_modules = True, # SHould leave on always! + r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 + lora_alpha = 32, + lora_dropout = 0, # Supports any, but = 0 is optimized + bias = "none", # Supports any, but = "none" is optimized + use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context + random_state = 3407, + use_rslora = False, # We support rank stabilized LoRA + loftq_config = None, # And LoftQ ) print("✅ LoRA configuration applied successfully!") print(f" 🎯 LoRA rank (r): 16") @@ -128,40 +128,40 @@ FastVisionModel.for_training(model) # Enable for training! try: trainer = SFTTrainer( - model=model, - tokenizer=tokenizer, - data_collator=UnslothVisionDataCollator(model, tokenizer), - train_dataset=train_dataset, - args=SFTConfig( + model = model, + tokenizer = tokenizer, + data_collator = UnslothVisionDataCollator(model, tokenizer), + train_dataset = train_dataset, + args = SFTConfig( # per_device_train_batch_size = 4, # gradient_accumulation_steps = 8, - per_device_train_batch_size=2, - gradient_accumulation_steps=4, - gradient_checkpointing=True, - gradient_checkpointing_kwargs={ + per_device_train_batch_size = 2, + gradient_accumulation_steps = 4, + gradient_checkpointing = True, + gradient_checkpointing_kwargs = { "use_reentrant": False }, # use reentrant checkpointing - max_grad_norm=0.3, # max gradient norm based on QLoRA paper - warmup_ratio=0.03, + max_grad_norm = 0.3, # max gradient norm based on QLoRA paper + warmup_ratio = 0.03, # num_train_epochs = 2, # Set this instead of max_steps for full training runs - max_steps=10, - learning_rate=2e-4, - fp16=not is_bf16_supported(), - bf16=is_bf16_supported(), - logging_steps=5, - save_strategy="epoch", - optim="adamw_torch_fused", - weight_decay=0.01, - lr_scheduler_type="linear", - seed=3407, - output_dir="checkpoints", - report_to="none", # For Weights and Biases + max_steps = 10, + learning_rate = 2e-4, + fp16 = not is_bf16_supported(), + bf16 = is_bf16_supported(), + logging_steps = 5, + save_strategy = "epoch", + optim = "adamw_torch_fused", + weight_decay = 0.01, + lr_scheduler_type = "linear", + seed = 3407, + output_dir = "checkpoints", + report_to = "none", # For Weights and Biases # You MUST put the below items for vision finetuning: - remove_unused_columns=False, - dataset_text_field="", - dataset_kwargs={"skip_prepare_dataset": True}, - dataset_num_proc=4, - max_seq_length=2048, + remove_unused_columns = False, + dataset_text_field = "", + dataset_kwargs = {"skip_prepare_dataset": True}, + dataset_num_proc = 4, + max_seq_length = 2048, ), ) print("✅ Trainer setup completed!") @@ -221,7 +221,7 @@ try: print("=== UPLOADING MODEL TO HUB ===".center(80)) print("=" * 80 + "\n") print(f"🚀 Uploading to repository: {repo_name}") - model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token) + model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token) success["upload"] = True print("✅ Model uploaded successfully!") except Exception as e: @@ -233,8 +233,8 @@ try: print("\n" + "=" * 80) print("=== VERIFYING REPO CONTENTS ===".center(80)) print("=" * 80 + "\n") - fs = HfFileSystem(token=hf_token) - file_list = fs.ls(repo_name, detail=True) + fs = HfFileSystem(token = hf_token) + file_list = fs.ls(repo_name, detail = True) safetensors_found = any( file["name"].endswith("model.safetensors.index.json") for file in file_list ) diff --git a/tests/saving/vision_models/test_push_to_hub_merged.py b/tests/saving/vision_models/test_push_to_hub_merged.py index 8cea6ad26..74fa05898 100644 --- a/tests/saving/vision_models/test_push_to_hub_merged.py +++ b/tests/saving/vision_models/test_push_to_hub_merged.py @@ -22,7 +22,7 @@ from tests.utils.cleanup_utils import safe_remove_directory ## Dataset Preparation""" print("\n📊 Loading and preparing dataset...") -dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split="train") +dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split = "train") # To select the first 2000 examples train_dataset = dataset.select(range(2000)) @@ -82,11 +82,11 @@ print("🤖 Loading base vision model...") try: model, tokenizer = FastVisionModel.from_pretrained( # model_name = "unsloth/Qwen2-VL-7B-Instruct", - model_name="unsloth/Qwen2-VL-2B-Instruct", - max_seq_length=2048, # Choose any for long context! - load_in_4bit=True, # 4 bit quantization to reduce memory - load_in_8bit=False, # [NEW!] A bit more accurate, uses 2x memory - full_finetuning=False, # [NEW!] We have full finetuning now! + model_name = "unsloth/Qwen2-VL-2B-Instruct", + max_seq_length = 2048, # Choose any for long context! + load_in_4bit = True, # 4 bit quantization to reduce memory + load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory + full_finetuning = False, # [NEW!] We have full finetuning now! ) except Exception as e: print(f"❌ Failed to load base model: {e}") @@ -97,18 +97,18 @@ print("\n🔧 Setting up LoRA configuration...") try: model = FastVisionModel.get_peft_model( model, - finetune_vision_layers=True, # Turn off for just text! - finetune_language_layers=True, # Should leave on! - finetune_attention_modules=True, # Attention good for GRPO - finetune_mlp_modules=True, # SHould leave on always! - r=16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 - lora_alpha=32, - lora_dropout=0, # Supports any, but = 0 is optimized - bias="none", # Supports any, but = "none" is optimized - use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context - random_state=3407, - use_rslora=False, # We support rank stabilized LoRA - loftq_config=None, # And LoftQ + finetune_vision_layers = True, # Turn off for just text! + finetune_language_layers = True, # Should leave on! + finetune_attention_modules = True, # Attention good for GRPO + finetune_mlp_modules = True, # SHould leave on always! + r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 + lora_alpha = 32, + lora_dropout = 0, # Supports any, but = 0 is optimized + bias = "none", # Supports any, but = "none" is optimized + use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context + random_state = 3407, + use_rslora = False, # We support rank stabilized LoRA + loftq_config = None, # And LoftQ ) print("✅ LoRA configuration applied successfully!") print(f" 🎯 LoRA rank (r): 16") @@ -129,40 +129,40 @@ FastVisionModel.for_training(model) # Enable for training! try: trainer = SFTTrainer( - model=model, - tokenizer=tokenizer, - data_collator=UnslothVisionDataCollator(model, tokenizer), - train_dataset=train_dataset, - args=SFTConfig( + model = model, + tokenizer = tokenizer, + data_collator = UnslothVisionDataCollator(model, tokenizer), + train_dataset = train_dataset, + args = SFTConfig( # per_device_train_batch_size = 4, # gradient_accumulation_steps = 8, - per_device_train_batch_size=2, - gradient_accumulation_steps=4, - gradient_checkpointing=True, - gradient_checkpointing_kwargs={ + per_device_train_batch_size = 2, + gradient_accumulation_steps = 4, + gradient_checkpointing = True, + gradient_checkpointing_kwargs = { "use_reentrant": False }, # use reentrant checkpointing - max_grad_norm=0.3, # max gradient norm based on QLoRA paper - warmup_ratio=0.03, + max_grad_norm = 0.3, # max gradient norm based on QLoRA paper + warmup_ratio = 0.03, # num_train_epochs = 2, # Set this instead of max_steps for full training runs - max_steps=10, - learning_rate=2e-4, - fp16=not is_bf16_supported(), - bf16=is_bf16_supported(), - logging_steps=5, - save_strategy="epoch", - optim="adamw_torch_fused", - weight_decay=0.01, - lr_scheduler_type="linear", - seed=3407, - output_dir="checkpoints", - report_to="none", # For Weights and Biases + max_steps = 10, + learning_rate = 2e-4, + fp16 = not is_bf16_supported(), + bf16 = is_bf16_supported(), + logging_steps = 5, + save_strategy = "epoch", + optim = "adamw_torch_fused", + weight_decay = 0.01, + lr_scheduler_type = "linear", + seed = 3407, + output_dir = "checkpoints", + report_to = "none", # For Weights and Biases # You MUST put the below items for vision finetuning: - remove_unused_columns=False, - dataset_text_field="", - dataset_kwargs={"skip_prepare_dataset": True}, - dataset_num_proc=4, - max_seq_length=2048, + remove_unused_columns = False, + dataset_text_field = "", + dataset_kwargs = {"skip_prepare_dataset": True}, + dataset_num_proc = 4, + max_seq_length = 2048, ), ) print("✅ Trainer setup completed!") @@ -221,7 +221,7 @@ try: print("=== UPLOADING MODEL TO HUB ===".center(80)) print("=" * 80 + "\n") print(f"🚀 Uploading to repository: {repo_name}") - model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token) + model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token) success["upload"] = True print("✅ Model uploaded successfully!") except Exception as e: diff --git a/tests/utils/aime_eval.py b/tests/utils/aime_eval.py index 281950155..131da3e50 100644 --- a/tests/utils/aime_eval.py +++ b/tests/utils/aime_eval.py @@ -24,7 +24,7 @@ def download_and_combine_aime_datasets(data_dir: str = "./data/aime") -> str: "test2025-II": "https://raw.githubusercontent.com/GAIR-NLP/AIME-Preview/main/eval/data/aime/test2025-II.jsonl", } - os.makedirs(data_dir, exist_ok=True) + os.makedirs(data_dir, exist_ok = True) combined_filepath = os.path.join(data_dir, "aime.jsonl") # Check if combined file already exists @@ -67,9 +67,9 @@ def download_and_combine_aime_datasets(data_dir: str = "./data/aime") -> str: # Write combined dataset if all_problems: - with open(combined_filepath, "w", encoding="utf-8") as f: + with open(combined_filepath, "w", encoding = "utf-8") as f: for problem in all_problems: - f.write(json.dumps(problem, ensure_ascii=False) + "\n") + f.write(json.dumps(problem, ensure_ascii = False) + "\n") print(f"✅ Combined {len(all_problems)} problems from {len(datasets)} datasets") print(f" Saved to: {combined_filepath}") @@ -92,7 +92,7 @@ def load_aime_dataset(data_dir: str = "./data/aime") -> List[Dict[str, Any]]: filepath = download_and_combine_aime_datasets(data_dir) examples = [] - with open(filepath, "r", encoding="utf-8") as f: + with open(filepath, "r", encoding = "utf-8") as f: for line_num, line in enumerate(f): line = line.strip() if line: @@ -188,20 +188,20 @@ def get_num_tokens(text, tokenizer_instance): """Count tokens in text""" if not text: return 0 - encoding = tokenizer_instance(text, return_tensors="pt") + encoding = tokenizer_instance(text, return_tensors = "pt") return len(encoding["input_ids"][0]) def evaluate_model_aime( model, tokenizer, - model_type="base", - lora_request=None, - temperature=0.3, - n_sampling=8, - max_tokens=32768, - top_p=0.95, - seed=0, + model_type = "base", + lora_request = None, + temperature = 0.3, + n_sampling = 8, + max_tokens = 32768, + top_p = 0.95, + seed = 0, ): """Evaluate model on combined AIME dataset with official configuration""" @@ -237,11 +237,11 @@ def evaluate_model_aime( # Setup sampling parameters (AIME configuration) sampling_params = SamplingParams( - temperature=temperature, - top_p=top_p, - max_tokens=max_tokens, - n=n_sampling, # Multiple samples per question - seed=seed, + temperature = temperature, + top_p = top_p, + max_tokens = max_tokens, + n = n_sampling, # Multiple samples per question + seed = seed, ) print(f"\n🔧 Configuration:") @@ -272,13 +272,13 @@ def evaluate_model_aime( # Main evaluation loop with tqdm( - total=len(eval_dataset), desc="Processing AIME problems", unit="problem" + total = len(eval_dataset), desc = "Processing AIME problems", unit = "problem" ) as pbar: for task_id, item in enumerate(eval_dataset): try: # Prepare prompt prompt_text = tokenizer.apply_chat_template( - item["prompt"], add_generation_prompt=True, tokenize=False + item["prompt"], add_generation_prompt = True, tokenize = False ) input_tokens.append(get_num_tokens(prompt_text, tokenizer)) @@ -286,9 +286,9 @@ def evaluate_model_aime( # Generate multiple responses outputs = model.fast_generate( [prompt_text], - sampling_params=sampling_params, - lora_request=lora_request, - use_tqdm=False, + sampling_params = sampling_params, + lora_request = lora_request, + use_tqdm = False, )[0].outputs # Process all generated responses @@ -413,8 +413,8 @@ def evaluate_model_aime( # Save results filename = f"aime_eval_combined_{model_type}_t{temperature}_n{n_sampling}.json" - with open(filename, "w", encoding="utf-8") as f: - json.dump({"results": results, "records": records}, f, indent=4) + with open(filename, "w", encoding = "utf-8") as f: + json.dump({"results": results, "records": records}, f, indent = 4) # Print comprehensive summary print(f"\n{'='*70}") @@ -517,27 +517,27 @@ def compare_aime_results(all_results): if all_results and "source_accuracies" in all_results[0]: datasets = list(all_results[0]["source_accuracies"].keys()) - print(f"{'Model':<15}", end="") + print(f"{'Model':<15}", end = "") for dataset in datasets: - print(f"{dataset:<15}", end="") + print(f"{dataset:<15}", end = "") print() print("-" * (15 + 15 * len(datasets))) for result in all_results: - print(f"{result['model_type']:<15}", end="") + print(f"{result['model_type']:<15}", end = "") for dataset in datasets: accuracy = result["source_accuracies"].get(dataset, 0) - print(f"{accuracy:<15.1f}", end="") + print(f"{accuracy:<15.1f}", end = "") print() # Save comparison comparison_data = { "summary": all_results, - "best_model": max(all_results, key=lambda x: x["accuracy"]), + "best_model": max(all_results, key = lambda x: x["accuracy"]), } with open("aime_model_comparison.json", "w") as f: - json.dump(comparison_data, f, indent=4) + json.dump(comparison_data, f, indent = 4) print( f"\nBest performing model: {comparison_data['best_model']['model_type']} " diff --git a/tests/utils/hf_utils.py b/tests/utils/hf_utils.py index cc5edce02..8ad6d5ad0 100644 --- a/tests/utils/hf_utils.py +++ b/tests/utils/hf_utils.py @@ -79,27 +79,27 @@ def generate_responses( skip_special_tokens: bool = True, dtype: torch.dtype = None, ): - inputs = [tokenizer(prompt, return_tensors="pt") for _ in range(num_generations)] + inputs = [tokenizer(prompt, return_tensors = "pt") for _ in range(num_generations)] keys = inputs[0].keys() batched_inputs = { - key: torch.cat([input[key] for input in inputs], dim=0).to(model.device) + key: torch.cat([input[key] for input in inputs], dim = 0).to(model.device) for key in keys } if dtype is not None: - inference_context = torch.autocast(device_type="cuda", dtype=dtype) + inference_context = torch.autocast(device_type = "cuda", dtype = dtype) else: inference_context = nullcontext() with inference_context: outputs = model.generate( **batched_inputs, - max_new_tokens=max_new_tokens, - do_sample=do_sample, - temperature=temperature, + max_new_tokens = max_new_tokens, + do_sample = do_sample, + temperature = temperature, ) - responses = tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens) + responses = tokenizer.batch_decode(outputs, skip_special_tokens = skip_special_tokens) return responses @@ -117,11 +117,11 @@ def sample_responses( model, tokenizer, prompt, - temperature=temperature, - num_generations=num_generations, - max_new_tokens=max_new_tokens, - skip_special_tokens=skip_special_tokens, - dtype=dtype, + temperature = temperature, + num_generations = num_generations, + max_new_tokens = max_new_tokens, + skip_special_tokens = skip_special_tokens, + dtype = dtype, ) return responses @@ -136,32 +136,32 @@ def setup_tokenizer(model_name, fixup_funcs: list[Callable] = []): def setup_model( model_name, quantize: bool = True, - dtype=torch.bfloat16, - peft_config=None, + dtype = torch.bfloat16, + peft_config = None, autocast_adapter: bool = True, ): if quantize: bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=dtype, + load_in_4bit = True, + bnb_4bit_use_double_quant = True, + bnb_4bit_quant_type = "nf4", + bnb_4bit_compute_dtype = dtype, ) else: bnb_config = None model = AutoModelForCausalLM.from_pretrained( model_name, - device_map="cuda:0", - attn_implementation="sdpa", - quantization_config=bnb_config, - torch_dtype=dtype, + device_map = "cuda:0", + attn_implementation = "sdpa", + quantization_config = bnb_config, + torch_dtype = dtype, ) model = prepare_model_for_kbit_training(model) if quantize else model if peft_config is not None: model = get_peft_model( - model, peft_config, autocast_adapter_dtype=autocast_adapter + model, peft_config, autocast_adapter_dtype = autocast_adapter ) return model @@ -169,19 +169,19 @@ def setup_model( def get_peft_config( lora_rank, - lora_alpha=None, - lora_dropout=0.0, - bias="none", - target_modules="all-linear", + lora_alpha = None, + lora_dropout = 0.0, + bias = "none", + target_modules = "all-linear", ): lora_alpha = lora_alpha or 2 * lora_rank peft_config = LoraConfig( - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - r=lora_rank, - bias=bias, - target_modules=target_modules, - task_type="CAUSAL_LM", + lora_alpha = lora_alpha, + lora_dropout = lora_dropout, + r = lora_rank, + bias = bias, + target_modules = target_modules, + task_type = "CAUSAL_LM", ) return peft_config @@ -191,18 +191,18 @@ def setup_trainer( tokenizer, dataset, train_args, - peft_config=None, - formatting_func=None, - collator=None, + peft_config = None, + formatting_func = None, + collator = None, ): return SFTTrainer( - model=model, - peft_config=peft_config, - train_dataset=dataset, - processing_class=tokenizer, - formatting_func=formatting_func, - data_collator=collator, - args=train_args, + model = model, + peft_config = peft_config, + train_dataset = dataset, + processing_class = tokenizer, + formatting_func = formatting_func, + data_collator = collator, + args = train_args, ) @@ -212,17 +212,17 @@ def setup_lora( dataset, peft_config, train_args, - formatting_func=None, - collator=None, + formatting_func = None, + collator = None, ): return LoraConfig( - model=model, - peft_config=peft_config, - train_dataset=dataset, - processing_class=tokenizer, - formatting_func=formatting_func, - data_collator=collator, - args=train_args, + model = model, + peft_config = peft_config, + train_dataset = dataset, + processing_class = tokenizer, + formatting_func = formatting_func, + data_collator = collator, + args = train_args, ) @@ -236,7 +236,7 @@ def convert_weights_back_to_dtype(model, dtype): param.data = param.data.to(dtype) -def fix_llama3_tokenizer(tokenizer, padding_side="right"): +def fix_llama3_tokenizer(tokenizer, padding_side = "right"): tokenizer.padding_side = padding_side added_vocab = tokenizer.get_added_vocab() pad_token = [w for w in added_vocab if "pad" in w] @@ -276,12 +276,12 @@ def _convert_lora_to_linear(module: LoraLayer, adapter_name: str = "default"): w_dq = w_dq.to(original_dtype) new_module = torch.nn.Linear( - w_dq.shape[1], w_dq.shape[0], bias=module.base_layer.bias is not None + w_dq.shape[1], w_dq.shape[0], bias = module.base_layer.bias is not None ) - new_module.weight.data = torch.nn.Parameter(w_dq, requires_grad=False) + new_module.weight.data = torch.nn.Parameter(w_dq, requires_grad = False) if module.lora_bias[adapter_name]: bias_data = module.base_layer.bias.data + module.lora_B[adapter_name].bias - new_module.bias.data = torch.nn.Parameter(bias_data, requires_grad=False) + new_module.bias.data = torch.nn.Parameter(bias_data, requires_grad = False) return new_module diff --git a/unsloth/__init__.py b/unsloth/__init__.py index c9433faaa..340bcee5e 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -39,7 +39,7 @@ if already_imported: f"to ensure all optimizations are applied. Your code may run slower or encounter " f"memory issues without these optimizations.\n\n" f"Please restructure your imports with 'import unsloth' at the top of your file.", - stacklevel=2, + stacklevel = 2, ) del already_imported, critical_modules @@ -141,7 +141,7 @@ if DEVICE_TYPE == "cuda": old_is_bf16_supported = torch.cuda.is_bf16_supported if "including_emulation" in str(inspect.signature(old_is_bf16_supported)): - def is_bf16_supported(including_emulation=False): + def is_bf16_supported(including_emulation = False): return old_is_bf16_supported(including_emulation) torch.cuda.is_bf16_supported = is_bf16_supported diff --git a/unsloth/kernels/fast_lora.py b/unsloth/kernels/fast_lora.py index a8384ca3b..60d0c318c 100644 --- a/unsloth/kernels/fast_lora.py +++ b/unsloth/kernels/fast_lora.py @@ -86,7 +86,7 @@ class LoRA_MLP(torch.autograd.Function): downS, _forward_function, _backward_function, - inplace=True, + inplace = True, ): dtype = X.dtype @@ -169,39 +169,39 @@ class LoRA_MLP(torch.autograd.Function): # d_downB = (downA.t() @ h.t()) @ dY # d_downA *= downS # d_downB *= downS - d_downA.addmm_(h.t(), dY @ downB.t(), alpha=downS, beta=0) - d_downB.addmm_(downA.t() @ h.t(), dY, alpha=downS, beta=0) + d_downA.addmm_(h.t(), dY @ downB.t(), alpha = downS, beta = 0) + d_downB.addmm_(downA.t() @ h.t(), dY, alpha = downS, beta = 0) # Up projection LoRA weights # d_upA = X.t() @ (df @ upB.t()) # d_upB = (upA.t() @ X.t()) @ df # d_upA *= upS # d_upB *= upS - d_upA.addmm_(X.t(), df @ upB.t(), alpha=upS, beta=0) - d_upB.addmm_(upA.t() @ X.t(), df, alpha=upS, beta=0) + d_upA.addmm_(X.t(), df @ upB.t(), alpha = upS, beta = 0) + d_upB.addmm_(upA.t() @ X.t(), df, alpha = upS, beta = 0) # Gate projection LoRA weights # d_gateA = X.t() @ (de @ gateB.t()) # d_gateB = (gateA.t() @ X.t()) @ de # d_gateA *= gateS # d_gateB *= gateS - d_gateA.addmm_(X.t(), de @ gateB.t(), alpha=gateS, beta=0) - d_gateB.addmm_(gateA.t() @ X.t(), de, alpha=gateS, beta=0) + d_gateA.addmm_(X.t(), de @ gateB.t(), alpha = gateS, beta = 0) + d_gateB.addmm_(gateA.t() @ X.t(), de, alpha = gateS, beta = 0) # dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS) # dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS) upW = fast_dequantize(upW.t(), upW_quant) - dX = torch.matmul(df, upW.t(), out=X if ctx.inplace else None) + dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None) del upW # dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t()) - dX.addmm_(df @ upB.t(), upA.t(), alpha=upS) + dX.addmm_(df @ upB.t(), upA.t(), alpha = upS) gateW = fast_dequantize(gateW.t(), gateW_quant) # dX += de @ gateW.t() dX.addmm_(de, gateW.t()) del gateW # dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t()) - dX.addmm_(de @ gateB.t(), gateA.t(), alpha=gateS) + dX.addmm_(de @ gateB.t(), gateA.t(), alpha = gateS) # gateW, gateW_quant, gateA, gateB, gateS, # upW, upW_quant, upA, upB, upS, @@ -232,7 +232,7 @@ class LoRA_MLP(torch.autograd.Function): from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel -def apply_lora_mlp_swiglu(self, X, inplace=True): +def apply_lora_mlp_swiglu(self, X, inplace = True): X = _maybe_fake_quantize_activations(X, self.gate_proj) gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj) @@ -264,7 +264,7 @@ def apply_lora_mlp_swiglu(self, X, inplace=True): from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel -def apply_lora_mlp_geglu_exact(self, X, inplace=True): +def apply_lora_mlp_geglu_exact(self, X, inplace = True): X = _maybe_fake_quantize_activations(X, self.gate_proj) gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj) @@ -375,7 +375,7 @@ class LoRA_QKV(torch.autograd.Function): VA, VB, VS, - inplace=True, + inplace = True, ): dtype = X.dtype @@ -452,32 +452,32 @@ class LoRA_QKV(torch.autograd.Function): # d_QB = (QA.t() @ X.t()) @ dQ # d_QA *= QS # d_QB *= QS - d_QA.addmm_(X.t(), dQ @ QB.t(), alpha=QS, beta=0) - d_QB.addmm_(QA.t() @ X.t(), dQ, alpha=QS, beta=0) + d_QA.addmm_(X.t(), dQ @ QB.t(), alpha = QS, beta = 0) + d_QB.addmm_(QA.t() @ X.t(), dQ, alpha = QS, beta = 0) # K Projection # d_KA = X.t() @ (dK @ KB.t()) # d_KB = (KA.t() @ X.t()) @ dK # d_KA *= KS # d_KB *= KS - d_KA.addmm_(X.t(), dK @ KB.t(), alpha=KS, beta=0) - d_KB.addmm_(KA.t() @ X.t(), dK, alpha=KS, beta=0) + d_KA.addmm_(X.t(), dK @ KB.t(), alpha = KS, beta = 0) + d_KB.addmm_(KA.t() @ X.t(), dK, alpha = KS, beta = 0) # V Projection # d_VA = X.t() @ (dV @ VB.t()) # d_VB = (VA.t() @ X.t()) @ dV # d_VA *= VS # d_VB *= VS - d_VA.addmm_(X.t(), dV @ VB.t(), alpha=VS, beta=0) - d_VB.addmm_(VA.t() @ X.t(), dV, alpha=VS, beta=0) + d_VA.addmm_(X.t(), dV @ VB.t(), alpha = VS, beta = 0) + d_VB.addmm_(VA.t() @ X.t(), dV, alpha = VS, beta = 0) # Combine derivatives to find dX # dQ QW = fast_dequantize(QW.t(), QW_quant) - dX = torch.matmul(dQ, QW.t(), out=X if ctx.inplace else None) + dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None) del QW # dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t())) - dX.addmm_(dQ @ QB.t(), QA.t(), alpha=QS) + dX.addmm_(dQ @ QB.t(), QA.t(), alpha = QS) # dK KW = fast_dequantize(KW.t(), KW_quant) @@ -485,7 +485,7 @@ class LoRA_QKV(torch.autograd.Function): dX.addmm_(dK, KW.t()) del KW # dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t()) - dX.addmm_(dK @ KB.t(), KA.t(), alpha=KS) + dX.addmm_(dK @ KB.t(), KA.t(), alpha = KS) # dV VW = fast_dequantize(VW.t(), VW_quant) @@ -493,7 +493,7 @@ class LoRA_QKV(torch.autograd.Function): dX.addmm_(dV, VW.t()) del VW # dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t()) - dX.addmm_(dV @ VB.t(), VA.t(), alpha=VS) + dX.addmm_(dV @ VB.t(), VA.t(), alpha = VS) # QW, QW_quant, QA, QB, QS, # KW, KW_quant, KA, KB, KS, @@ -519,7 +519,7 @@ class LoRA_QKV(torch.autograd.Function): ) -def apply_lora_qkv(self, X, inplace=True): +def apply_lora_qkv(self, X, inplace = True): X = _maybe_fake_quantize_activations(X, self.q_proj) QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj) KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj) @@ -611,15 +611,15 @@ class LoRA_W(torch.autograd.Function): # d_B = (A.t() @ X.t()) @ dY # d_A *= S # d_B *= S - d_A.addmm_(X.t(), dY @ B.t(), alpha=S, beta=0) - d_B.addmm_(A.t() @ X.t(), dY, alpha=S, beta=0) + d_A.addmm_(X.t(), dY @ B.t(), alpha = S, beta = 0) + d_B.addmm_(A.t() @ X.t(), dY, alpha = S, beta = 0) # Get derivative for dX W = fast_dequantize(W.t(), W_quant) dX = dY @ W.t() del W # dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t()) - dX.addmm_(dY @ B.t(), A.t(), alpha=S) + dX.addmm_(dY @ B.t(), A.t(), alpha = S) # W, W_quant, A, B, S return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None @@ -649,7 +649,7 @@ def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: result = self.base_layer(x, *args, **kwargs) elif adapter_names is not None: result = self._mixed_batch_forward( - x, *args, adapter_names=adapter_names, **kwargs + x, *args, adapter_names = adapter_names, **kwargs ) elif self.merged: result = self.base_layer(x, *args, **kwargs) @@ -705,11 +705,11 @@ def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: result = result + self.lora_magnitude_vector[active_adapter]( x, - lora_A=lora_A, - lora_B=lora_B, - scaling=scaling, - base_layer=self.get_base_layer(), - base_result=base_result, + lora_A = lora_A, + lora_B = lora_B, + scaling = scaling, + base_layer = self.get_base_layer(), + base_result = base_result, ) if requires_conversion: result = result.to(expected_dtype) diff --git a/unsloth/kernels/moe/benchmark/benchmark_fused_moe.py b/unsloth/kernels/moe/benchmark/benchmark_fused_moe.py index 2fe2afa1e..074cc5a56 100644 --- a/unsloth/kernels/moe/benchmark/benchmark_fused_moe.py +++ b/unsloth/kernels/moe/benchmark/benchmark_fused_moe.py @@ -56,7 +56,7 @@ def run_benchmark_forward( hidden_size = config.hidden_size X = torch.randn( - bs, seqlen, hidden_size, dtype=dtype, device=device, requires_grad=True + bs, seqlen, hidden_size, dtype = dtype, device = device, requires_grad = True ) # Forward @@ -89,7 +89,7 @@ def run_benchmark_backward( config: AutoConfig, seqlen: int, dtype: torch.dtype, - bs=1, + bs = 1, ): torch.manual_seed( SEED @@ -98,7 +98,7 @@ def run_benchmark_backward( hidden_size = config.hidden_size X = torch.randn( - bs, seqlen, hidden_size, dtype=dtype, device=device, requires_grad=True + bs, seqlen, hidden_size, dtype = dtype, device = device, requires_grad = True ) X_test = X.detach().clone().requires_grad_(True) @@ -114,14 +114,14 @@ def run_benchmark_backward( # Bench grad_output = torch.randn_like(output) - bench_backward_ref = lambda: output.backward(grad_output, retain_graph=True) # noqa: E731 - bench_backward_fused = lambda: test_output.backward(grad_output, retain_graph=True) # noqa: E731 + bench_backward_ref = lambda: output.backward(grad_output, retain_graph = True) # noqa: E731 + bench_backward_fused = lambda: test_output.backward(grad_output, retain_graph = True) # noqa: E731 ref_backward_time = do_bench( - bench_backward_ref, grad_to_none=[X, *ref_model.parameters()] + bench_backward_ref, grad_to_none = [X, *ref_model.parameters()] ) fused_backward_time = do_bench( - bench_backward_fused, grad_to_none=[X_test, *tt_model.parameters()] + bench_backward_fused, grad_to_none = [X_test, *tt_model.parameters()] ) print( f"Backward: ref {ref_backward_time:.4f}, fused {fused_backward_time:.4f}, speedup {ref_backward_time / fused_backward_time:.1f}x" @@ -138,10 +138,10 @@ def setup_model( kernel_config_fwd, kernel_config_bwd_dW, kernel_config_bwd_dX, - dX_only=False, - dW_only=False, - overlap_router_shared=False, - device="cuda", + dX_only = False, + dW_only = False, + overlap_router_shared = False, + device = "cuda", ): if isinstance(config, Qwen3MoeConfig): ref_model = Qwen3MoeSparseMoeBlock(config).to(device, dtype) @@ -149,29 +149,29 @@ def setup_model( # Triton kernel grouped gemm version of MoE Block -- this is what we're testing tt_model = Qwen3MoeFusedGroupedGEMMBlock.from_hf( ref_model, - permute_x=permute_x, - permute_y=permute_y, - autotune=autotune, - kernel_config_fwd=kernel_config_fwd, - kernel_config_bwd_dW=kernel_config_bwd_dW, - kernel_config_bwd_dX=kernel_config_bwd_dX, - dX_only=dX_only, - dW_only=dW_only, + permute_x = permute_x, + permute_y = permute_y, + autotune = autotune, + kernel_config_fwd = kernel_config_fwd, + kernel_config_bwd_dW = kernel_config_bwd_dW, + kernel_config_bwd_dX = kernel_config_bwd_dX, + dX_only = dX_only, + dW_only = dW_only, ).to(device, dtype) elif isinstance(config, Llama4TextConfig): ref_model = Llama4TextMoe(config).to(device, dtype) tt_model = Llama4TritonTextMoe( config, - overlap_router_shared=overlap_router_shared, - permute_x=permute_x, - permute_y=permute_y, - autotune=autotune, - kernel_config_fwd=kernel_config_fwd, - kernel_config_bwd_dW=kernel_config_bwd_dW, - kernel_config_bwd_dX=kernel_config_bwd_dX, - dX_only=dX_only, - dW_only=dW_only, + overlap_router_shared = overlap_router_shared, + permute_x = permute_x, + permute_y = permute_y, + autotune = autotune, + kernel_config_fwd = kernel_config_fwd, + kernel_config_bwd_dW = kernel_config_bwd_dW, + kernel_config_bwd_dX = kernel_config_bwd_dX, + dX_only = dX_only, + dW_only = dW_only, ).to(device, dtype) else: @@ -205,31 +205,31 @@ def run_benchmark( ref_model, tt_model = setup_model( model_config, - dtype=dtype, - permute_x=permute_x, - permute_y=permute_y, - autotune=autotune, - kernel_config_fwd=kernel_config_fwd, - kernel_config_bwd_dW=kernel_config_bwd_dW, - kernel_config_bwd_dX=kernel_config_bwd_dX, - dX_only=dX_only, - dW_only=dW_only, - overlap_router_shared=overlap_router_shared, + dtype = dtype, + permute_x = permute_x, + permute_y = permute_y, + autotune = autotune, + kernel_config_fwd = kernel_config_fwd, + kernel_config_bwd_dW = kernel_config_bwd_dW, + kernel_config_bwd_dX = kernel_config_bwd_dX, + dX_only = dX_only, + dW_only = dW_only, + overlap_router_shared = overlap_router_shared, ) if mode == "forward": ref_time, fused_time = run_benchmark_forward( ref_model, tt_model, - config=model_config, - seqlen=seqlen, - dtype=dtype, - autotune=autotune, - kernel_config_fwd=kernel_config_fwd, + config = model_config, + seqlen = seqlen, + dtype = dtype, + autotune = autotune, + kernel_config_fwd = kernel_config_fwd, ) else: ref_time, fused_time = run_benchmark_backward( - ref_model, tt_model, config=model_config, seqlen=seqlen, dtype=dtype + ref_model, tt_model, config = model_config, seqlen = seqlen, dtype = dtype ) if autotune: @@ -251,60 +251,60 @@ def run_benchmark( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--results_dir", type=str, default="benchmark_results") - parser.add_argument("--model", type=str, choices=["llama4", "qwen3"], required=True) - parser.add_argument("--seqlen", type=int, default=1024) + parser.add_argument("--results_dir", type = str, default = "benchmark_results") + parser.add_argument("--model", type = str, choices = ["llama4", "qwen3"], required = True) + parser.add_argument("--seqlen", type = int, default = 1024) parser.add_argument( - "--dtype", type=str, choices=["bfloat16", "float16"], default="bfloat16" + "--dtype", type = str, choices = ["bfloat16", "float16"], default = "bfloat16" ) - parser.add_argument("--permute_x", action="store_true") - parser.add_argument("--permute_y", action="store_true") - parser.add_argument("--autotune", action="store_true") - parser.add_argument("--overlap_router_shared", action="store_true") + parser.add_argument("--permute_x", action = "store_true") + parser.add_argument("--permute_y", action = "store_true") + parser.add_argument("--autotune", action = "store_true") + parser.add_argument("--overlap_router_shared", action = "store_true") parser.add_argument( "--BLOCK_SIZE_M", - nargs=2, - type=int, - default=[DEFAULT_M_BLOCK_SIZES[0], DEFAULT_M_BLOCK_SIZES[-1]], + nargs = 2, + type = int, + default = [DEFAULT_M_BLOCK_SIZES[0], DEFAULT_M_BLOCK_SIZES[-1]], ) parser.add_argument( "--BLOCK_SIZE_N", - nargs=2, - type=int, - default=[DEFAULT_N_BLOCK_SIZES[0], DEFAULT_N_BLOCK_SIZES[-1]], + nargs = 2, + type = int, + default = [DEFAULT_N_BLOCK_SIZES[0], DEFAULT_N_BLOCK_SIZES[-1]], ) parser.add_argument( "--BLOCK_SIZE_K", - nargs=2, - type=int, - default=[DEFAULT_K_BLOCK_SIZES[0], DEFAULT_K_BLOCK_SIZES[-1]], + nargs = 2, + type = int, + default = [DEFAULT_K_BLOCK_SIZES[0], DEFAULT_K_BLOCK_SIZES[-1]], ) parser.add_argument( "--num_warps", - nargs=2, - type=int, - default=[DEFAULT_NUM_WARPS[0], DEFAULT_NUM_WARPS[-1]], + nargs = 2, + type = int, + default = [DEFAULT_NUM_WARPS[0], DEFAULT_NUM_WARPS[-1]], ) parser.add_argument( "--num_stages", - nargs=2, - type=int, - default=[DEFAULT_NUM_STAGES[0], DEFAULT_NUM_STAGES[-1]], + nargs = 2, + type = int, + default = [DEFAULT_NUM_STAGES[0], DEFAULT_NUM_STAGES[-1]], ) parser.add_argument( - "--use_tma_load_w", action="store_true" + "--use_tma_load_w", action = "store_true" ) # No need to specify, will automatically parametrize these for each kernel config parser.add_argument( - "--use_tma_load_x", action="store_true" + "--use_tma_load_x", action = "store_true" ) # No need to specify, will automatically parametrize these for each kernel config parser.add_argument( - "--use_tma_load_dy", action="store_true" + "--use_tma_load_dy", action = "store_true" ) # No need to specify, will automatically parametrize these for each kernel config parser.add_argument( "--mode", - type=str, - choices=["forward", "backward", "dW", "dX"], - default="forward", + type = str, + choices = ["forward", "backward", "dW", "dX"], + default = "forward", ) args = parser.parse_args() args.dtype = getattr(torch, args.dtype) @@ -324,13 +324,13 @@ if __name__ == "__main__": ref_time, fused_time = run_benchmark( args.mode, model_config, - seqlen=args.seqlen, - dtype=args.dtype, - permute_x=args.permute_x, - permute_y=args.permute_y, - autotune=args.autotune, - overlap_router_shared=args.overlap_router_shared, - results_dir=args.results_dir, + seqlen = args.seqlen, + dtype = args.dtype, + permute_x = args.permute_x, + permute_y = args.permute_y, + autotune = args.autotune, + overlap_router_shared = args.overlap_router_shared, + results_dir = args.results_dir, ) end_time = time.time() print(f"Total time: {end_time - start_time:.4f} seconds") @@ -343,13 +343,13 @@ if __name__ == "__main__": kernel_configs = create_kernel_configs(args, args.permute_x, args.permute_y) print(f"Running {len(kernel_configs)} kernel configs") default_kernel_config_fwd = KernelConfigForward( - permute_x=args.permute_x, permute_y=args.permute_y + permute_x = args.permute_x, permute_y = args.permute_y ) default_kernel_config_bwd_dW = KernelConfigBackward_dW( - permute_x=args.permute_x, permute_y=args.permute_y + permute_x = args.permute_x, permute_y = args.permute_y ) default_kernel_config_bwd_dX = KernelConfigBackward_dX( - permute_x=args.permute_x, permute_y=args.permute_y + permute_x = args.permute_x, permute_y = args.permute_y ) results = [] for kernel_config in kernel_configs: @@ -374,21 +374,21 @@ if __name__ == "__main__": ref_time, fused_time = run_benchmark( args.mode, model_config, - seqlen=args.seqlen, - dtype=args.dtype, - permute_x=kernel_config.permute_x, - permute_y=kernel_config.permute_y, - autotune=False, - kernel_config_fwd=kernel_config_fwd, - kernel_config_bwd_dW=kernel_config_bwd_dW, - kernel_config_bwd_dX=kernel_config_bwd_dX, + seqlen = args.seqlen, + dtype = args.dtype, + permute_x = kernel_config.permute_x, + permute_y = kernel_config.permute_y, + autotune = False, + kernel_config_fwd = kernel_config_fwd, + kernel_config_bwd_dW = kernel_config_bwd_dW, + kernel_config_bwd_dX = kernel_config_bwd_dX, ) results.append( KernelResult( - torch_time=ref_time, - triton_time=fused_time, - speedup=ref_time / fused_time, - kernel_config=kernel_config, + torch_time = ref_time, + triton_time = fused_time, + speedup = ref_time / fused_time, + kernel_config = kernel_config, ) ) df = post_process_results( diff --git a/unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py b/unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py index d57105cee..a185b5fd3 100644 --- a/unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py +++ b/unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py @@ -37,15 +37,15 @@ def convert_args_to_list(args): def get_forward_configs( - BLOCK_M=DEFAULT_M_BLOCK_SIZES, - BLOCK_N=DEFAULT_N_BLOCK_SIZES, - BLOCK_K=DEFAULT_K_BLOCK_SIZES, - TMA_LOAD_X=True, - TMA_LOAD_W=True, - TMA_STORE=False, # NOTE: TMA_STORE is disabled for now - num_warps=DEFAULT_NUM_WARPS, - num_stages=DEFAULT_NUM_STAGES, - num_ctas=DEFAULT_NUM_CTAS, + BLOCK_M = DEFAULT_M_BLOCK_SIZES, + BLOCK_N = DEFAULT_N_BLOCK_SIZES, + BLOCK_K = DEFAULT_K_BLOCK_SIZES, + TMA_LOAD_X = True, + TMA_LOAD_W = True, + TMA_STORE = False, # NOTE: TMA_STORE is disabled for now + num_warps = DEFAULT_NUM_WARPS, + num_stages = DEFAULT_NUM_STAGES, + num_ctas = DEFAULT_NUM_CTAS, ): ( BLOCK_M, @@ -95,16 +95,16 @@ def get_forward_configs( kernel_configs.append( triton.Config( dict( - BLOCK_SIZE_M=block_m, - BLOCK_SIZE_N=block_n, - BLOCK_SIZE_K=block_k, - USE_TMA_LOAD_X=tma_load_x, - USE_TMA_LOAD_W=tma_load_w, - USE_TMA_STORE=tma_store, + BLOCK_SIZE_M = block_m, + BLOCK_SIZE_N = block_n, + BLOCK_SIZE_K = block_k, + USE_TMA_LOAD_X = tma_load_x, + USE_TMA_LOAD_W = tma_load_w, + USE_TMA_STORE = tma_store, ), - num_warps=w, - num_stages=s, - num_ctas=num_ctas, + num_warps = w, + num_stages = s, + num_ctas = num_ctas, ) ) @@ -112,15 +112,15 @@ def get_forward_configs( def get_dX_kernel_configs( - BLOCK_M=DEFAULT_M_BLOCK_SIZES, - BLOCK_N=DEFAULT_N_BLOCK_SIZES, - BLOCK_K=DEFAULT_K_BLOCK_SIZES, - TMA_LOAD_dY=True, - TMA_LOAD_W=True, - TMA_STORE=False, # NOTE: TMA_STORE is disabled for now - num_warps=DEFAULT_NUM_WARPS, - num_stages=DEFAULT_NUM_STAGES, - num_ctas=DEFAULT_NUM_CTAS, + BLOCK_M = DEFAULT_M_BLOCK_SIZES, + BLOCK_N = DEFAULT_N_BLOCK_SIZES, + BLOCK_K = DEFAULT_K_BLOCK_SIZES, + TMA_LOAD_dY = True, + TMA_LOAD_W = True, + TMA_STORE = False, # NOTE: TMA_STORE is disabled for now + num_warps = DEFAULT_NUM_WARPS, + num_stages = DEFAULT_NUM_STAGES, + num_ctas = DEFAULT_NUM_CTAS, ): ( BLOCK_M, @@ -170,16 +170,16 @@ def get_dX_kernel_configs( kernel_configs.append( triton.Config( dict( - BLOCK_SIZE_M=block_m, - BLOCK_SIZE_N=block_n, - BLOCK_SIZE_K=block_k, - USE_TMA_LOAD_dY=tma_load_dy, - USE_TMA_LOAD_W=tma_load_w, - USE_TMA_STORE=tma_store, + BLOCK_SIZE_M = block_m, + BLOCK_SIZE_N = block_n, + BLOCK_SIZE_K = block_k, + USE_TMA_LOAD_dY = tma_load_dy, + USE_TMA_LOAD_W = tma_load_w, + USE_TMA_STORE = tma_store, ), - num_warps=w, - num_stages=s, - num_ctas=num_ctas, + num_warps = w, + num_stages = s, + num_ctas = num_ctas, ) ) @@ -187,15 +187,15 @@ def get_dX_kernel_configs( def get_dW_kernel_configs( - BLOCK_M=DEFAULT_M_BLOCK_SIZES, - BLOCK_N=DEFAULT_N_BLOCK_SIZES, - BLOCK_K=DEFAULT_K_BLOCK_SIZES, - num_warps=DEFAULT_NUM_WARPS, - num_stages=DEFAULT_NUM_STAGES, - num_ctas=DEFAULT_NUM_CTAS, - TMA_LOAD_dY=True, - TMA_LOAD_X=True, - TMA_STORE=False, + BLOCK_M = DEFAULT_M_BLOCK_SIZES, + BLOCK_N = DEFAULT_N_BLOCK_SIZES, + BLOCK_K = DEFAULT_K_BLOCK_SIZES, + num_warps = DEFAULT_NUM_WARPS, + num_stages = DEFAULT_NUM_STAGES, + num_ctas = DEFAULT_NUM_CTAS, + TMA_LOAD_dY = True, + TMA_LOAD_X = True, + TMA_STORE = False, ): ( BLOCK_M, @@ -245,16 +245,16 @@ def get_dW_kernel_configs( kernel_configs.append( triton.Config( dict( - BLOCK_SIZE_M=block_m, - BLOCK_SIZE_N=block_n, - BLOCK_SIZE_K=block_k, - USE_TMA_LOAD_dY=tma_load_dy, - USE_TMA_LOAD_X=tma_load_x, - USE_TMA_STORE=tma_store, + BLOCK_SIZE_M = block_m, + BLOCK_SIZE_N = block_n, + BLOCK_SIZE_K = block_k, + USE_TMA_LOAD_dY = tma_load_dy, + USE_TMA_LOAD_X = tma_load_x, + USE_TMA_STORE = tma_store, ), - num_warps=w, - num_stages=s, - num_ctas=num_ctas, + num_warps = w, + num_stages = s, + num_ctas = num_ctas, ) ) diff --git a/unsloth/kernels/moe/grouped_gemm/kernels/backward.py b/unsloth/kernels/moe/grouped_gemm/kernels/backward.py index a05fb4d5d..d8bdcb57e 100644 --- a/unsloth/kernels/moe/grouped_gemm/kernels/backward.py +++ b/unsloth/kernels/moe/grouped_gemm/kernels/backward.py @@ -84,18 +84,18 @@ def _grouped_gemm_dX_kernel( if USE_TMA_LOAD_dY: dY_desc = tl._experimental_make_tensor_descriptor( dY_ptr, - shape=[TOTAL_TOKENS, N], - strides=[N, 1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + shape = [TOTAL_TOKENS, N], + strides = [N, 1], + block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N], ) if USE_TMA_LOAD_W: expert_stride = N * K w_desc = tl._experimental_make_tensor_descriptor( w_ptr, - shape=[NUM_EXPERTS, N, K], - strides=[expert_stride, K, 1], - block_shape=[1, BLOCK_SIZE_N, BLOCK_SIZE_K], + shape = [NUM_EXPERTS, N, K], + strides = [expert_stride, K, 1], + block_shape = [1, BLOCK_SIZE_N, BLOCK_SIZE_K], ) m_end = 0 @@ -104,7 +104,7 @@ def _grouped_gemm_dX_kernel( n_block_range = tl.arange(0, BLOCK_SIZE_N) k_block_range = tl.arange(0, BLOCK_SIZE_K) - for expert_idx in range(NUM_EXPERTS, flatten=FLATTEN): + for expert_idx in range(NUM_EXPERTS, flatten = FLATTEN): m_start = m_end m_size = tl.load(m_sizes_ptr + expert_idx).to(tl.int32) m_end = m_start + m_size @@ -125,9 +125,9 @@ def _grouped_gemm_dX_kernel( ) dX_desc = tl._experimental_make_tensor_descriptor( dX_ptr, - shape=[m_end, K], - strides=[K, 1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + shape = [m_end, K], + strides = [K, 1], + block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K], ) # Lower bound and upper bound are defined relative to the total tiles processed so far @@ -152,7 +152,7 @@ def _grouped_gemm_dX_kernel( ) expert_token_idx = tl.load( gather_indices_ptr + indices_to_gather, - mask=indices_to_gather < TOTAL_TOKENS, + mask = indices_to_gather < TOTAL_TOKENS, ) expert_token_offsets = expert_token_idx[:, None] @@ -210,13 +210,13 @@ def _grouped_gemm_dX_kernel( # col_mask = offs_bk[None, :] < K store_mask = row_mask # & col_mask - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype = tl.float32) # GEMM main loop for n_offset in range(0, N, BLOCK_SIZE_N): # dY block [M, N] if not USE_TMA_LOAD_dY: - dY = tl.load(dY_ptrs, mask=row_mask) + dY = tl.load(dY_ptrs, mask = row_mask) else: dY = dY_desc.load( [m_start + tile_m_idx * BLOCK_SIZE_M, n_offset] @@ -253,7 +253,7 @@ def _grouped_gemm_dX_kernel( tl.store( dX_ptr + store_idx + offs_bk[None, :], dX, - mask=store_mask, + mask = store_mask, ) # Move to the next tile within this expert group @@ -264,9 +264,9 @@ def _grouped_gemm_dX_kernel( _autotuned_grouped_gemm_dX_kernel = triton.autotune( - configs=get_dX_kernel_configs(), - prune_configs_by={"early_config_prune": prune_dX_configs}, - key=["NUM_EXPERTS", "NUM_TOKENS", "N", "K", "PERMUTE_X", "PERMUTE_Y"], + configs = get_dX_kernel_configs(), + prune_configs_by = {"early_config_prune": prune_dX_configs}, + key = ["NUM_EXPERTS", "NUM_TOKENS", "N", "K", "PERMUTE_X", "PERMUTE_Y"], )(_grouped_gemm_dX_kernel) """ @@ -324,17 +324,17 @@ def _grouped_gemm_dW_kernel( if USE_TMA_LOAD_dY and not TMA_LOAD_BOTH: dY_desc = tl._experimental_make_tensor_descriptor( dY_ptr, - shape=[TOTAL_TOKENS, N], - strides=[N, 1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + shape = [TOTAL_TOKENS, N], + strides = [N, 1], + block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N], ) if USE_TMA_LOAD_X and not TMA_LOAD_BOTH: x_desc = tl._experimental_make_tensor_descriptor( x_ptr, - shape=[TOTAL_TOKENS, K], - strides=[K, 1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + shape = [TOTAL_TOKENS, K], + strides = [K, 1], + block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K], ) # Output tiles per expert, since each expert weight matrix is [N, K] num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) @@ -351,9 +351,9 @@ def _grouped_gemm_dW_kernel( tl.static_assert(K % BLOCK_SIZE_K == 0, "K must be divisible by BLOCK_SIZE_K") dW_desc = tl._experimental_make_tensor_descriptor( dW_ptr, - shape=[NUM_EXPERTS, N, K], - strides=[N * K, K, 1], - block_shape=[1, BLOCK_SIZE_N, BLOCK_SIZE_K], + shape = [NUM_EXPERTS, N, K], + strides = [N * K, K, 1], + block_shape = [1, BLOCK_SIZE_N, BLOCK_SIZE_K], ) for tile_idx in range( @@ -377,7 +377,7 @@ def _grouped_gemm_dW_kernel( m_end = 0 for expert_idx in range(NUM_EXPERTS): # We need to instantiate a fresh accumulator for each expert - accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=acc_dtype) + accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype = acc_dtype) m_start = m_end # Need to figure out why this cast is needed, otherwise compiler complains about mismatching types @@ -392,16 +392,16 @@ def _grouped_gemm_dW_kernel( if TMA_LOAD_BOTH: dY_desc = tl._experimental_make_tensor_descriptor( dY_ptr, - shape=[m_end, N], - strides=[N, 1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + shape = [m_end, N], + strides = [N, 1], + block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N], ) x_desc = tl._experimental_make_tensor_descriptor( x_ptr, - shape=[m_end, K], - strides=[K, 1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + shape = [m_end, K], + strides = [K, 1], + block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K], ) for tile_m_idx in range(0, m_size, BLOCK_SIZE_M): @@ -425,7 +425,7 @@ def _grouped_gemm_dW_kernel( # indices_to_gather = m_start + gather_offsets expert_token_idx = tl.load( gather_indices_ptr + indices_to_gather, - mask=indices_to_gather < TOTAL_TOKENS, + mask = indices_to_gather < TOTAL_TOKENS, ) expert_token_offsets = expert_token_idx[:, None] @@ -461,7 +461,7 @@ def _grouped_gemm_dW_kernel( x_ptr + x_row_load_idx + (k_offset + block_range_k)[None, :], - mask=mk_mask, + mask = mk_mask, ) if USE_TMA_LOAD_dY: @@ -471,7 +471,7 @@ def _grouped_gemm_dW_kernel( dY_ptr + dY_row_load_idx + (n_offset + block_range_n)[None, :], - mask=mn_mask, + mask = mn_mask, ) accumulator += tl.dot( @@ -491,12 +491,12 @@ def _grouped_gemm_dW_kernel( + store_row_offs[:, None] * K + (k_offset + block_range_k)[None, :], y, - mask=nk_mask, + mask = nk_mask, ) _autotuned_grouped_gemm_dW_kernel = triton.autotune( - configs=get_dW_kernel_configs(), - prune_configs_by={"early_config_prune": prune_kernel_configs_backward_dW}, - key=["NUM_EXPERTS", "NUM_TOKENS", "N", "K", "PERMUTE_X", "PERMUTE_Y"], + configs = get_dW_kernel_configs(), + prune_configs_by = {"early_config_prune": prune_kernel_configs_backward_dW}, + key = ["NUM_EXPERTS", "NUM_TOKENS", "N", "K", "PERMUTE_X", "PERMUTE_Y"], )(_grouped_gemm_dW_kernel) diff --git a/unsloth/kernels/moe/grouped_gemm/reference/layers/llama4_moe.py b/unsloth/kernels/moe/grouped_gemm/reference/layers/llama4_moe.py index 1474c95bd..4010c77ce 100644 --- a/unsloth/kernels/moe/grouped_gemm/reference/layers/llama4_moe.py +++ b/unsloth/kernels/moe/grouped_gemm/reference/layers/llama4_moe.py @@ -51,9 +51,9 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe): def __init__( self, config: Llama4TextConfig, - overlap_router_shared=False, - verbose=False, - debug=False, + overlap_router_shared = False, + verbose = False, + debug = False, ): super().__init__(config) self.overlap_router_shared = overlap_router_shared @@ -136,7 +136,7 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe): hidden_states = hidden_states.view(-1, self.hidden_dim) router_logits = self.router(hidden_states) routing_weights, selected_experts = torch.topk( - router_logits, self.top_k, dim=-1 + router_logits, self.top_k, dim = -1 ) routing_weights = F.sigmoid(routing_weights.float()).to(hidden_states.dtype) @@ -195,7 +195,7 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe): ) if self.top_k > 1: - hidden_states = hidden_states.sum(dim=1) + hidden_states = hidden_states.sum(dim = 1) hidden_states_after_weight_merge = hidden_states.view(-1, hidden_dim) # 1. Compute tokens per expert and indices for gathering tokes from token order to expert order @@ -212,7 +212,7 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe): # Start expert computation first_gemm = torch_grouped_gemm( - X=hidden_states, W=self.experts.gate_up_proj, m_sizes=token_counts_by_expert + X = hidden_states, W = self.experts.gate_up_proj, m_sizes = token_counts_by_expert ) assert first_gemm.shape == (total_tokens, 2 * self.experts.expert_dim) @@ -221,7 +221,7 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe): # See comment above second_gemm = torch_grouped_gemm( - X=intermediate, W=self.experts.down_proj, m_sizes=token_counts_by_expert + X = intermediate, W = self.experts.down_proj, m_sizes = token_counts_by_expert ) assert second_gemm.shape == (total_tokens, hidden_dim) @@ -234,17 +234,17 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe): result = ( Llama4MoeResult( - token_counts_by_expert=token_counts_by_expert, - gather_indices=gather_indices, - topk_weights=routing_weights, - hidden_states_after_weight_merge=hidden_states_after_weight_merge, - first_gemm=first_gemm, - intermediate=intermediate, - second_gemm=second_gemm, - hidden_states_unpermute=hidden_states_unpermute, - shared_expert_out=shared_expert_out, - final_out=final_out, - router_logits=router_logits, + token_counts_by_expert = token_counts_by_expert, + gather_indices = gather_indices, + topk_weights = routing_weights, + hidden_states_after_weight_merge = hidden_states_after_weight_merge, + first_gemm = first_gemm, + intermediate = intermediate, + second_gemm = second_gemm, + hidden_states_unpermute = hidden_states_unpermute, + shared_expert_out = shared_expert_out, + final_out = final_out, + router_logits = router_logits, ) if self.debug else (final_out, routing_weights) @@ -257,7 +257,7 @@ class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe): def __init__( self, config: Llama4TextConfig, - overlap_router_shared=False, + overlap_router_shared = False, permute_x: bool = False, permute_y: bool = True, autotune: bool = True, @@ -266,9 +266,9 @@ class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe): kernel_config_bwd_dX: KernelConfigBackward_dX = None, dW_only: bool = False, dX_only: bool = False, - verbose=False, + verbose = False, ): - super().__init__(config, overlap_router_shared=overlap_router_shared) + super().__init__(config, overlap_router_shared = overlap_router_shared) assert not permute_x, "Llama4 triton grouped gemm does not support permute x due to pre-multiplication of router weights" self.permute_x = permute_x self.permute_y = permute_y @@ -321,7 +321,7 @@ class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe): hidden_states = hidden_states.view(-1, self.hidden_dim) router_logits = self.router(hidden_states) routing_weights, selected_experts = torch.topk( - router_logits, self.top_k, dim=-1 + router_logits, self.top_k, dim = -1 ) routing_weights = F.sigmoid(routing_weights.float()).to(hidden_states.dtype) @@ -380,7 +380,7 @@ class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe): ) if self.top_k > 1: - hidden_states = hidden_states.sum(dim=1) + hidden_states = hidden_states.sum(dim = 1) hidden_states = hidden_states.view(-1, hidden_dim) # 1. Compute tokens per expert and indices for gathering tokes from token order to expert order @@ -395,37 +395,37 @@ class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe): # Start expert computation hidden_states = grouped_gemm( - X=hidden_states, - W=self.experts.gate_up_proj, - m_sizes=token_counts_by_expert, - gather_indices=gather_indices, - topk=self.top_k, - permute_x=self.permute_x, - permute_y=False, # output of first grouped gemm should never be permuted - autotune=self.autotune, - kernel_config_fwd=self.kernel_config_fwd, - kernel_config_bwd_dW=self.kernel_config_bwd_dW, - kernel_config_bwd_dX=self.kernel_config_bwd_dX, - is_first_gemm=True, - dW_only=self.dW_only, - dX_only=self.dX_only, + X = hidden_states, + W = self.experts.gate_up_proj, + m_sizes = token_counts_by_expert, + gather_indices = gather_indices, + topk = self.top_k, + permute_x = self.permute_x, + permute_y = False, # output of first grouped gemm should never be permuted + autotune = self.autotune, + kernel_config_fwd = self.kernel_config_fwd, + kernel_config_bwd_dW = self.kernel_config_bwd_dW, + kernel_config_bwd_dX = self.kernel_config_bwd_dX, + is_first_gemm = True, + dW_only = self.dW_only, + dX_only = self.dX_only, ) hidden_states = self.act_and_mul(hidden_states) hidden_states = grouped_gemm( - X=hidden_states, - W=self.experts.down_proj, - m_sizes=token_counts_by_expert, - gather_indices=gather_indices, - topk=self.top_k, - permute_x=False, - permute_y=self.permute_y, - autotune=self.autotune, - kernel_config_fwd=self.kernel_config_fwd, - kernel_config_bwd_dW=self.kernel_config_bwd_dW, - kernel_config_bwd_dX=self.kernel_config_bwd_dX, - is_first_gemm=False, - dW_only=self.dW_only, - dX_only=self.dX_only, + X = hidden_states, + W = self.experts.down_proj, + m_sizes = token_counts_by_expert, + gather_indices = gather_indices, + topk = self.top_k, + permute_x = False, + permute_y = self.permute_y, + autotune = self.autotune, + kernel_config_fwd = self.kernel_config_fwd, + kernel_config_bwd_dW = self.kernel_config_bwd_dW, + kernel_config_bwd_dX = self.kernel_config_bwd_dX, + is_first_gemm = False, + dW_only = self.dW_only, + dX_only = self.dX_only, ) # Post-processing diff --git a/unsloth/kernels/moe/grouped_gemm/reference/moe_block.py b/unsloth/kernels/moe/grouped_gemm/reference/moe_block.py index cbccf19cb..0d497f380 100644 --- a/unsloth/kernels/moe/grouped_gemm/reference/moe_block.py +++ b/unsloth/kernels/moe/grouped_gemm/reference/moe_block.py @@ -80,14 +80,14 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock): gate, gate_up_proj, down_proj, - permute_x=permute_x, - permute_y=permute_y, - autotune=autotune, - kernel_config_fwd=kernel_config_fwd, - kernel_config_bwd_dW=kernel_config_bwd_dW, - kernel_config_bwd_dX=kernel_config_bwd_dX, - dW_only=dW_only, - dX_only=dX_only, + permute_x = permute_x, + permute_y = permute_y, + autotune = autotune, + kernel_config_fwd = kernel_config_fwd, + kernel_config_bwd_dW = kernel_config_bwd_dW, + kernel_config_bwd_dX = kernel_config_bwd_dX, + dW_only = dW_only, + dX_only = dX_only, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -112,37 +112,37 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock): hidden_states = permute(hidden_states, gather_indices, self.top_k) # Start expert computation hidden_states = grouped_gemm( - X=hidden_states, - W=self.gate_up_proj, - m_sizes=token_counts_by_expert, - gather_indices=gather_indices, - topk=self.top_k, - permute_x=self.permute_x, - permute_y=False, # output of first grouped gemm should never be permuted - autotune=self.autotune, - kernel_config_fwd=self.kernel_config_fwd, - kernel_config_bwd_dW=self.kernel_config_bwd_dW, - kernel_config_bwd_dX=self.kernel_config_bwd_dX, - is_first_gemm=True, - dW_only=self.dW_only, - dX_only=self.dX_only, + X = hidden_states, + W = self.gate_up_proj, + m_sizes = token_counts_by_expert, + gather_indices = gather_indices, + topk = self.top_k, + permute_x = self.permute_x, + permute_y = False, # output of first grouped gemm should never be permuted + autotune = self.autotune, + kernel_config_fwd = self.kernel_config_fwd, + kernel_config_bwd_dW = self.kernel_config_bwd_dW, + kernel_config_bwd_dX = self.kernel_config_bwd_dX, + is_first_gemm = True, + dW_only = self.dW_only, + dX_only = self.dX_only, ) hidden_states = self.act_and_mul(hidden_states) hidden_states = grouped_gemm( - X=hidden_states, - W=self.down_proj, - m_sizes=token_counts_by_expert, - gather_indices=gather_indices, - topk=self.top_k, - permute_x=False, - permute_y=self.permute_y, - autotune=self.autotune, - kernel_config_fwd=self.kernel_config_fwd, - kernel_config_bwd_dW=self.kernel_config_bwd_dW, - kernel_config_bwd_dX=self.kernel_config_bwd_dX, - is_first_gemm=False, - dW_only=self.dW_only, - dX_only=self.dX_only, + X = hidden_states, + W = self.down_proj, + m_sizes = token_counts_by_expert, + gather_indices = gather_indices, + topk = self.top_k, + permute_x = False, + permute_y = self.permute_y, + autotune = self.autotune, + kernel_config_fwd = self.kernel_config_fwd, + kernel_config_bwd_dW = self.kernel_config_bwd_dW, + kernel_config_bwd_dX = self.kernel_config_bwd_dX, + is_first_gemm = False, + dW_only = self.dW_only, + dX_only = self.dX_only, ) # Post-processing @@ -155,7 +155,7 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock): hidden_states.view(num_tokens, self.top_k, hidden_dim) * routing_weights[..., None] ) - hidden_states = hidden_states.sum(dim=1) + hidden_states = hidden_states.sum(dim = 1) hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim) return hidden_states, router_logits diff --git a/unsloth/kernels/moe/grouped_gemm/reference/moe_ops.py b/unsloth/kernels/moe/grouped_gemm/reference/moe_ops.py index 821b06c97..46d9c3c51 100644 --- a/unsloth/kernels/moe/grouped_gemm/reference/moe_ops.py +++ b/unsloth/kernels/moe/grouped_gemm/reference/moe_ops.py @@ -56,7 +56,7 @@ def calculate_topk( gating_output.dtype ) else: - scores = F.softmax(gating_output.to(torch.float32), dim=1).to( + scores = F.softmax(gating_output.to(torch.float32), dim = 1).to( gating_output.dtype ) @@ -67,13 +67,13 @@ def calculate_topk( else: scores = gating_output - topk_weights, topk_ids = torch.topk(scores, k=top_k, dim=1) + topk_weights, topk_ids = torch.topk(scores, k = top_k, dim = 1) if post_act: topk_weights = _activation(topk_weights) if renormalize: - topk_weights /= torch.sum(topk_weights, dim=-1, keepdim=True).to( + topk_weights /= torch.sum(topk_weights, dim = -1, keepdim = True).to( gating_output.dtype ) @@ -94,12 +94,12 @@ def get_routing_indices( # group tokens together by expert indices from 0 to num_experts and pass that to experts forward token_counts_by_expert = torch.histc( selected_experts.view(-1), - bins=num_experts, - min=0, - max=num_experts, + bins = num_experts, + min = 0, + max = num_experts, ) # token_indices_experts_sorted shape (bs*slen*top_k,) - gather_indices = torch.argsort(selected_experts.view(-1), stable=True) + gather_indices = torch.argsort(selected_experts.view(-1), stable = True) if return_scatter_indices: scatter_indices = gather_indices.argsort() return token_counts_by_expert, gather_indices, scatter_indices @@ -107,7 +107,7 @@ def get_routing_indices( return token_counts_by_expert, gather_indices -def torch_grouped_gemm(X, W, m_sizes, transpose=True): +def torch_grouped_gemm(X, W, m_sizes, transpose = True): """ X: [M, K] if forward, else [M, N] W: [E, N, K] @@ -127,7 +127,7 @@ def torch_grouped_gemm(X, W, m_sizes, transpose=True): N = W.shape[1] - result = torch.zeros((M, N), dtype=X.dtype, device=X.device) + result = torch.zeros((M, N), dtype = X.dtype, device = X.device) m_start = 0 for g in range(E): diff --git a/unsloth/kernels/moe/tests/test_llama4_moe.py b/unsloth/kernels/moe/tests/test_llama4_moe.py index 27d5d99a2..13ad552bf 100644 --- a/unsloth/kernels/moe/tests/test_llama4_moe.py +++ b/unsloth/kernels/moe/tests/test_llama4_moe.py @@ -37,7 +37,7 @@ NUM_AUTOTUNE_CONFIGS = 50 @contextmanager -def annotated_context(prelude, epilogue="Passed!", char="-", num_chars=80): +def annotated_context(prelude, epilogue = "Passed!", char = "-", num_chars = 80): print(char * num_chars) print(prelude) yield @@ -81,7 +81,7 @@ def prep_triton_kernel_traits(autotune): def sparse_to_dense(t: torch.Tensor): - t = t.sum(dim=0).view(-1) + t = t.sum(dim = 0).view(-1) return t @@ -91,9 +91,9 @@ def _check_diff( t2: torch.Tensor, atol, rtol, - precision=".6f", - verbose=False, - msg="", + precision = ".6f", + verbose = False, + msg = "", ): t2 = t2.view_as(t1) diff = t1.sub(t2).abs().max().item() @@ -101,7 +101,7 @@ def _check_diff( if msg == "": msg = "diff" print(f"{msg}: {diff:{precision}}") - assert torch.allclose(t1, t2, atol=atol, rtol=rtol) + assert torch.allclose(t1, t2, atol = atol, rtol = rtol) def run_backwards(y: torch.Tensor, grad_output: torch.Tensor, module: torch.nn.Module): @@ -115,19 +115,19 @@ def _check_grads( m2: torch.nn.Module, atol, rtol, - precision=".6f", - verbose=False, - msg="", + precision = ".6f", + verbose = False, + msg = "", ): for name, param in m1.named_parameters(): _check_diff( param.grad, m2.get_parameter(name).grad, - atol=atol, - rtol=rtol, - precision=precision, - verbose=verbose, - msg=f"{msg}:{name}.grad", + atol = atol, + rtol = rtol, + precision = precision, + verbose = verbose, + msg = f"{msg}:{name}.grad", ) @@ -139,19 +139,19 @@ def model_config(): @pytest.mark.parametrize( "overlap_router_shared", [False, True], - ids=lambda x: "overlap_router_shared" if x else "no_overlap", + ids = lambda x: "overlap_router_shared" if x else "no_overlap", ) @pytest.mark.parametrize( - "permute_y", [False, True], ids=lambda x: "permute_y" if x else "no_permute_y" + "permute_y", [False, True], ids = lambda x: "permute_y" if x else "no_permute_y" ) @pytest.mark.parametrize( - "permute_x", [False], ids=lambda x: "permute_x" if x else "no_permute_x" + "permute_x", [False], ids = lambda x: "permute_x" if x else "no_permute_x" ) # Llama4 does not support permute_x @pytest.mark.parametrize( - "autotune", [True], ids=lambda x: "autotune" if x else "manual" + "autotune", [True], ids = lambda x: "autotune" if x else "manual" ) -@pytest.mark.parametrize("seqlen", SEQ_LENS, ids=lambda x: f"seqlen={x}") -@pytest.mark.parametrize("dtype", DTYPES, ids=str) +@pytest.mark.parametrize("seqlen", SEQ_LENS, ids = lambda x: f"seqlen={x}") +@pytest.mark.parametrize("dtype", DTYPES, ids = str) def test_llama4_ref( dtype: torch.dtype, seqlen, @@ -161,9 +161,9 @@ def test_llama4_ref( overlap_router_shared: bool, model_config: Llama4TextConfig, # test fixture bs: int = 1, - device="cuda", - precision=".6f", - verbose=False, + device = "cuda", + precision = ".6f", + verbose = False, ): torch.manual_seed( SEED @@ -172,24 +172,24 @@ def test_llama4_ref( hidden_dim = model_config.hidden_size atol, rtol = TOLERANCES[dtype] check_diff = partial( - _check_diff, atol=atol, rtol=rtol, precision=precision, verbose=verbose + _check_diff, atol = atol, rtol = rtol, precision = precision, verbose = verbose ) check_grads = partial( - _check_grads, atol=atol, rtol=rtol, precision=precision, verbose=verbose + _check_grads, atol = atol, rtol = rtol, precision = precision, verbose = verbose ) # Reference op -- HF - llama4_ref = Llama4TextMoe(model_config).to(dtype=dtype, device=device) + llama4_ref = Llama4TextMoe(model_config).to(dtype = dtype, device = device) # Torch grouped gemm impl llama4_gg_ref = Llama4GroupedGemmTextMoe( - model_config, overlap_router_shared=overlap_router_shared - ).to(dtype=dtype, device=device) + model_config, overlap_router_shared = overlap_router_shared + ).to(dtype = dtype, device = device) llama4_gg_ref.copy_weights(llama4_ref) llama4_gg_ref.check_weights(llama4_ref) x_ref = torch.randn( - bs, seqlen, hidden_dim, dtype=dtype, device=device, requires_grad=True + bs, seqlen, hidden_dim, dtype = dtype, device = device, requires_grad = True ) x_torch_gg = x_ref.detach().clone().requires_grad_() x_triton = x_ref.detach().clone().requires_grad_() @@ -198,9 +198,9 @@ def test_llama4_ref( y_torch_gg, routing_torch_gg = llama4_gg_ref(x_torch_gg) assert y_ref.shape == y_torch_gg.shape, f"{y_ref.shape} != {y_torch_gg.shape}" with annotated_context("Testing torch grouped gemm Llama4TextMoe"): - check_diff(y_ref, y_torch_gg, msg="y_torch_gg") + check_diff(y_ref, y_torch_gg, msg = "y_torch_gg") check_diff( - sparse_to_dense(routing_ref), routing_torch_gg, msg="routing_torch_gg" + sparse_to_dense(routing_ref), routing_torch_gg, msg = "routing_torch_gg" ) kernel_config_fwd, kernel_config_bwd_dW, kernel_config_bwd_dX = ( @@ -209,38 +209,38 @@ def test_llama4_ref( llama4_triton = Llama4TritonTextMoe( model_config, - overlap_router_shared=overlap_router_shared, - permute_x=permute_x, - permute_y=permute_y, - autotune=autotune, - kernel_config_fwd=kernel_config_fwd, - kernel_config_bwd_dW=kernel_config_bwd_dW, - kernel_config_bwd_dX=kernel_config_bwd_dX, - ).to(device=device, dtype=dtype) + overlap_router_shared = overlap_router_shared, + permute_x = permute_x, + permute_y = permute_y, + autotune = autotune, + kernel_config_fwd = kernel_config_fwd, + kernel_config_bwd_dW = kernel_config_bwd_dW, + kernel_config_bwd_dX = kernel_config_bwd_dX, + ).to(device = device, dtype = dtype) llama4_triton.copy_weights(llama4_ref) llama4_triton.check_weights(llama4_ref) y_triton, routing_triton = llama4_triton(x_triton) with annotated_context("Testing triton grouped gemm Llama4TextMoe forward"): - check_diff(y_ref, y_triton, msg="y_triton") - check_diff(sparse_to_dense(routing_ref), routing_triton, msg="routing_triton") + check_diff(y_ref, y_triton, msg = "y_triton") + check_diff(sparse_to_dense(routing_ref), routing_triton, msg = "routing_triton") ref_grad = torch.randn_like(y_ref) run_backwards(y_ref, ref_grad, llama4_ref) run_backwards(y_torch_gg, ref_grad, llama4_gg_ref) with annotated_context("Testing torch group gemm Llama4TextMoe backward"): - check_grads(llama4_ref, llama4_gg_ref, msg="torch_gg") + check_grads(llama4_ref, llama4_gg_ref, msg = "torch_gg") run_backwards(y_triton, ref_grad, llama4_triton) with annotated_context("Testing triton group gemm Llama4TextMoe backward"): - check_grads(llama4_ref, llama4_triton, msg="triton") + check_grads(llama4_ref, llama4_triton, msg = "triton") if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--seqlen", type=int, default=1024) + parser.add_argument("--seqlen", type = int, default = 1024) parser.add_argument( - "--dtype", type=str, choices=["bfloat16", "float16"], default="bfloat16" + "--dtype", type = str, choices = ["bfloat16", "float16"], default = "bfloat16" ) args = parser.parse_args() args.dtype = getattr(torch, args.dtype) @@ -251,12 +251,12 @@ if __name__ == "__main__": text_config: Llama4TextConfig = get_text_config(model_id) for overlap in [False, True]: test_llama4_ref( - seqlen=args.seqlen, - model_config=text_config, - dtype=args.dtype, - autotune=True, - permute_x=False, - permute_y=True, - overlap_router_shared=overlap, - verbose=True, + seqlen = args.seqlen, + model_config = text_config, + dtype = args.dtype, + autotune = True, + permute_x = False, + permute_y = True, + overlap_router_shared = overlap, + verbose = True, ) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 359b9a41b..accbed11b 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -45,16 +45,16 @@ def _rms_layernorm_forward( X += row_idx * X_row_stride r += row_idx * r_row_stride - X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32) - W_row = tl.load(W + col_offsets, mask=mask, other=0) # .to(tl.float32) + X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32) + W_row = tl.load(W + col_offsets, mask = mask, other = 0) # .to(tl.float32) - row_var = tl.sum(X_row * X_row, axis=0) / n_cols + row_var = tl.sum(X_row * X_row, axis = 0) / n_cols inv_var = tl.math.rsqrt(row_var + eps) tl.store(r, inv_var) normed = X_row * inv_var normed = normed.to(W_row.dtype) # Exact copy from HF output = normed * W_row - tl.store(Y + col_offsets, output, mask=mask) + tl.store(Y + col_offsets, output, mask = mask) def _rms_layernorm_backward( @@ -92,9 +92,9 @@ def _rms_layernorm_backward( else: dX = dY - dY_row = tl.load(dY + col_offsets, mask=mask, other=0).to(tl.float32) - X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32) - W_row = tl.load(W + col_offsets, mask=mask, other=0).to(tl.float32) + dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32) + X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32) + W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32) # Get saved row variance inv_var = tl.load(r).to(tl.float32) @@ -105,9 +105,9 @@ def _rms_layernorm_backward( else: dY_W = dY_row * W_row - rowsum_dY_normed = tl.sum(dY_W * normed, axis=0) + rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0) output = inv_var / n_cols * (n_cols * dY_W - normed * rowsum_dY_normed) - tl.store(dX + col_offsets, output, mask=mask) + tl.store(dX + col_offsets, output, mask = mask) _rms_layernorm_backward = triton.jit(_rms_layernorm_backward) @@ -143,16 +143,16 @@ def _gemma_rms_layernorm_forward( X += row_idx * X_row_stride r += row_idx * r_row_stride - X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32) - W_row = tl.load(W + col_offsets, mask=mask, other=0).to(tl.float32) + X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32) + W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32) - row_var = tl.sum(X_row * X_row, axis=0) / n_cols + row_var = tl.sum(X_row * X_row, axis = 0) / n_cols inv_var = tl.math.rsqrt(row_var + eps) tl.store(r, inv_var) normed = X_row * inv_var output = normed * (W_row + 1.0) - tl.store(Y + col_offsets, output, mask=mask) + tl.store(Y + col_offsets, output, mask = mask) class Fast_RMS_Layernorm(torch.autograd.Function): @@ -169,8 +169,8 @@ class Fast_RMS_Layernorm(torch.autograd.Function): BLOCK_SIZE, num_warps = calculate_settings(n_cols) device = X.device - Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=device) - r = torch.empty(n_rows, dtype=torch.float32, device=device) + Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device) + r = torch.empty(n_rows, dtype = torch.float32, device = device) fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward with torch_gpu_device(device): @@ -185,8 +185,8 @@ class Fast_RMS_Layernorm(torch.autograd.Function): r.stride(0), n_cols, eps, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, + BLOCK_SIZE = BLOCK_SIZE, + num_warps = num_warps, ) ctx.eps = eps ctx.BLOCK_SIZE = BLOCK_SIZE @@ -222,9 +222,9 @@ class Fast_RMS_Layernorm(torch.autograd.Function): # dW, dW.stride(0), n_cols, ctx.eps, - GEMMA=ctx.GEMMA, - BLOCK_SIZE=ctx.BLOCK_SIZE, - num_warps=ctx.num_warps, + GEMMA = ctx.GEMMA, + BLOCK_SIZE = ctx.BLOCK_SIZE, + num_warps = ctx.num_warps, ) dX = dX.view(*shape) return dX, None, None, None @@ -248,7 +248,7 @@ from transformers.models.llama.modeling_llama import LlamaRMSNorm class Unsloth_LlamaRMSNorm(LlamaRMSNorm): def forward(self, X): - return fast_rms_layernorm(self, X, gemma=False) + return fast_rms_layernorm(self, X, gemma = False) try: @@ -256,7 +256,7 @@ try: class Unsloth_MllamaTextRMSNorm(MllamaTextRMSNorm): def forward(self, X): - return fast_rms_layernorm(self, X, gemma=False) + return fast_rms_layernorm(self, X, gemma = False) except: @@ -292,25 +292,25 @@ def unpatch_rms_layernorm(): def test_rms_layernorm( - dim=1024, - eps=1e-5, - dtype=torch.float16, - bsz=21, - random_state=3407, - seqlen=3341, + dim = 1024, + eps = 1e-5, + dtype = torch.float16, + bsz = 21, + random_state = 3407, + seqlen = 3341, ): from transformers.models.llama.modeling_llama import LlamaRMSNorm - layernorm = LlamaRMSNorm((dim,), eps=eps).to("cuda") + layernorm = LlamaRMSNorm((dim,), eps = eps).to("cuda") torch.cuda.manual_seed(random_state) torch.manual_seed(random_state) torch.nn.init.uniform_(layernorm.weight) - X = torch.randn((bsz, seqlen, dim), dtype=dtype, device="cuda") + X = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda") XX = X.clone() X.requires_grad_(True) XX.requires_grad_(True) Y = layernorm(X) - YY = torch.randn((bsz, seqlen, dim), dtype=dtype, device="cuda", requires_grad=True) + YY = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda", requires_grad = True) Y.backward(YY) correct_grad = X.grad.clone() # from unsloth.kernels import fast_rms_layernorm @@ -322,14 +322,14 @@ def test_rms_layernorm( def testing_suite_layernorm(): for dim in [512, 1024, 2048]: for dtype in [torch.float16, torch.bfloat16]: - with torch.autocast(device_type="cuda", dtype=dtype): + with torch.autocast(device_type = "cuda", dtype = dtype): for seqlen in [3341, 2048, 349]: for random_state in [3407, 42]: test_rms_layernorm( - dim=dim, - eps=1e-5, - dtype=dtype, - bsz=21, - random_state=random_state, - seqlen=seqlen, + dim = dim, + eps = 1e-5, + dtype = dtype, + bsz = 21, + random_state = random_state, + seqlen = seqlen, ) diff --git a/unsloth/kernels/rope_embedding.py b/unsloth/kernels/rope_embedding.py index 90e151373..8eecec10c 100644 --- a/unsloth/kernels/rope_embedding.py +++ b/unsloth/kernels/rope_embedding.py @@ -50,16 +50,16 @@ def _rope_embedding( + (row_position % seqlen) * sin_row_stride + half_head_dim * 0 + col_offsets, - mask=mask, - other=0, + mask = mask, + other = 0, ) cos1 = tl.load( cos + (row_position % seqlen) * cos_row_stride + half_head_dim * 0 + col_offsets, - mask=mask, - other=0, + mask = mask, + other = 0, ) if BACKWARD_PASS: @@ -78,11 +78,11 @@ def _rope_embedding( ) # For Gemma - sometimes RoPE must be done in float32 and not bfloat16 - Q1 = tl.load(Q + offs_q1, mask=mask, other=0).to(sin1.dtype) - Q2 = tl.load(Q + offs_q2, mask=mask, other=0).to(sin1.dtype) + Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype) + Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype) - tl.store(Q + offs_q1, Q1 * cos1 - Q2 * sin1, mask=mask) - tl.store(Q + offs_q2, Q2 * cos1 + Q1 * sin1, mask=mask) + tl.store(Q + offs_q1, Q1 * cos1 - Q2 * sin1, mask = mask) + tl.store(Q + offs_q2, Q2 * cos1 + Q1 * sin1, mask = mask) _rope_embedding = triton.jit(_rope_embedding) @@ -134,9 +134,9 @@ class Fast_RoPE_Embedding(torch.autograd.Function): seq_len, head_dim, n_heads, - BACKWARD_PASS=False, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, + BACKWARD_PASS = False, + BLOCK_SIZE = BLOCK_SIZE, + num_warps = num_warps, ) ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps @@ -177,9 +177,9 @@ class Fast_RoPE_Embedding(torch.autograd.Function): seq_len, head_dim, n_heads, - BACKWARD_PASS=True, - BLOCK_SIZE=ctx.BLOCK_SIZE, - num_warps=ctx.num_warps, + BACKWARD_PASS = True, + BLOCK_SIZE = ctx.BLOCK_SIZE, + num_warps = ctx.num_warps, ) dY = dY.view(batch, seq_len, n_heads, head_dim) return ( @@ -211,7 +211,7 @@ class Slow_RoPE_Embedding(torch.autograd.Function): # Q * cos + rotate_half(Q) * sin half = Q.shape[-1] // 2 - RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim=-1) + RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1) Q *= cos Q.addcmul_(RH_Q, sin) # RH_Q *= sin @@ -224,7 +224,7 @@ class Slow_RoPE_Embedding(torch.autograd.Function): cos, sin = ctx.saved_tensors # Q * cos + rotate_half.T(Q) * sin half = dY.shape[-1] // 2 - RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim=-1) + RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1) dY *= cos dY.addcmul_(RH_dY, sin) # RH_dY *= sin diff --git a/unsloth/kernels/swiglu.py b/unsloth/kernels/swiglu.py index 7547b9ca6..b321f5179 100644 --- a/unsloth/kernels/swiglu.py +++ b/unsloth/kernels/swiglu.py @@ -43,8 +43,8 @@ def _fg_kernel( offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements - e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32) - g_row = tl.load(g + offsets, mask=mask, other=0) # .to(tl.float32) + e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32) + g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32) # f = e * sigmoid(e) f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row)) @@ -53,13 +53,13 @@ def _fg_kernel( h_row = f_row * g_row # Store h - tl.store(h + offsets, h_row, mask=mask) + tl.store(h + offsets, h_row, mask = mask) def swiglu_fg_kernel(e, g): batch, seq_len, hd = e.shape n_elements = e.numel() - h = torch.empty((batch, seq_len, hd), dtype=e.dtype, device=e.device) + h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = e.device) grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) with torch_gpu_device(e.device): _fg_kernel[grid]( @@ -67,8 +67,8 @@ def swiglu_fg_kernel(e, g): g, h, n_elements, - BLOCK_SIZE=BLOCK_SIZE, - LONG_INDEXING=0 if n_elements <= INT32_SAFETY_BUFFER else 1, + BLOCK_SIZE = BLOCK_SIZE, + LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1, ) return h @@ -101,9 +101,9 @@ def _DWf_DW_dfg_kernel( offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements - DW_row = tl.load(DW + offsets, mask=mask, other=0) # .to(tl.float32) - e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32) - g_row = tl.load(g + offsets, mask=mask, other=0) # .to(tl.float32) + DW_row = tl.load(DW + offsets, mask = mask, other = 0) # .to(tl.float32) + e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32) + g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32) # e = e.float() # se = 1.0 / (1.0 + torch.exp(-e)) @@ -122,9 +122,9 @@ def _DWf_DW_dfg_kernel( de_row = de_row.to(DW_row.dtype) # Store derivatives in buffers - tl.store(DW + offsets, h_row, mask=mask) # h = f * g - tl.store(e + offsets, df_row, mask=mask) # df = DW * f - tl.store(g + offsets, de_row, mask=mask) # de + tl.store(DW + offsets, h_row, mask = mask) # h = f * g + tl.store(e + offsets, df_row, mask = mask) # df = DW * f + tl.store(g + offsets, de_row, mask = mask) # de def swiglu_DWf_DW_dfg_kernel(DW, e, g): @@ -137,7 +137,7 @@ def swiglu_DWf_DW_dfg_kernel(DW, e, g): e, g, n_elements, - BLOCK_SIZE=BLOCK_SIZE, - LONG_INDEXING=0 if n_elements <= INT32_SAFETY_BUFFER else 1, + BLOCK_SIZE = BLOCK_SIZE, + LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1, ) return DW, e, g diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index c3f743cc0..5dcc7c232 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -47,12 +47,12 @@ if Version(torch.__version__) < Version("2.4.0"): torch_amp_custom_fwd = torch.cuda.amp.custom_fwd torch_amp_custom_bwd = torch.cuda.amp.custom_bwd else: - torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda") - torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda") + torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda") + torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda") if DEVICE_TYPE == "xpu": - torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="xpu") - torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="xpu") + torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "xpu") + torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "xpu") # tl.math.tanh now is libdevice.tanh @@ -349,7 +349,7 @@ def _maybe_fake_quantize_activations( if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM: @torch.inference_mode - def fast_dequantize(W, quant_state=None, out=None, use_global_buffer=False): + def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): # TODO: After adding XPU BNB support, check this function if isinstance(W, Float8Tensor): return W.dequantize() @@ -390,13 +390,13 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM: ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index] if WEIGHT_BUFFER is None: WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty( - size, dtype=dtype, device=device, requires_grad=False + size, dtype = dtype, device = device, requires_grad = False ) ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty( n_elements_absmax, - dtype=torch.float32, - device=device, - requires_grad=False, + dtype = torch.float32, + device = device, + requires_grad = False, ) if size > WEIGHT_BUFFER.numel(): @@ -409,16 +409,16 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM: else: if out is None: out = torch_empty( - shape, dtype=dtype, device=device, requires_grad=False + shape, dtype = dtype, device = device, requires_grad = False ) else: assert out.shape == shape assert out.dtype == dtype out_absmax = torch_empty( n_elements_absmax, - dtype=torch_float32, - device=device, - requires_grad=False, + dtype = torch_float32, + device = device, + requires_grad = False, ) # NF4 dequantization of statistics @@ -458,7 +458,7 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM: elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM: @torch.inference_mode - def fast_dequantize(W, quant_state=None, out=None, use_global_buffer=False): + def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if isinstance(W, Float8Tensor): return W.dequantize() if quant_state is None: @@ -500,13 +500,13 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM: ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index] if WEIGHT_BUFFER is None: WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty( - size, dtype=dtype, device=device, requires_grad=False + size, dtype = dtype, device = device, requires_grad = False ) ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty( n_elements_absmax, - dtype=torch_float32, - device=device, - requires_grad=False, + dtype = torch_float32, + device = device, + requires_grad = False, ) if size > WEIGHT_BUFFER.numel(): @@ -519,16 +519,16 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM: else: if out is None: out = torch_empty( - shape, dtype=dtype, device=device, requires_grad=False + shape, dtype = dtype, device = device, requires_grad = False ) else: assert out.shape == shape assert out.dtype == dtype out_absmax = torch_empty( n_elements_absmax, - dtype=torch_float32, - device=device, - requires_grad=False, + dtype = torch_float32, + device = device, + requires_grad = False, ) pass @@ -570,7 +570,7 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM: else: @torch.inference_mode - def fast_dequantize(W, quant_state=None, out=None, use_global_buffer=False): + def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if isinstance(W, Float8Tensor): return W.dequantize() if quant_state is None: @@ -601,12 +601,12 @@ else: # Create weight matrix if out is None: - out = torch_empty(shape, dtype=dtype, device=device, requires_grad=False) + out = torch_empty(shape, dtype = dtype, device = device, requires_grad = False) else: assert out.shape == shape assert out.dtype == dtype out_absmax = torch_empty( - n_elements_absmax, dtype=torch_float32, device=device, requires_grad=False + n_elements_absmax, dtype = torch_float32, device = device, requires_grad = False ) # Do dequantization @@ -645,9 +645,9 @@ else: # INTEL GPU Specific Logic if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM: - def fast_gemv(X, W, quant_state, out=None): + def fast_gemv(X, W, quant_state, out = None): if quant_state is None: - return torch_matmul(X, W, out=out) + return torch_matmul(X, W, out = out) # For fast X @ W where seq_len == 1 # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469 _, q_len, hd = X.shape @@ -686,8 +686,8 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM: 1, bout, ), - dtype=dtype, - device=device, + dtype = dtype, + device = device, ) # else: # assert(out.shape == (1, 1, bout,)) @@ -710,7 +710,7 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM: ldb = ctypes_c_int32(ldb) ldc = ctypes_c_int32(ldc) - df = torch_empty(absmax.shape, dtype=torch_float32, device=device) + df = torch_empty(absmax.shape, dtype = torch_float32, device = device) with torch_gpu_device(device): cdequantize_blockwise_fp32( get_ptr(code2), @@ -751,9 +751,9 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM: elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM: - def fast_gemv(X, W, quant_state, out=None): + def fast_gemv(X, W, quant_state, out = None): if quant_state is None: - return torch_matmul(X, W, out=out) + return torch_matmul(X, W, out = out) # For fast X @ W where seq_len == 1 # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469 _, q_len, hd = X.shape @@ -793,8 +793,8 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM: 1, bout, ), - dtype=dtype, - device=device, + dtype = dtype, + device = device, ) # else: # assert(out.shape == (1, 1, bout,)) @@ -813,7 +813,7 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM: ldb = ctypes_c_int32(ldb) ldc = ctypes_c_int32(ldc) - df = torch_empty(absmax.shape, dtype=torch_float32, device=device) + df = torch_empty(absmax.shape, dtype = torch_float32, device = device) with torch_gpu_device(device): cdequantize_blockwise_fp32( get_ptr(code2), @@ -856,9 +856,9 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM: pass else: - def fast_gemv(X, W, quant_state, out=None): + def fast_gemv(X, W, quant_state, out = None): if quant_state is None: - return torch_matmul(X, W, out=out) + return torch_matmul(X, W, out = out) # For fast X @ W where seq_len == 1 # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469 _, q_len, hd = X.shape @@ -894,8 +894,8 @@ else: 1, bout, ), - dtype=dtype, - device=device, + dtype = dtype, + device = device, ) # else: # assert(out.shape == (1, 1, bout,)) @@ -914,7 +914,7 @@ else: ldb = ctypes_c_int32(ldb) ldc = ctypes_c_int32(ldc) - df = torch_empty(absmax.shape, dtype=torch_float32, device=device) + df = torch_empty(absmax.shape, dtype = torch_float32, device = device) cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), @@ -953,21 +953,21 @@ else: pass -def fast_linear_forward(proj, X, temp_lora=None, out=None): +def fast_linear_forward(proj, X, temp_lora = None, out = None): W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj) bsz, q_len, in_dim = X.shape if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S) if W_quant is None: - out = torch_matmul(X, W.t(), out=out) + out = torch_matmul(X, W.t(), out = out) elif W.dtype == torch.float8_e4m3fn: out = fp8_linear(X, W, W_quant, bias) elif bsz == 1 and q_len == 1: - out = fast_gemv(X, W, W_quant, out=out) + out = fast_gemv(X, W, W_quant, out = out) else: - W = fast_dequantize(W.t(), W_quant, use_global_buffer=True) - out = torch_matmul(X, W, out=out) + W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) + out = torch_matmul(X, W, out = out) # Add in LoRA weights if lora_A is not None: @@ -980,14 +980,14 @@ def fast_linear_forward(proj, X, temp_lora=None, out=None): if bsz == 1: out = out.view(out_dim) - temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out=temp_lora) - out.addmv_(lora_B._fast_lora, temp_lora, alpha=lora_S) + temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out = temp_lora) + out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S) else: out = out.view(bsz, out_dim) temp_lora = torch_mm( - X.view(bsz, in_dim), lora_A._fast_lora.t(), out=temp_lora + X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora ) - out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha=lora_S) + out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S) out = out.view(bsz, 1, out_dim) if bias is not None: @@ -996,7 +996,7 @@ def fast_linear_forward(proj, X, temp_lora=None, out=None): return out -def matmul_lora(X, W, W_quant, A, B, s, out=None): +def matmul_lora(X, W, W_quant, A, B, s, out = None): dtype = X.dtype if X.dim() == 3: @@ -1015,12 +1015,12 @@ def matmul_lora(X, W, W_quant, A, B, s, out=None): W = W.dequantize() else: W = W.contiguous() - out = torch_matmul(X, W.t(), out=out) + out = torch_matmul(X, W.t(), out = out) elif W.dtype == torch.float8_e4m3fn: out = fp8_linear(X, W, W_quant) else: - W = fast_dequantize(W, W_quant, use_global_buffer=True) - out = torch_matmul(X, W.t(), out=out) + W = fast_dequantize(W, W_quant, use_global_buffer = True) + out = torch_matmul(X, W.t(), out = out) if W_quant is not None: del W @@ -1028,7 +1028,7 @@ def matmul_lora(X, W, W_quant, A, B, s, out=None): # LoRA is enabled A, B = A.t(), B.t() XA = torch_matmul(X, A.to(dtype)) - out.addmm_(XA, B.to(dtype), alpha=s) + out.addmm_(XA, B.to(dtype), alpha = s) # out += (X @ A.to(dtype)) @ (s * B.to(dtype)) return out.view(batch, seq_len, -1) if reshape else out diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index e0766b352..d3c4a8d73 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -150,23 +150,23 @@ for temporary_patch in TEMPORARY_PATCHES: # ============================================= # Disable some warnings which can get annoying -warnings.filterwarnings(action="ignore", category=UserWarning, module="torch") -warnings.filterwarnings(action="ignore", category=FutureWarning, module="torch") -warnings.filterwarnings(action="ignore", category=UserWarning, module="huggingface_hub") +warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch") +warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "torch") +warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub") warnings.filterwarnings( - action="ignore", category=FutureWarning, module="huggingface_hub" + action = "ignore", category = FutureWarning, module = "huggingface_hub" ) -warnings.filterwarnings(action="ignore", category=UserWarning, module="trl") -warnings.filterwarnings(action="ignore", category=FutureWarning, module="trl") -warnings.filterwarnings(action="ignore", category=FutureWarning, module="xformers") -warnings.filterwarnings(action="ignore", category=RuntimeWarning, module="subprocess") -warnings.filterwarnings(action="ignore", category=UserWarning, module="transformers") -warnings.filterwarnings(action="ignore", category=FutureWarning, module="accelerate") +warnings.filterwarnings(action = "ignore", category = UserWarning, module = "trl") +warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "trl") +warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "xformers") +warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "subprocess") +warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers") +warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate") warnings.filterwarnings( - action="ignore", category=RuntimeWarning, module="multiprocessing" + action = "ignore", category = RuntimeWarning, module = "multiprocessing" ) -warnings.filterwarnings(action="ignore", category=RuntimeWarning, module="multiprocess") -warnings.filterwarnings(action="ignore", category=UserWarning, module="triton") +warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocess") +warnings.filterwarnings(action = "ignore", category = UserWarning, module = "triton") # Stop "Special tokens have been added in the vocabulary, ..." import logging @@ -345,10 +345,10 @@ except: # You passed `quantization_config` or equivalent parameters try: warnings.filterwarnings( - action="ignore", - message=r".*quantization_config.*", - category=UserWarning, - append=True, + action = "ignore", + message = r".*quantization_config.*", + category = UserWarning, + append = True, ) except: pass @@ -357,10 +357,10 @@ except: # Will be fixed in torch 2.8.1 https://github.com/pytorch/pytorch/issues/158463 try: warnings.filterwarnings( - action="ignore", - message=r".*Logical operators 'and' and 'or'.*", - category=UserWarning, - append=True, + action = "ignore", + message = r".*Logical operators 'and' and 'or'.*", + category = UserWarning, + append = True, ) except: pass @@ -465,7 +465,7 @@ def extract_quant_model_param_count(model): return count -def get_model_param_count(model, trainable_only=False): +def get_model_param_count(model, trainable_only = False): """ Calculate model's total param count. If trainable_only is True then count only those requiring grads """ @@ -596,14 +596,14 @@ if DEVICE_TYPE in ("cuda", "hip"): torch_amp_custom_fwd = torch.cuda.amp.custom_fwd torch_amp_custom_bwd = torch.cuda.amp.custom_bwd else: - torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda") - torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda") + torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda") + torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda") elif DEVICE_TYPE == "xpu": if Version(torch_version) < Version("2.6.0"): raise RuntimeError("torch.xpu currently only supports torch.version >= 2.6.0") else: - torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="xpu") - torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="xpu") + torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "xpu") + torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "xpu") # ============================================= # ============================================= @@ -934,9 +934,9 @@ import torch._inductor.utils torch._inductor.utils.is_big_gpu = is_big_gpu patch_torch_compile( - debug=UNSLOTH_COMPILE_DEBUG, - O3=UNSLOTH_COMPILE_MAXIMUM, - ignore_errors=UNSLOTH_COMPILE_IGNORE_ERRORS, + debug = UNSLOTH_COMPILE_DEBUG, + O3 = UNSLOTH_COMPILE_MAXIMUM, + ignore_errors = UNSLOTH_COMPILE_IGNORE_ERRORS, ) torch_compile_options = { @@ -983,9 +983,9 @@ def patch_regional_compilation(): [ torch.compile( x, - dynamic=True, - options=torch_compile_options, - fullgraph=False, + dynamic = True, + options = torch_compile_options, + fullgraph = False, ) for x in args[0] ] @@ -1008,14 +1008,14 @@ def prepare_model_for_kbit_training( use_reentrant: Optional[bool] = True, ) -> Any: return prepare_model_for_training( - model=model, - use_gradient_checkpointing=use_gradient_checkpointing, - use_reentrant=use_reentrant, - full_finetuning=False, - train_layernorms=False, - train_embedding=False, - train_lm_head=False, - float32_mixed_precision=True, + model = model, + use_gradient_checkpointing = use_gradient_checkpointing, + use_reentrant = use_reentrant, + full_finetuning = False, + train_layernorms = False, + train_embedding = False, + train_lm_head = False, + float32_mixed_precision = True, ) @@ -1033,7 +1033,7 @@ if Version(peft_version) < Version("0.12.0"): text = "if weight is not None:\n" start = source.find(text) + len(text) end = source.find("self.to(weight.device)", start) - spaces = re.findall(r"^([ ]{1,})break", source, flags=re.MULTILINE)[0] + spaces = re.findall(r"^([ ]{1,})break", source, flags = re.MULTILINE)[0] source = source.replace(source[start:end], spaces) spaces = len(re.match(r"[\s]{1,}", source).group(0)) lines = source.split("\n") @@ -1070,7 +1070,7 @@ import socket @functools.lru_cache(1) -def has_internet(host="8.8.8.8", port=53, timeout=3): +def has_internet(host = "8.8.8.8", port = 53, timeout = 3): if os.environ.get("TRANSFORMERS_OFFLINE", "0") == "1": return False try: @@ -1084,12 +1084,12 @@ def has_internet(host="8.8.8.8", port=53, timeout=3): import psutil -def _get_statistics(statistics=None, force_download=True): +def _get_statistics(statistics = None, force_download = True): # We log some basic stats about which environment is being used. # We simply download a README.md file from HF - all data is made public. # This is simply so we can check if some envs are broken or not. # You can disable this by commenting the below out - n_cpus = psutil.cpu_count(logical=False) + n_cpus = psutil.cpu_count(logical = False) keynames = "\n" + "\n".join(os.environ.keys()) # Check modelscope for down detection global USE_MODELSCOPE @@ -1150,12 +1150,12 @@ def _get_statistics(statistics=None, force_download=True): if has_internet(): def stats_check(): - with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as f: + with tempfile.TemporaryDirectory(ignore_cleanup_errors = True) as f: snapshot_download( f"unslothai/{statistics}", - force_download=True, - cache_dir=f, - local_dir=f, + force_download = True, + cache_dir = f, + local_dir = f, ) time_limited_stats_check = execute_with_time_limit(120)(stats_check) @@ -1178,7 +1178,7 @@ def _get_statistics(statistics=None, force_download=True): stats_check() -def get_statistics(local_files_only=False): +def get_statistics(local_files_only = False): # We log some basic stats about which environment is being used. # This is also to check if HuggingFace is down or not! # We simply download a README.md file from HF - all data is made public. @@ -1201,7 +1201,7 @@ def get_statistics(local_files_only=False): disable_progress_bars() disabled = True _get_statistics(None) - _get_statistics("repeat", force_download=False) + _get_statistics("repeat", force_download = False) total_memory = ( torch.xpu.get_device_properties(0).total_memory if DEVICE_TYPE == "xpu" @@ -1242,7 +1242,7 @@ BitsAndBytesConfig__init__ = re.sub( r"if[\s]{1,}kwargs\:[\s]{1,}.+?\n", "", BitsAndBytesConfig__init__, - flags=re.MULTILINE, + flags = re.MULTILINE, ) BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.split("\n") length_spaces = len(re.match(r"[\s]{1,}", BitsAndBytesConfig__init__[0]).group(0)) @@ -1322,12 +1322,12 @@ def offload_to_disk( torch.save( W, filename, - pickle_module=pickle, - pickle_protocol=pickle.HIGHEST_PROTOCOL, + pickle_module = pickle, + pickle_protocol = pickle.HIGHEST_PROTOCOL, ) # We must use weights_only = False due to pickling offloaded_W = torch.load( - filename, map_location="cpu", mmap=True, weights_only=False + filename, map_location = "cpu", mmap = True, weights_only = False ) offloaded_W._offloaded_file_location = filename return offloaded_W @@ -1352,7 +1352,7 @@ def offload_output_embeddings( model.get_output_embeddings(), model, "output_embeddings", temporary_location ) - new_output_embeddings = torch.nn.Linear(1, 1, bias=None) + new_output_embeddings = torch.nn.Linear(1, 1, bias = None) del new_output_embeddings.weight new_output_embeddings.weight = offloaded_W new_output_embeddings.in_features = offloaded_W.shape[1] @@ -1376,10 +1376,10 @@ def is_vLLM_available(): # Patches models to add RoPE Scaling def patch_linear_scaling( - model_name="gemma2", - rope_module=None, - scaled_rope_module=None, - attention_module=None, + model_name = "gemma2", + rope_module = None, + scaled_rope_module = None, + attention_module = None, ): assert rope_module is not None and scaled_rope_module is not None assert attention_module is not None @@ -1430,13 +1430,13 @@ def patch_linear_scaling( pass """ fix_rope_function = fix_rope_function.format( - rope_function=rope_module.__name__, - scaled_rope_function=scaled_rope_module.__name__, + rope_function = rope_module.__name__, + scaled_rope_function = scaled_rope_module.__name__, ) rotary_emb = re.findall( r"self\.rotary\_emb \= .+?\)", function, - flags=re.DOTALL | re.MULTILINE, + flags = re.DOTALL | re.MULTILINE, ) if len(rotary_emb) == 0: return None, exec_code + "\n\n" + function @@ -1449,12 +1449,12 @@ def patch_linear_scaling( # Patches for Llama-3 LlamaExtendedRotaryEmbedding def patch_llama_rope_scaling( - model_name="llama", - rope_module=None, - scaled_rope_module=None, - extended_rope_module=None, - attention_module=None, - longrope_module=None, + model_name = "llama", + rope_module = None, + scaled_rope_module = None, + extended_rope_module = None, + attention_module = None, + longrope_module = None, ): assert ( rope_module is not None @@ -1528,17 +1528,17 @@ def patch_llama_rope_scaling( """ fix_rope_function = fix_rope_function.format( - rope_function=rope_module.__name__, - scaled_rope_function=scaled_rope_module.__name__, - extended_rope_function=extended_rope_module.__name__, - longrope_rope_function=( + rope_function = rope_module.__name__, + scaled_rope_function = scaled_rope_module.__name__, + extended_rope_function = extended_rope_module.__name__, + longrope_rope_function = ( longrope_module if longrope_module is not None else rope_module ).__name__, ) rotary_emb = re.findall( r"self\.rotary\_emb \= .+?\)", function, - flags=re.DOTALL | re.MULTILINE, + flags = re.DOTALL | re.MULTILINE, ) if len(rotary_emb) == 0: return None, function @@ -1548,15 +1548,15 @@ def patch_llama_rope_scaling( return init_name, function -def create_boolean_mask(n=4096, sliding_window=2048): +def create_boolean_mask(n = 4096, sliding_window = 2048): # Creates a boolean mask for attention - mask = torch.ones(n, n, dtype=torch.bool) + mask = torch.ones(n, n, dtype = torch.bool) if sliding_window == 0: - return torch.triu(mask, diagonal=1, out=mask) - torch.triu(mask, diagonal=0, out=mask) - torch.triu(mask.T, diagonal=-sliding_window, out=mask.T) + return torch.triu(mask, diagonal = 1, out = mask) + torch.triu(mask, diagonal = 0, out = mask) + torch.triu(mask.T, diagonal = -sliding_window, out = mask.T) mask = mask.T - torch.logical_not(mask, out=mask) + torch.logical_not(mask, out = mask) return mask @@ -1567,37 +1567,37 @@ def test_mask_creation(): for s in range(1, 23): correct_mask = ( AttentionMaskConverter( - is_causal=True, - sliding_window=s, + is_causal = True, + sliding_window = s, ) .to_causal_4d( 1, n, n, - dtype=torch.float16, + dtype = torch.float16, ) .squeeze(0) .squeeze(0) ) correct_mask = correct_mask == correct_mask.min() - our_mask = create_boolean_mask(n=n, sliding_window=s) + our_mask = create_boolean_mask(n = n, sliding_window = s) assert torch.all(correct_mask == our_mask) correct_mask = ( AttentionMaskConverter( - is_causal=True, - sliding_window=None, + is_causal = True, + sliding_window = None, ) .to_causal_4d( 1, n, n, - dtype=torch.float16, + dtype = torch.float16, ) .squeeze(0) .squeeze(0) ) correct_mask = correct_mask == correct_mask.min() - our_mask = create_boolean_mask(n=n, sliding_window=0) + our_mask = create_boolean_mask(n = n, sliding_window = 0) assert torch.all(correct_mask == our_mask) @@ -1812,34 +1812,34 @@ def unsloth_compile_transformers( dtype, model_name, model_types, - token=None, - revision=None, - trust_remote_code=False, - sdpa_dynamic_mask=True, - sdpa_bool_masks=True, - sdpa_gqa_replace=True, - sdpa_dynamic_compile=True, - compile_attention=True, - disable_causal_masks=True, - compile_torch_modules=True, - compile_custom_modules=True, - compile_function_calls=True, - fuse_lm_head=True, - gradient_checkpointing=True, - manual_replacements=True, - fast_lora_forwards=True, - fast_residual_stream=True, - accurate_accumulation=True, - epilogue_fusion=True, - max_autotune=False, - shape_padding=True, - cudagraphs=False, - debug=False, - fullgraph=True, - import_from_cache=False, - disable=False, - return_logits=False, - unsloth_force_compile=False, + token = None, + revision = None, + trust_remote_code = False, + sdpa_dynamic_mask = True, + sdpa_bool_masks = True, + sdpa_gqa_replace = True, + sdpa_dynamic_compile = True, + compile_attention = True, + disable_causal_masks = True, + compile_torch_modules = True, + compile_custom_modules = True, + compile_function_calls = True, + fuse_lm_head = True, + gradient_checkpointing = True, + manual_replacements = True, + fast_lora_forwards = True, + fast_residual_stream = True, + accurate_accumulation = True, + epilogue_fusion = True, + max_autotune = False, + shape_padding = True, + cudagraphs = False, + debug = False, + fullgraph = True, + import_from_cache = False, + disable = False, + return_logits = False, + unsloth_force_compile = False, ): if Version(torch_version) < Version("2.4.0"): print( @@ -1864,31 +1864,31 @@ def unsloth_compile_transformers( for model_type in model_types: _unsloth_compile_transformers( model_type, - sdpa_dynamic_mask=sdpa_dynamic_mask, - sdpa_bool_masks=sdpa_bool_masks, - sdpa_gqa_replace=sdpa_gqa_replace, - sdpa_dynamic_compile=sdpa_dynamic_compile, - compile_attention=compile_attention, - disable_causal_masks=disable_causal_masks, - compile_torch_modules=compile_torch_modules, - compile_custom_modules=compile_custom_modules, - compile_function_calls=compile_function_calls, - fuse_lm_head=fuse_lm_head, - gradient_checkpointing=gradient_checkpointing, - manual_replacements=manual_replacements, - fast_lora_forwards=fast_lora_forwards, - fast_residual_stream=fast_residual_stream, - accurate_accumulation=accurate_accumulation, - epilogue_fusion=epilogue_fusion, - max_autotune=max_autotune, - shape_padding=shape_padding, - cudagraphs=cudagraphs, - debug=debug, - fullgraph=fullgraph, - import_from_cache=import_from_cache, - disable=disable, - return_logits=return_logits, - supports_sdpa=supports_sdpa, + sdpa_dynamic_mask = sdpa_dynamic_mask, + sdpa_bool_masks = sdpa_bool_masks, + sdpa_gqa_replace = sdpa_gqa_replace, + sdpa_dynamic_compile = sdpa_dynamic_compile, + compile_attention = compile_attention, + disable_causal_masks = disable_causal_masks, + compile_torch_modules = compile_torch_modules, + compile_custom_modules = compile_custom_modules, + compile_function_calls = compile_function_calls, + fuse_lm_head = fuse_lm_head, + gradient_checkpointing = gradient_checkpointing, + manual_replacements = manual_replacements, + fast_lora_forwards = fast_lora_forwards, + fast_residual_stream = fast_residual_stream, + accurate_accumulation = accurate_accumulation, + epilogue_fusion = epilogue_fusion, + max_autotune = max_autotune, + shape_padding = shape_padding, + cudagraphs = cudagraphs, + debug = debug, + fullgraph = fullgraph, + import_from_cache = import_from_cache, + disable = disable, + return_logits = return_logits, + supports_sdpa = supports_sdpa, ) # Redo patches which override compiler for temporary_patch in TEMPORARY_PATCHES: @@ -1993,7 +1993,7 @@ def validate_loftq_config(loftq_config, lora_dropout, bias, init_lora_weights, m "Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n" "We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`." ) - loftq_config = LoftQConfig(loftq_bits=4, loftq_iter=1) + loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1) if hasattr(model.config, "quantization_config"): raise ValueError( @@ -2069,9 +2069,9 @@ class TorchAOConfig: base_config_and_filter_fns: List[ Tuple["AOBaseConfig", Optional[Callable[[torch.nn.Module, str], bool]]] ] = field( - default_factory=lambda: [ + default_factory = lambda: [ ( - Int4WeightOnlyConfig(group_size=128), + Int4WeightOnlyConfig(group_size = 128), lambda m, _: isinstance(m, torch.nn.Linear) and getattr(m, "in_features", 0) >= 128, ), @@ -2143,7 +2143,7 @@ def _convert_torchao_model(model): module_to_fqn_dict = {} for base_config, filter_fn in model._torchao_config.base_config_and_filter_fns: - quantize_(model, QATConfig(base_config, step="convert"), filter_fn=filter_fn) + quantize_(model, QATConfig(base_config, step = "convert"), filter_fn = filter_fn) # Default filter function used for quantize_ if filter_fn is None: @@ -2166,7 +2166,7 @@ def _convert_torchao_model(model): kwargs["modules_to_not_convert"] = [] quant_config = ModuleFqnToConfig(module_to_fqn_dict) - quantization_config = TorchAoConfig(quant_type=quant_config, **kwargs) + quantization_config = TorchAoConfig(quant_type = quant_config, **kwargs) model.config.quantization_config = quantization_config @@ -2199,17 +2199,17 @@ def _prepare_model_for_qat( and m.in_features >= group_size ) torchao_config = TorchAOConfig( - qat_scheme=qat_scheme, - base_config_and_filter_fns=[(base_config, filter_fn)], + qat_scheme = qat_scheme, + base_config_and_filter_fns = [(base_config, filter_fn)], ) elif qat_scheme == "fp8-fp8": from torchao.quantization import Float8DynamicActivationFloat8WeightConfig base_config = Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow() + granularity = PerRow() ) torchao_config = TorchAOConfig( - qat_scheme=qat_scheme, base_config_and_filter_fns=[(base_config, None)] + qat_scheme = qat_scheme, base_config_and_filter_fns = [(base_config, None)] ) elif qat_scheme == "int8-int4": from torchao.quantization import ( @@ -2218,35 +2218,35 @@ def _prepare_model_for_qat( ) torchao_config = TorchAOConfig( - qat_scheme=qat_scheme, - base_config_and_filter_fns=[ + qat_scheme = qat_scheme, + base_config_and_filter_fns = [ ( IntxWeightOnlyConfig( - weight_dtype=torch.int8, granularity=PerAxis(0) + weight_dtype = torch.int8, granularity = PerAxis(0) ), lambda m, fqn: isinstance(m, torch.nn.Embedding), ), ( Int8DynamicActivationIntxWeightConfig( - weight_dtype=torch.int4, weight_granularity=PerGroup(32) + weight_dtype = torch.int4, weight_granularity = PerGroup(32) ), None, ), ], - prequantization_transform=_untie_input_output_embeddings, + prequantization_transform = _untie_input_output_embeddings, ) elif qat_scheme == "int4": from torchao.quantization import Int4WeightOnlyConfig group_size = 128 - base_config = Int4WeightOnlyConfig(group_size=group_size) + base_config = Int4WeightOnlyConfig(group_size = group_size) filter_fn = ( lambda m, _: isinstance(m, torch.nn.Linear) and m.in_features >= group_size ) torchao_config = TorchAOConfig( - qat_scheme=qat_scheme, - base_config_and_filter_fns=[(base_config, filter_fn)], + qat_scheme = qat_scheme, + base_config_and_filter_fns = [(base_config, filter_fn)], ) else: raise ValueError(f"Unexpected QAT scheme {qat_scheme}") @@ -2264,7 +2264,7 @@ def _prepare_model_for_qat( if torchao_config.prequantization_transform is not None: torchao_config.prequantization_transform(model) for base_config, filter_fn in torchao_config.base_config_and_filter_fns: - quantize_(model, QATConfig(base_config, step="prepare"), filter_fn=filter_fn) + quantize_(model, QATConfig(base_config, step = "prepare"), filter_fn = filter_fn) return model diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 9f2d3cfc0..234fc2100 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -113,8 +113,8 @@ def Gemma2Attention_fast_forward( Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) if past_key_value is not None: - K = torch.cat([past_key_value[0], K], dim=2) - V = torch.cat([past_key_value[1], V], dim=2) + K = torch.cat([past_key_value[0], K], dim = 2) + V = torch.cat([past_key_value[1], V], dim = 2) past_key_value = (K, V) if use_cache else None # Only enable if the attention_mask is True @@ -139,10 +139,10 @@ def Gemma2Attention_fast_forward( Q, K, V, - causal=True, - softcap=self.config.attn_logit_softcapping, - softmax_scale=self._flash_attention_softmax_scale, - window_size=window, + causal = True, + softcap = self.config.attn_logit_softcapping, + softmax_scale = self._flash_attention_softmax_scale, + window_size = window, ) A = A.reshape(bsz, q_len, n_heads * head_dim) else: @@ -174,7 +174,7 @@ def Gemma2DecoderLayer_fast_forward( self, "_flag_for_generation" ): # past_key_value is not None: out_weight = torch.empty( - self.input_layernorm.weight.shape, dtype=torch.float32, device="cuda:0" + self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0" ) # Self Attention @@ -183,15 +183,15 @@ def Gemma2DecoderLayer_fast_forward( self.input_layernorm, hidden_states, out_weight ) hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - causal_mask=causal_mask, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - padding_mask=padding_mask, - _flag_for_generation=self._flag_for_generation, + hidden_states = hidden_states, + causal_mask = causal_mask, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_value = past_key_value, + output_attentions = output_attentions, + use_cache = use_cache, + padding_mask = padding_mask, + _flag_for_generation = self._flag_for_generation, ) hidden_states = fast_rms_layernorm_inference_gemma( self.post_attention_layernorm, hidden_states, out_weight @@ -211,31 +211,31 @@ def Gemma2DecoderLayer_fast_forward( else: residual = hidden_states hidden_states = fast_rms_layernorm( - self.input_layernorm, hidden_states, gemma=True + self.input_layernorm, hidden_states, gemma = True ) hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - causal_mask=causal_mask, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - padding_mask=padding_mask, + hidden_states = hidden_states, + causal_mask = causal_mask, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_value = past_key_value, + output_attentions = output_attentions, + use_cache = use_cache, + padding_mask = padding_mask, ) hidden_states = fast_rms_layernorm( - self.post_attention_layernorm, hidden_states, gemma=True + self.post_attention_layernorm, hidden_states, gemma = True ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = fast_rms_layernorm( - self.pre_feedforward_layernorm, hidden_states, gemma=True + self.pre_feedforward_layernorm, hidden_states, gemma = True ) hidden_states = self.mlp(hidden_states) hidden_states = fast_rms_layernorm( - self.post_feedforward_layernorm, hidden_states, gemma=True + self.post_feedforward_layernorm, hidden_states, gemma = True ) hidden_states = residual + hidden_states @@ -260,9 +260,9 @@ def Gemma2Attention_fast_forward_inference( hidden_states: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]], position_ids, - do_prefill=False, - attention_mask=None, - use_sliding_window=False, + do_prefill = False, + attention_mask = None, + use_sliding_window = False, ): Xn = hidden_states bsz, _, hd = hidden_states.size() @@ -286,24 +286,24 @@ def Gemma2Attention_fast_forward_inference( if do_prefill: self.paged_attention = torch.empty( (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim), - dtype=dtype, - device=device, + dtype = dtype, + device = device, ) self.paged_attention_K = self.paged_attention[:, 0] self.paged_attention_V = self.paged_attention[:, 1] self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3) self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3) self.temp_QA = torch.empty( - (2, bsz, 1, attention_size), dtype=dtype, device=device + (2, bsz, 1, attention_size), dtype = dtype, device = device ) self.temp_KV = torch.empty( - (2, bsz, 1, n_kv_heads * head_dim), dtype=dtype, device=device + (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device ) - self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype=dtype, device=device) + self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device) # Only for Gemma2 - self.temp_O = torch.empty((1, bsz, hidden_size), dtype=dtype, device=device) + self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device) self.attention = torch.empty( - (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype=dtype, device=device + (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device ) # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e @@ -331,9 +331,9 @@ def Gemma2Attention_fast_forward_inference( (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT) ) - Qn = fast_linear_forward(self.q_proj, Xn, out=self.temp_QA[0]) - Kn = fast_linear_forward(self.k_proj, Xn, out=self.temp_KV[0]) - Vn = fast_linear_forward(self.v_proj, Xn, out=self.temp_KV[1]) + Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0]) + Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0]) + Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1]) Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2) Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) @@ -348,7 +348,7 @@ def Gemma2Attention_fast_forward_inference( RH_Q = self.RH_Q RH_Q[:, :, :, :h] = Qn[:, :, :, h:] RH_Q[:, :, :, h:] = Qn[:, :, :, :h] - torch.neg(RH_Q[:, :, :, :h], out=RH_Q[:, :, :, :h]) + torch.neg(RH_Q[:, :, :, :h], out = RH_Q[:, :, :, :h]) Qn *= cos Qn.addcmul_(RH_Q, sin) @@ -357,7 +357,7 @@ def Gemma2Attention_fast_forward_inference( ] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0") RH_K[:, :, :, :h] = Kn[:, :, :, h:] RH_K[:, :, :, h:] = Kn[:, :, :, :h] - torch.neg(RH_K[:, :, :, :h], out=RH_K[:, :, :, :h]) + torch.neg(RH_K[:, :, :, :h], out = RH_K[:, :, :, :h]) Kn *= cos Kn.addcmul_(RH_K, sin) @@ -400,21 +400,21 @@ def Gemma2Attention_fast_forward_inference( self.scalar ) # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963 # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows - A = torch_matmul(Qn, Knn.transpose(2, 3), out=self.attention[:, :, :, :cached_len]) + A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]) # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched A *= self.reciprocal_t - torch_tanh(A, out=A) + torch_tanh(A, out = A) A *= self.t # Logit softcapping - A[:] = torch_nn_functional_softmax(A, dim=-1, dtype=torch.float32) # .to(A.dtype) - A = torch_matmul(A, Vnn, out=Qn) + A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32) # .to(A.dtype) + A = torch_matmul(A, Vnn, out = Qn) # else: # A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) # pass A = A.transpose(1, 2) A = A.reshape(bsz, 1, attention_size) - A = fast_linear_forward(self.o_proj, A, out=self.temp_O) + A = fast_linear_forward(self.o_proj, A, out = self.temp_O) return A, (Kn, Vn) @@ -425,13 +425,13 @@ def Gemma2Model_fast_forward_inference( input_ids, past_key_values, position_ids, - attention_mask=None, + attention_mask = None, ): out_weights = tuple( torch.empty_like( self.model.layers[0].input_layernorm.weight, - dtype=torch.float32, - device=torch.device(x), + dtype = torch.float32, + device = torch.device(x), ) for x in range(DEVICE_COUNT) ) @@ -441,7 +441,7 @@ def Gemma2Model_fast_forward_inference( # 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32 # 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32 hidden_states *= torch.tensor( - math_sqrt(self.config.hidden_size), dtype=hidden_states.dtype + math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype ) bsz, q_len, hd = hidden_states.shape @@ -456,7 +456,7 @@ def Gemma2Model_fast_forward_inference( (bsz, q_len), hidden_states, seq_len, - sliding_window=self.config.sliding_window, + sliding_window = self.config.sliding_window, ) GA = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, @@ -484,12 +484,12 @@ def Gemma2Model_fast_forward_inference( ) hidden_states, present_key_value = Gemma2Attention_fast_forward_inference( decoder_layer.self_attn, - hidden_states=hidden_states, - past_key_value=past_key_values[idx], - position_ids=position_ids, - attention_mask=SWA if use_sliding_window else GA, - do_prefill=not hasattr(decoder_layer.self_attn, "paged_attention"), - use_sliding_window=use_sliding_window, + hidden_states = hidden_states, + past_key_value = past_key_values[idx], + position_ids = position_ids, + attention_mask = SWA if use_sliding_window else GA, + do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), + use_sliding_window = use_sliding_window, ) hidden_states = fast_rms_layernorm_inference_gemma( decoder_layer.post_attention_layernorm, @@ -518,10 +518,10 @@ def Gemma2Model_fast_forward_inference( ) return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=[], - attentions=[], + last_hidden_state = hidden_states, + past_key_values = next_decoder_cache, + hidden_states = [], + attentions = [], ) @@ -529,10 +529,10 @@ class FastGemma2Model(FastLlamaModel): @staticmethod def pre_patch(): init_name, function = patch_linear_scaling( - model_name="gemma2", - rope_module=GemmaFixedRotaryEmbedding, - scaled_rope_module=GemmaFixedLinearScalingRotaryEmbedding, - attention_module=Gemma2Attention, + model_name = "gemma2", + rope_module = GemmaFixedRotaryEmbedding, + scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding, + attention_module = Gemma2Attention, ) if init_name is not None: exec(function, globals()) @@ -564,7 +564,7 @@ class FastGemma2Model(FastLlamaModel): def post_patch(model, tokenizer): # Gemma does not downcast RoPE model, tokenizer = patch_model_and_tokenizer( - model, tokenizer, downcast_rope=False + model, tokenizer, downcast_rope = False ) # Add 1 to weight diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index 7fd6bbd6f..0a816e399 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -109,8 +109,8 @@ def GraniteAttention_fast_forward( Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) if past_key_value is not None: - K = torch.cat([past_key_value[0], K], dim=2) - V = torch.cat([past_key_value[1], V], dim=2) + K = torch.cat([past_key_value[0], K], dim = 2) + V = torch.cat([past_key_value[1], V], dim = 2) past_key_value = (K, V) if use_cache else None # Attention module @@ -135,7 +135,7 @@ def GraniteAttention_fast_forward( Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim) A = xformers_attention( - Q, K, V, attn_bias=causal_mask, scale=self.scaling, p=dropout_p + Q, K, V, attn_bias = causal_mask, scale = self.scaling, p = dropout_p ) A = A.view(bsz, q_len, n_heads, head_dim) @@ -148,10 +148,10 @@ def GraniteAttention_fast_forward( Q, K, V, - causal=True, - window_size=window, - softmax_scale=self.scaling, - dropout_p=dropout_p, + causal = True, + window_size = window, + softmax_scale = self.scaling, + dropout_p = dropout_p, ) else: # Grouped query attention @@ -170,10 +170,10 @@ def GraniteAttention_fast_forward( Q, K, V, - attn_mask=attention_mask, - scale=self.scaling, - is_causal=False, - dropout_p=dropout_p, + attn_mask = attention_mask, + scale = self.scaling, + is_causal = False, + dropout_p = dropout_p, ) # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2).contiguous() @@ -212,18 +212,18 @@ def GraniteDecoderLayer_fast_forward( self.input_layernorm, hidden_states ) hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - causal_mask=causal_mask, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - padding_mask=padding_mask, - position_embeddings=position_embeddings, - _flag_for_generation=self._flag_for_generation, + hidden_states = hidden_states, + causal_mask = causal_mask, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_value = past_key_value, + output_attentions = output_attentions, + use_cache = use_cache, + padding_mask = padding_mask, + position_embeddings = position_embeddings, + _flag_for_generation = self._flag_for_generation, ) - hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier) + hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) # Fully Connected residual = hidden_states @@ -231,28 +231,28 @@ def GraniteDecoderLayer_fast_forward( self.post_attention_layernorm, hidden_states ) hidden_states = fast_swiglu_inference(self.mlp, hidden_states) - hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier) + hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) else: residual = hidden_states hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states) hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - causal_mask=causal_mask, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - padding_mask=padding_mask, - position_embeddings=position_embeddings, + hidden_states = hidden_states, + causal_mask = causal_mask, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_value = past_key_value, + output_attentions = output_attentions, + use_cache = use_cache, + padding_mask = padding_mask, + position_embeddings = position_embeddings, ) - hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier) + hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) # Fully Connected residual = hidden_states hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier) + hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) outputs = (hidden_states,) if output_attentions: @@ -275,9 +275,9 @@ def GraniteAttention_fast_forward_inference( hidden_states: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]], position_ids, - do_prefill=False, - attention_mask=None, - use_sliding_window=False, + do_prefill = False, + attention_mask = None, + use_sliding_window = False, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): assert ( @@ -306,24 +306,24 @@ def GraniteAttention_fast_forward_inference( if do_prefill: self.paged_attention = torch.empty( (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim), - dtype=dtype, - device=device, + dtype = dtype, + device = device, ) self.paged_attention_K = self.paged_attention[:, 0] self.paged_attention_V = self.paged_attention[:, 1] self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3) self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3) self.temp_QA = torch.empty( - (2, bsz, 1, attention_size), dtype=dtype, device=device + (2, bsz, 1, attention_size), dtype = dtype, device = device ) self.temp_KV = torch.empty( - (2, bsz, 1, n_kv_heads * head_dim), dtype=dtype, device=device + (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device ) - self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype=dtype, device=device) + self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device) # Only for Gemma2 - self.temp_O = torch.empty((1, bsz, hidden_size), dtype=dtype, device=device) + self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device) self.attention = torch.empty( - (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype=dtype, device=device + (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device ) self.half_head_dim = head_dim // 2 @@ -343,9 +343,9 @@ def GraniteAttention_fast_forward_inference( (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT) ) - Qn = fast_linear_forward(self.q_proj, Xn, out=self.temp_QA[0]) - Kn = fast_linear_forward(self.k_proj, Xn, out=self.temp_KV[0]) - Vn = fast_linear_forward(self.v_proj, Xn, out=self.temp_KV[1]) + Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0]) + Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0]) + Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1]) Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2) Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) @@ -359,7 +359,7 @@ def GraniteAttention_fast_forward_inference( RH_Q = self.RH_Q RH_Q[:, :, :, :h] = Qn[:, :, :, h:] RH_Q[:, :, :, h:] = Qn[:, :, :, :h] - torch.neg(RH_Q[:, :, :, :h], out=RH_Q[:, :, :, :h]) + torch.neg(RH_Q[:, :, :, :h], out = RH_Q[:, :, :, :h]) Qn *= cos Qn.addcmul_(RH_Q, sin) @@ -368,7 +368,7 @@ def GraniteAttention_fast_forward_inference( ] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0") RH_K[:, :, :, :h] = Kn[:, :, :, h:] RH_K[:, :, :, h:] = Kn[:, :, :, :h] - torch.neg(RH_K[:, :, :, :h], out=RH_K[:, :, :, :h]) + torch.neg(RH_K[:, :, :, :h], out = RH_K[:, :, :, :h]) Kn *= cos Kn.addcmul_(RH_K, sin) @@ -396,18 +396,18 @@ def GraniteAttention_fast_forward_inference( # pass Qn *= self.scaling - A = torch_matmul(Qn, Kn.transpose(2, 3), out=self.attention[:, :, :, :cached_len]) + A = torch_matmul(Qn, Kn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]) # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched - A[:] = torch_nn_functional_softmax(A, dim=-1, dtype=torch.float32) # .to(A.dtype) - A = torch_matmul(A, Vn, out=Qn) + A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32) # .to(A.dtype) + A = torch_matmul(A, Vn, out = Qn) # else: # A = scaled_dot_product_attention(Qn, Kn, Vn, attn_mask = attention_mask, is_causal = False) # pass A = A.transpose(1, 2) A = A.reshape(bsz, 1, attention_size) - A = fast_linear_forward(self.o_proj, A, out=self.temp_O) + A = fast_linear_forward(self.o_proj, A, out = self.temp_O) return A, (Kn, Vn) @@ -418,7 +418,7 @@ def GraniteModel_fast_forward_inference( input_ids, past_key_values, position_ids, - attention_mask=None, + attention_mask = None, ): input_ids = input_ids[:, : self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) @@ -459,37 +459,37 @@ def GraniteModel_fast_forward_inference( ) hidden_states, present_key_value = GraniteAttention_fast_forward_inference( decoder_layer.self_attn, - hidden_states=hidden_states, - past_key_value=past_key_values[idx], - position_ids=position_ids, - attention_mask=attention_mask, - do_prefill=not hasattr(decoder_layer.self_attn, "paged_attention"), - position_embeddings=position_embeddings, + hidden_states = hidden_states, + past_key_value = past_key_values[idx], + position_ids = position_ids, + attention_mask = attention_mask, + do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), + position_embeddings = position_embeddings, ) - hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier) + hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) residual = hidden_states hidden_states = fast_rms_layernorm_inference( decoder_layer.post_attention_layernorm, hidden_states ) hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states) - hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier) + hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) next_decoder_cache.append(present_key_value) hidden_states = fast_rms_layernorm_inference(self.model.norm, hidden_states) return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=[], - attentions=[], + last_hidden_state = hidden_states, + past_key_values = next_decoder_cache, + hidden_states = [], + attentions = [], ) class GraniteRotaryEmbedding(LlamaRotaryEmbedding): def __init__(self, config): - super().__init__(config=config) + super().__init__(config = config) def patched_init(original_init): @@ -510,10 +510,10 @@ class FastGraniteModel(FastLlamaModel): @staticmethod def pre_patch(): init_name, function = patch_linear_scaling( - model_name="granite", - rope_module=GraniteRotaryEmbedding, - scaled_rope_module=LlamaLinearScalingRotaryEmbedding, - attention_module=GraniteAttention, + model_name = "granite", + rope_module = GraniteRotaryEmbedding, + scaled_rope_module = LlamaLinearScalingRotaryEmbedding, + attention_module = GraniteAttention, ) if init_name is not None: exec(function, globals()) @@ -548,7 +548,7 @@ class FastGraniteModel(FastLlamaModel): model.config.update({"unsloth_version": __version__}) # We also do this for the lm_head - lm_head = torch.nn.Linear(1, 1, bias=None) + lm_head = torch.nn.Linear(1, 1, bias = None) del lm_head.weight lm_head.weight = model.lm_head.weight lm_head.in_features = lm_head.weight.shape[1] @@ -560,7 +560,7 @@ class FastGraniteModel(FastLlamaModel): model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr() ): - lm_head = torch.nn.Linear(1, 1, bias=None) + lm_head = torch.nn.Linear(1, 1, bias = None) del lm_head.weight lm_head.weight = model.model.embed_tokens.weight lm_head.in_features = lm_head.weight.shape[1] diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index b280c4978..1fe4fbffb 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -143,7 +143,7 @@ SDPA_HAS_GQA = "enable_gqa" in scaled_dot_product_attention.__doc__ def _fast_prepare_inputs_for_generation( self, input_ids, - attention_mask=None, + attention_mask = None, **kwargs, ): past_key_values = kwargs.get("past_key_values", None) @@ -185,7 +185,7 @@ def _fast_prepare_inputs_for_generation( "target_length": cache_length, "dtype": self.dtype, "cache_position": torch.arange( - cache_length, cache_length + 1, device=input_ids.device + cache_length, cache_length + 1, device = input_ids.device ), "batch_size": bs, "config": self.config, @@ -240,8 +240,8 @@ def LlamaAttention_fast_forward_inference( hidden_states: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]], position_ids, - do_prefill=False, - attention_mask=None, + do_prefill = False, + attention_mask = None, ): """ https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406 @@ -293,29 +293,29 @@ def LlamaAttention_fast_forward_inference( if do_prefill: self.paged_attention = torch.empty( (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim), - dtype=dtype, - device=device, + dtype = dtype, + device = device, ) self.paged_attention_K = self.paged_attention[:, 0] self.paged_attention_V = self.paged_attention[:, 1] self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3) self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3) self.temp_QA = torch.empty( - (2, bsz, 1, attention_size), dtype=dtype, device=device + (2, bsz, 1, attention_size), dtype = dtype, device = device ) self.temp_KV = torch.empty( - (2, bsz, 1, n_kv_heads * head_dim), dtype=dtype, device=device + (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device ) - self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype=dtype, device=device) + self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device) # Mistral Nemo 12b has weird dimensions if attention_size != hidden_size: - self.temp_O = torch.empty((1, bsz, hidden_size), dtype=dtype, device=device) + self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device) else: self.temp_O = self.temp_QA[1][:, :, :hidden_size] self.attention = torch.empty( - (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype=dtype, device=device + (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device ) self.scalar = 1.0 / math_sqrt(self.head_dim) self.half_head_dim = head_dim // 2 @@ -335,9 +335,9 @@ def LlamaAttention_fast_forward_inference( (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT) ) - Qn = fast_linear_forward(self.q_proj, Xn, out=self.temp_QA[0]) - Kn = fast_linear_forward(self.k_proj, Xn, out=self.temp_KV[0]) - Vn = fast_linear_forward(self.v_proj, Xn, out=self.temp_KV[1]) + Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0]) + Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0]) + Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1]) Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2) Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) @@ -414,30 +414,30 @@ def LlamaAttention_fast_forward_inference( Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963 # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows A = torch_matmul( - Qn, Knn.transpose(2, 3), out=self.attention[:, :, :, :cached_len] + Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len] ) # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched A[:] = torch_nn_functional_softmax( - A, dim=-1, dtype=torch.float32 + A, dim = -1, dtype = torch.float32 ) # .to(A.dtype) - A = torch_matmul(A, Vnn, out=Qn) + A = torch_matmul(A, Vnn, out = Qn) else: if SDPA_HAS_GQA: A = scaled_dot_product_attention( Qn, Knn, Vnn, - attn_mask=attention_mask, - is_causal=is_causal, - enable_gqa=True, + attn_mask = attention_mask, + is_causal = is_causal, + enable_gqa = True, ) else: A = scaled_dot_product_attention( - Qn, Knn, Vnn, attn_mask=attention_mask, is_causal=is_causal + Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = is_causal ) A = A.transpose(1, 2) A = A.reshape(bsz, 1, attention_size) - A = fast_linear_forward(self.o_proj, A, out=self.temp_O) + A = fast_linear_forward(self.o_proj, A, out = self.temp_O) return A, (Kn, Vn) @@ -445,7 +445,7 @@ torch_nn_functional_silu = torch.nn.functional.silu def fast_swiglu_inference( - self, X, temp_gate=None, temp_up=None, gate_multiplier=None, down_multiplier=None + self, X, temp_gate = None, temp_up = None, gate_multiplier = None, down_multiplier = None ): # gate = self.gate_proj(X) # up = self.up_proj(X) @@ -453,18 +453,18 @@ def fast_swiglu_inference( # mlp_size = self.config.intermediate_size # temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") - gate = fast_linear_forward(self.gate_proj, X, out=temp_gate) + gate = fast_linear_forward(self.gate_proj, X, out = temp_gate) if gate_multiplier is not None: gate *= gate_multiplier - up = fast_linear_forward(self.up_proj, X, out=temp_up) + up = fast_linear_forward(self.up_proj, X, out = temp_up) - gate = torch_nn_functional_silu(gate, inplace=True) + gate = torch_nn_functional_silu(gate, inplace = True) gate *= up # X = self.down_proj(gate) - down = fast_linear_forward(self.down_proj, gate, out=up[:, :, :hd]) + down = fast_linear_forward(self.down_proj, gate, out = up[:, :, :hd]) if down_multiplier is not None: down *= down_multiplier @@ -476,14 +476,14 @@ torch_square = torch.square torch_mean = torch.mean -def fast_rms_layernorm_inference(self, X, XX=None, XX2=None, variance=None): +def fast_rms_layernorm_inference(self, X, XX = None, XX2 = None, variance = None): old_dtype = X.dtype if XX is None: XX = X.to(torch.float32) - variance = XX.square().mean(-1, keepdim=True) + variance = XX.square().mean(-1, keepdim = True) else: XX.copy_(X) - torch_mean(torch_square(XX, out=XX2), -1, keepdim=True, out=variance) + torch_mean(torch_square(XX, out = XX2), -1, keepdim = True, out = variance) variance += self.variance_epsilon XX *= variance.rsqrt_() @@ -496,9 +496,9 @@ def fast_rms_layernorm_inference(self, X, XX=None, XX2=None, variance=None): return X -def fast_rms_layernorm_inference_gemma(self, X, out_weight=None): +def fast_rms_layernorm_inference_gemma(self, X, out_weight = None): XX = X.to(torch.float32) - variance = XX.square().mean(-1, keepdim=True) + variance = XX.square().mean(-1, keepdim = True) variance += self.variance_epsilon XX *= variance.rsqrt_() @@ -513,15 +513,15 @@ def fast_rms_layernorm_inference_gemma(self, X, out_weight=None): # Normal layernorm with mean removal -@torch.compile(fullgraph=False, dynamic=True, options=torch_compile_options) +@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def fast_layernorm_compiled(layernorm, X): old_dtype = X.dtype X = X.float() - mean = X.mean(-1, keepdim=True) + mean = X.mean(-1, keepdim = True) Xbar = X - mean X = ( Xbar - * torch.rsqrt(Xbar.square().mean(-1, keepdim=True) + layernorm.variance_epsilon) + * torch.rsqrt(Xbar.square().mean(-1, keepdim = True) + layernorm.variance_epsilon) * layernorm.weight.float() ) return X.to(old_dtype) @@ -574,7 +574,7 @@ def LlamaAttention_fast_forward( else: # Extend RoPE dynamically to fit in VRA rotary_emb = self.rotary_emb - rotary_emb.extend_rope_embedding(V, seq_len=kv_seq_len) + rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len) # if position_ids is None: # # Useful for LongRoPE @@ -591,8 +591,8 @@ def LlamaAttention_fast_forward( Q, K = fast_rope_embedding(Q, K, cos, sin) if past_key_value is not None: - K = torch.cat([past_key_value[0], K], dim=2) - V = torch.cat([past_key_value[1], V], dim=2) + K = torch.cat([past_key_value[0], K], dim = 2) + V = torch.cat([past_key_value[1], V], dim = 2) past_key_value = (K, V) if use_cache else None # Attention module @@ -614,14 +614,14 @@ def LlamaAttention_fast_forward( V = V.reshape(bsz, kv_seq_len, n_heads, head_dim) else: Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim) - A = xformers_attention(Q, K, V, attn_bias=causal_mask) + A = xformers_attention(Q, K, V, attn_bias = causal_mask) A = A.view(bsz, q_len, n_heads, head_dim) elif HAS_FLASH_ATTENTION and attention_mask is None: Q = Q.transpose(1, 2) K = K.transpose(1, 2) V = V.transpose(1, 2) - A = flash_attn_func(Q, K, V, causal=True) + A = flash_attn_func(Q, K, V, causal = True) else: # when qlen==vlen and attn_mask is None, we should use causal attention Q_len = Q.shape[-2] @@ -638,9 +638,9 @@ def LlamaAttention_fast_forward( Q, K, V, - attn_mask=attention_mask, - is_causal=is_causal, - enable_gqa=n_groups != 1, + attn_mask = attention_mask, + is_causal = is_causal, + enable_gqa = n_groups != 1, ) # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2) # .contiguous() @@ -661,7 +661,7 @@ def LlamaAttention_fast_forward( # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! A = scaled_dot_product_attention( - Q, K, V, attn_mask=attention_mask, is_causal=is_causal + Q, K, V, attn_mask = attention_mask, is_causal = is_causal ) # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2).contiguous() @@ -676,7 +676,7 @@ def LlamaAttention_fast_forward( def LlamaDecoderLayer_fast_forward( self, hidden_states: torch.Tensor, - causal_mask=None, + causal_mask = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, @@ -706,15 +706,15 @@ def LlamaDecoderLayer_fast_forward( self.input_layernorm, hidden_states ) hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - causal_mask=causal_mask, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - padding_mask=padding_mask, - position_embeddings=position_embeddings, + hidden_states = hidden_states, + causal_mask = causal_mask, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_value = past_key_value, + output_attentions = output_attentions, + use_cache = use_cache, + padding_mask = padding_mask, + position_embeddings = position_embeddings, ) hidden_states += residual @@ -729,15 +729,15 @@ def LlamaDecoderLayer_fast_forward( residual = hidden_states hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states) hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - causal_mask=causal_mask, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - padding_mask=padding_mask, - position_embeddings=position_embeddings, + hidden_states = hidden_states, + causal_mask = causal_mask, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_value = past_key_value, + output_attentions = output_attentions, + use_cache = use_cache, + padding_mask = padding_mask, + position_embeddings = position_embeddings, ) hidden_states = residual + hidden_states @@ -839,8 +839,8 @@ def LlamaModel_fast_forward( position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, - dtype=torch.int32, - device=f"{DEVICE_TYPE_TORCH}:0", + dtype = torch.int32, + device = f"{DEVICE_TYPE_TORCH}:0", ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) elif position_ids is not None: @@ -873,7 +873,7 @@ def LlamaModel_fast_forward( # Ie 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32 # & 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32 normalizer = torch.tensor( - math_sqrt(self.config.hidden_size), dtype=inputs_embeds.dtype + math_sqrt(self.config.hidden_size), dtype = inputs_embeds.dtype ) if train_embed_tokens: @@ -930,7 +930,7 @@ def LlamaModel_fast_forward( (batch_size, seq_length), inputs_embeds, past_key_values_length, - sliding_window=getattr(self.config, "sliding_window", None), + sliding_window = getattr(self.config, "sliding_window", None), ) # Must NOT convert to bool - weirdly this causes stuff to error out! # if attention_mask is not None: @@ -985,14 +985,14 @@ def LlamaModel_fast_forward( (batch_size, seq_length), inputs_embeds, past_key_values_length, - sliding_window=self.config.sliding_window, + sliding_window = self.config.sliding_window, ) dynamic_GA_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, - sliding_window=None, + sliding_window = None, ) use_static_mask = False @@ -1012,15 +1012,15 @@ def LlamaModel_fast_forward( self.SWA_mask = ( AttentionMaskConverter( - is_causal=True, - sliding_window=self.config.sliding_window, + is_causal = True, + sliding_window = self.config.sliding_window, ) .to_causal_4d( 1, n, n, - dtype=inputs_embeds.dtype, - device=DEVICE_TYPE_TORCH, + dtype = inputs_embeds.dtype, + device = DEVICE_TYPE_TORCH, ) .squeeze(0) .squeeze(0) @@ -1028,14 +1028,14 @@ def LlamaModel_fast_forward( self.GA_mask = ( AttentionMaskConverter( - is_causal=True, + is_causal = True, ) .to_causal_4d( 1, n, n, - dtype=inputs_embeds.dtype, - device=DEVICE_TYPE_TORCH, + dtype = inputs_embeds.dtype, + device = DEVICE_TYPE_TORCH, ) .squeeze(0) .squeeze(0) @@ -1087,8 +1087,8 @@ def LlamaModel_fast_forward( *inputs, past_key_value, output_attentions, - padding_mask=padding_mask, - position_embeddings=position_embeddings, + padding_mask = padding_mask, + position_embeddings = position_embeddings, ) return custom_forward @@ -1099,22 +1099,22 @@ def LlamaModel_fast_forward( mask, attention_mask, position_ids, - use_reentrant=True, - preserve_rng_state=False, + use_reentrant = True, + preserve_rng_state = False, ) hidden_states = layer_outputs[0] else: layer_outputs = decoder_layer( hidden_states, - causal_mask=mask, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - padding_mask=padding_mask, - position_embeddings=position_embeddings, + causal_mask = mask, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_value = past_key_value, + output_attentions = output_attentions, + use_cache = use_cache, + padding_mask = padding_mask, + position_embeddings = position_embeddings, ) hidden_states = layer_outputs[0] @@ -1139,10 +1139,10 @@ def LlamaModel_fast_forward( hidden_states = self.norm(hidden_states) elif IS_FALCON_H1: hidden_states = fast_rms_layernorm( - self.final_layernorm, hidden_states, gemma=IS_GEMMA + self.final_layernorm, hidden_states, gemma = IS_GEMMA ) else: - hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma=IS_GEMMA) + hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma = IS_GEMMA) if output_hidden_states: all_hidden_states += (hidden_states,) @@ -1155,17 +1155,17 @@ def LlamaModel_fast_forward( if v is not None ) return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, + last_hidden_state = hidden_states, + past_key_values = next_cache, + hidden_states = all_hidden_states, + attentions = all_self_attns, ) # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825 def _LlamaModel_fast_forward_inference( - attention_fast_forward_inference=LlamaAttention_fast_forward_inference, - mlp_fast_forward_inference=fast_swiglu_inference, + attention_fast_forward_inference = LlamaAttention_fast_forward_inference, + mlp_fast_forward_inference = fast_swiglu_inference, ): # This makes the attention and MLP customisable. # Now for models like qwen3 or cohere which use custom attention operations, we can use this function @@ -1174,7 +1174,7 @@ def _LlamaModel_fast_forward_inference( input_ids, past_key_values, position_ids, - attention_mask=None, + attention_mask = None, ): input_ids = input_ids[:, : self.max_seq_length] bsz, q_len = input_ids.shape @@ -1187,17 +1187,17 @@ def _LlamaModel_fast_forward_inference( assert q_len == 1 # Get saved buffers to reduce memory movement residual = torch.empty( - (bsz, q_len, hd), dtype=torch.float32, device=f"{DEVICE_TYPE_TORCH}:0" + (bsz, q_len, hd), dtype = torch.float32, device = f"{DEVICE_TYPE_TORCH}:0" ) _XX = torch.empty( - (2, bsz, q_len, hd), dtype=torch.float32, device=f"{DEVICE_TYPE_TORCH}:0" + (2, bsz, q_len, hd), dtype = torch.float32, device = f"{DEVICE_TYPE_TORCH}:0" ) XX, XX2 = _XX[0], _XX[1] variance = torch.empty( - (bsz, q_len, 1), dtype=torch.float32, device=f"{DEVICE_TYPE_TORCH}:0" + (bsz, q_len, 1), dtype = torch.float32, device = f"{DEVICE_TYPE_TORCH}:0" ) temp_mlp = torch.empty( - (2, bsz, 1, mlp_size), dtype=X.dtype, device=f"{DEVICE_TYPE_TORCH}:0" + (2, bsz, 1, mlp_size), dtype = X.dtype, device = f"{DEVICE_TYPE_TORCH}:0" ) temp_gates, temp_ups = ( tuple(temp_mlp[0].to(torch.device(x)) for x in range(DEVICE_COUNT)), @@ -1211,7 +1211,7 @@ def _LlamaModel_fast_forward_inference( (bsz, q_len), X, seq_len, - sliding_window=getattr(self.config, "sliding_window", None), + sliding_window = getattr(self.config, "sliding_window", None), ) else: attention_mask = None @@ -1227,17 +1227,17 @@ def _LlamaModel_fast_forward_inference( X = fast_rms_layernorm_inference( decoder_layer.input_layernorm, X, - XX=XX, - XX2=XX2, - variance=variance, + XX = XX, + XX2 = XX2, + variance = variance, ) X, present_key_value = attention_fast_forward_inference( decoder_layer.self_attn, - hidden_states=X, - past_key_value=past_key_values[idx], - position_ids=position_ids, - attention_mask=attention_mask, - do_prefill=not hasattr(decoder_layer.self_attn, "paged_attention"), + hidden_states = X, + past_key_value = past_key_values[idx], + position_ids = position_ids, + attention_mask = attention_mask, + do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), ) X += residual @@ -1245,15 +1245,15 @@ def _LlamaModel_fast_forward_inference( X = fast_rms_layernorm_inference( decoder_layer.post_attention_layernorm, X, - XX=XX, - XX2=XX2, - variance=variance, + XX = XX, + XX2 = XX2, + variance = variance, ) X = mlp_fast_forward_inference( decoder_layer.mlp, X, - temp_gate=temp_gates[device_index], - temp_up=temp_ups[device_index], + temp_gate = temp_gates[device_index], + temp_up = temp_ups[device_index], ) X += residual @@ -1261,16 +1261,16 @@ def _LlamaModel_fast_forward_inference( X = fast_rms_layernorm_inference( self.model.norm, X, - XX=XX, - XX2=XX2, - variance=variance, + XX = XX, + XX2 = XX2, + variance = variance, ) return BaseModelOutputWithPast( - last_hidden_state=X, - past_key_values=next_decoder_cache, - hidden_states=[], - attentions=[], + last_hidden_state = X, + past_key_values = next_decoder_cache, + hidden_states = [], + attentions = [], ) return LlamaModel_fast_forward_inference_custom @@ -1304,8 +1304,8 @@ def CausalLM_fast_forward(fast_forward_inference): self, input_ids, past_key_values, - position_ids=position_ids, - attention_mask=attention_mask, + position_ids = position_ids, + attention_mask = attention_mask, ) else: causal_mask = ( @@ -1328,16 +1328,16 @@ def CausalLM_fast_forward(fast_forward_inference): # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) self.model._has_no_labels = labels is None outputs = self.model( - input_ids=input_ids, - causal_mask=causal_mask, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + input_ids = input_ids, + causal_mask = causal_mask, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_values = past_key_values, + inputs_embeds = inputs_embeds, + use_cache = use_cache, + output_attentions = output_attentions, + output_hidden_states = output_hidden_states, + return_dict = return_dict, ) hidden_states = outputs[0] @@ -1360,11 +1360,11 @@ def CausalLM_fast_forward(fast_forward_inference): if num_logits_to_keep != 0: hidden_states = hidden_states[:, -num_logits_to_keep:, :] return CausalLMOutputWithPast( - loss=None, - logits=hidden_states, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, + loss = None, + logits = hidden_states, + past_key_values = outputs.past_key_values, + hidden_states = outputs.hidden_states, + attentions = outputs.attentions, ) if bsz == 1 and q_len == 1: @@ -1397,28 +1397,28 @@ def CausalLM_fast_forward(fast_forward_inference): # logit_softcapping = logit_softcapping, # ) loss = unsloth_fused_ce_loss( - trainer=None, - hidden_states=hidden_states, - lm_head_weight=lm_head, - lm_head_bias=None, - labels=labels, - mask=None, - n_items=n_items, - scaling=getattr(self, "accelerator_scaler", None), - target_gb=None, - torch_compile=True, - logit_softcapping=logit_softcapping, + trainer = None, + hidden_states = hidden_states, + lm_head_weight = lm_head, + lm_head_bias = None, + labels = labels, + mask = None, + n_items = n_items, + scaling = getattr(self, "accelerator_scaler", None), + target_gb = None, + torch_compile = True, + logit_softcapping = logit_softcapping, ) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output output = CausalLMOutputWithPast( - loss=loss, - logits=EMPTY_LOGITS, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, + loss = loss, + logits = EMPTY_LOGITS, + past_key_values = outputs.past_key_values, + hidden_states = outputs.hidden_states, + attentions = outputs.attentions, ) return output pass @@ -1451,11 +1451,11 @@ def CausalLM_fast_forward(fast_forward_inference): if n_items is None: n_items = kwargs.get("n_items", None) loss = fast_cross_entropy_loss( - logits=shift_logits, - labels=shift_labels, - logit_softcapping=logit_softcapping, - logit_scaling=logit_scaling, - n_items=n_items, + logits = shift_logits, + labels = shift_labels, + logit_softcapping = logit_softcapping, + logit_scaling = logit_scaling, + n_items = n_items, ) else: if logit_scaling != 0: @@ -1470,18 +1470,18 @@ def CausalLM_fast_forward(fast_forward_inference): logits = logit_softcapping * logits else: logits *= 1.0 / logit_softcapping - torch.tanh(logits, out=logits) + torch.tanh(logits, out = logits) logits *= logit_softcapping if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, + loss = loss, + logits = logits, + past_key_values = outputs.past_key_values, + hidden_states = outputs.hidden_states, + attentions = outputs.attentions, ) return _CausalLM_fast_forward @@ -1490,43 +1490,43 @@ def CausalLM_fast_forward(fast_forward_inference): @torch._disable_dynamo def PeftModel_fast_forward( self, - input_ids=None, - causal_mask=None, - attention_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - task_ids=None, - num_logits_to_keep=0, - logits_to_keep=0, + input_ids = None, + causal_mask = None, + attention_mask = None, + inputs_embeds = None, + labels = None, + output_attentions = None, + output_hidden_states = None, + return_dict = None, + task_ids = None, + num_logits_to_keep = 0, + logits_to_keep = 0, **kwargs, ): is_classification = "Classification" in str(type(self.base_model.model)) if is_classification: return self.base_model( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - labels=labels, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + input_ids = input_ids, + attention_mask = attention_mask, + inputs_embeds = inputs_embeds, + labels = labels, + output_attentions = output_attentions, + output_hidden_states = output_hidden_states, + return_dict = return_dict, **kwargs, ) else: return self.base_model( - input_ids=input_ids, - causal_mask=causal_mask, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - labels=labels, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - num_logits_to_keep=num_logits_to_keep, - logits_to_keep=logits_to_keep, + input_ids = input_ids, + causal_mask = causal_mask, + attention_mask = attention_mask, + inputs_embeds = inputs_embeds, + labels = labels, + output_attentions = output_attentions, + output_hidden_states = output_hidden_states, + return_dict = return_dict, + num_logits_to_keep = num_logits_to_keep, + logits_to_keep = logits_to_keep, **kwargs, ) @@ -1542,11 +1542,11 @@ class LlamaRotaryEmbedding(torch.nn.Module): # The precision of RoPE buffers is not correct, so we cast to int64. def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - config=None, # [TODO] Hack to pass in config - need to remove later + dim = None, + max_position_embeddings = 2048, + base = 10000, + device = None, + config = None, # [TODO] Hack to pass in config - need to remove later ): super().__init__() if config is not None: @@ -1578,17 +1578,17 @@ class LlamaRotaryEmbedding(torch.nn.Module): # Build here to make `torch.jit.trace` work. for device_idx in range(DEVICE_COUNT): self._set_cos_sin_cache( - seq_len=self.current_rope_size, - device=torch.device(device_idx), - dtype=torch.get_default_dtype(), + seq_len = self.current_rope_size, + device = torch.device(device_idx), + dtype = torch.get_default_dtype(), ) # dummy so that patch_utils doesn't fail for now self.cos_cached = torch.empty( - 1, device=get_current_device(), dtype=torch.get_default_dtype() + 1, device = get_current_device(), dtype = torch.get_default_dtype() ) self.sin_cached = torch.empty( - 1, device=get_current_device(), dtype=torch.get_default_dtype() + 1, device = get_current_device(), dtype = torch.get_default_dtype() ) def _set_cos_sin_cache(self, seq_len, device, dtype): @@ -1598,27 +1598,27 @@ class LlamaRotaryEmbedding(torch.nn.Module): inv_freq = 1.0 / ( self.base ** ( - torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() + torch.arange(0, self.dim, 2, dtype = torch.int64, device = "cpu").float() / self.dim ) ) t = torch.arange( - self.current_rope_size, device="cpu", dtype=torch.int64 + self.current_rope_size, device = "cpu", dtype = torch.int64 ).float() freqs = torch.outer(t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True) - sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True) + emb = torch.cat((freqs, freqs), dim = -1) + cos = emb.cos().to(dtype = dtype, device = device, non_blocking = True) + sin = emb.sin().to(dtype = dtype, device = device, non_blocking = True) self.multi_gpu_cos_cached[device.index] = cos self.multi_gpu_sin_cached[device.index] = sin return cos, sin - def forward(self, x, position_ids=None, seq_len=None): + def forward(self, x, position_ids = None, seq_len = None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len is not None and seq_len > self.current_rope_size: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + self._set_cos_sin_cache(seq_len = seq_len, device = x.device, dtype = x.dtype) device_index = x.device.index return ( @@ -1626,7 +1626,7 @@ class LlamaRotaryEmbedding(torch.nn.Module): self.multi_gpu_sin_cached[device_index][:seq_len], ) - def get_cached(self, seq_len=None, device_index=None): + def get_cached(self, seq_len = None, device_index = None): if device_index is None: device_index = get_current_device() return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[ @@ -1640,7 +1640,7 @@ class LlamaRotaryEmbedding(torch.nn.Module): self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 for device_idx in range(DEVICE_COUNT): self._set_cos_sin_cache( - self.current_rope_size, device=torch.device(device_idx), dtype=x.dtype + self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype ) @@ -1652,20 +1652,20 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): # The precision of RoPE buffers is not correct, so we cast to int64. def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - config=None, # [TODO] Hack to pass in config - need to remove later + dim = None, + max_position_embeddings = 2048, + base = 10000, + device = None, + scaling_factor = 1.0, + config = None, # [TODO] Hack to pass in config - need to remove later ): self.scaling_factor = scaling_factor super().__init__( - dim=dim, - max_position_embeddings=max_position_embeddings, - base=base, - device=device, - config=config, + dim = dim, + max_position_embeddings = max_position_embeddings, + base = base, + device = device, + config = config, ) def _set_cos_sin_cache(self, seq_len, device, dtype): @@ -1673,20 +1673,20 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): inv_freq = 1.0 / ( self.base ** ( - torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() + torch.arange(0, self.dim, 2, dtype = torch.int64, device = "cpu").float() / self.dim ) ) t = torch.arange( - self.current_rope_size, device="cpu", dtype=torch.int64 + self.current_rope_size, device = "cpu", dtype = torch.int64 ).float() t = t / self.scaling_factor freqs = torch.outer(t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True) - sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True) + emb = torch.cat((freqs, freqs), dim = -1) + cos = emb.cos().to(dtype = dtype, device = device, non_blocking = True) + sin = emb.sin().to(dtype = dtype, device = device, non_blocking = True) self.multi_gpu_cos_cached[device.index] = cos self.multi_gpu_sin_cached[device.index] = sin return cos, sin @@ -1697,11 +1697,11 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): class LlamaExtendedRotaryEmbedding(torch.nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - config=None, # [TODO] Hack to pass in config - need to remove later + dim = None, + max_position_embeddings = 2048, + base = 10000, + device = None, + config = None, # [TODO] Hack to pass in config - need to remove later ): super().__init__() if config is not None: @@ -1728,27 +1728,27 @@ class LlamaExtendedRotaryEmbedding(torch.nn.Module): inv_freq = 1.0 / ( self.base ** ( - torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() + torch.arange(0, self.dim, 2, dtype = torch.int64, device = "cpu").float() / self.dim ) ) inv_freq = self.apply_scaling(inv_freq) - self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("inv_freq", inv_freq, persistent = False) # Build here to make `torch.jit.trace` work. for device_idx in range(DEVICE_COUNT): self._set_cos_sin_cache( - seq_len=self.current_rope_size, - device=torch.device(device_idx), - dtype=torch.get_default_dtype(), + seq_len = self.current_rope_size, + device = torch.device(device_idx), + dtype = torch.get_default_dtype(), ) # dummy so that patch_utils doesn't fail for now self.cos_cached = torch.empty( - 1, device=get_current_device(), dtype=torch.get_default_dtype() + 1, device = get_current_device(), dtype = torch.get_default_dtype() ) self.sin_cached = torch.empty( - 1, device=get_current_device(), dtype=torch.get_default_dtype() + 1, device = get_current_device(), dtype = torch.get_default_dtype() ) def _set_cos_sin_cache(self, seq_len, device, dtype): @@ -1757,29 +1757,29 @@ class LlamaExtendedRotaryEmbedding(torch.nn.Module): self.current_rope_size = seq_len t = torch.arange( - self.current_rope_size, device=self.inv_freq.device, dtype=torch.int64 + self.current_rope_size, device = self.inv_freq.device, dtype = torch.int64 ).float() freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True) - sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True) + emb = torch.cat((freqs, freqs), dim = -1) + cos = emb.cos().to(dtype = dtype, device = device, non_blocking = True) + sin = emb.sin().to(dtype = dtype, device = device, non_blocking = True) self.multi_gpu_cos_cached[device.index] = cos self.multi_gpu_sin_cached[device.index] = sin return cos, sin - def forward(self, x, position_ids=None, seq_len=None): + def forward(self, x, position_ids = None, seq_len = None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len is not None and seq_len > self.current_rope_size: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + self._set_cos_sin_cache(seq_len = seq_len, device = x.device, dtype = x.dtype) device_index = x.device.index return ( self.multi_gpu_cos_cached[device_index][:seq_len], self.multi_gpu_sin_cached[device_index][:seq_len], ) - def get_cached(self, seq_len=None, device_index=None): + def get_cached(self, seq_len = None, device_index = None): if device_index is None: device_index = get_current_device() return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[ @@ -1793,7 +1793,7 @@ class LlamaExtendedRotaryEmbedding(torch.nn.Module): self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 for device_idx in range(DEVICE_COUNT): self._set_cos_sin_cache( - self.current_rope_size, device=torch.device(device_idx), dtype=x.dtype + self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype ) # From https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py#L41 @@ -1819,21 +1819,21 @@ class LlamaExtendedRotaryEmbedding(torch.nn.Module): high_freq_factor - low_freq_factor ) new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) - return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + return torch.tensor(new_freqs, dtype = freqs.dtype, device = freqs.device) class LongRopeRotaryEmbedding(torch.nn.Module): # For Phi 3.5 128K https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/modeling_phi3.py def __init__( self, - dim=None, - max_position_embeddings=131072, - original_max_position_embeddings=4096, - base=10000, - short_factor=None, - long_factor=None, - device=None, - config=None, # [TODO] Hack to pass in config - need to remove later + dim = None, + max_position_embeddings = 131072, + original_max_position_embeddings = 4096, + base = 10000, + short_factor = None, + long_factor = None, + device = None, + config = None, # [TODO] Hack to pass in config - need to remove later ): super().__init__() assert short_factor is not None @@ -1868,11 +1868,11 @@ class LongRopeRotaryEmbedding(torch.nn.Module): # Long RoPE similar to RoPE except short sequences have 1 cos / sin # and long sequences have another cos / sin inv_freq_shape = ( - torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() + torch.arange(0, self.dim, 2, dtype = torch.int64, device = "cpu").float() / self.dim ) - short_factor = torch.tensor(short_factor, device="cpu", dtype=torch.float32) - long_factor = torch.tensor(long_factor, device="cpu", dtype=torch.float32) + short_factor = torch.tensor(short_factor, device = "cpu", dtype = torch.float32) + long_factor = torch.tensor(long_factor, device = "cpu", dtype = torch.float32) short_inv_freq = 1.0 / (short_factor * self.base**inv_freq_shape) long_inv_freq = 1.0 / (long_factor * self.base**inv_freq_shape) @@ -1887,43 +1887,43 @@ class LongRopeRotaryEmbedding(torch.nn.Module): self.scaling_factor = scaling_factor # Short and long inv_freq - self.register_buffer("short_inv_freq", short_inv_freq, persistent=False) - self.register_buffer("long_inv_freq", long_inv_freq, persistent=False) + self.register_buffer("short_inv_freq", short_inv_freq, persistent = False) + self.register_buffer("long_inv_freq", long_inv_freq, persistent = False) # Build here to make `torch.jit.trace` work. # Initialize short sequences cache for all devices dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16 t = torch.arange( original_max_position_embeddings, - device=self.short_inv_freq.device, - dtype=torch.int64, + device = self.short_inv_freq.device, + dtype = torch.int64, ).float() freqs = torch.outer(t, self.short_inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) + emb = torch.cat((freqs, freqs), dim = -1) for device_idx in range(DEVICE_COUNT): device_obj = torch.device(device_idx) cos_cached = (emb.cos() * self.scaling_factor).to( - dtype=dtype, device=device_obj, non_blocking=True + dtype = dtype, device = device_obj, non_blocking = True ) sin_cached = (emb.sin() * self.scaling_factor).to( - dtype=dtype, device=device_obj, non_blocking=True + dtype = dtype, device = device_obj, non_blocking = True ) self.multi_gpu_short_cos_cached[device_idx] = cos_cached self.multi_gpu_short_sin_cached[device_idx] = sin_cached # dummy so that patch_utils doesn't fail for now self.short_cos_cached = torch.empty( - 1, device=get_current_device(), dtype=torch.get_default_dtype() + 1, device = get_current_device(), dtype = torch.get_default_dtype() ) self.short_sin_cached = torch.empty( - 1, device=get_current_device(), dtype=torch.get_default_dtype() + 1, device = get_current_device(), dtype = torch.get_default_dtype() ) self.long_cos_cached = torch.empty( - 1, device=get_current_device(), dtype=torch.get_default_dtype() + 1, device = get_current_device(), dtype = torch.get_default_dtype() ) self.long_sin_cached = torch.empty( - 1, device=get_current_device(), dtype=torch.get_default_dtype() + 1, device = get_current_device(), dtype = torch.get_default_dtype() ) def _set_cos_sin_cache(self, seq_len, device, dtype): @@ -1932,25 +1932,25 @@ class LongRopeRotaryEmbedding(torch.nn.Module): self.current_rope_size = seq_len t = torch.arange( - self.current_rope_size, device=self.long_inv_freq.device, dtype=torch.int64 + self.current_rope_size, device = self.long_inv_freq.device, dtype = torch.int64 ).float() # Long sequences freqs = torch.outer(t, self.long_inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) + emb = torch.cat((freqs, freqs), dim = -1) cos_cached = (emb.cos() * self.scaling_factor).to( - dtype=dtype, device=device, non_blocking=True + dtype = dtype, device = device, non_blocking = True ) sin_cached = (emb.sin() * self.scaling_factor).to( - dtype=dtype, device=device, non_blocking=True + dtype = dtype, device = device, non_blocking = True ) self.multi_gpu_long_cos_cached[device.index] = cos_cached self.multi_gpu_long_sin_cached[device.index] = sin_cached return cos_cached, sin_cached - def forward(self, x, position_ids=None, seq_len=None): + def forward(self, x, position_ids = None, seq_len = None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len is not None and seq_len > self.current_rope_size: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + self._set_cos_sin_cache(seq_len = seq_len, device = x.device, dtype = x.dtype) device_index = x.device.index @@ -1965,7 +1965,7 @@ class LongRopeRotaryEmbedding(torch.nn.Module): self.multi_gpu_long_sin_cached[device_index][:seq_len], ) - def get_cached(self, seq_len=None, device_index=None): + def get_cached(self, seq_len = None, device_index = None): if device_index is None: device_index = get_current_device() if seq_len is not None and seq_len < self.original_max_position_embeddings: @@ -1983,7 +1983,7 @@ class LongRopeRotaryEmbedding(torch.nn.Module): self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 for device_idx in range(DEVICE_COUNT): self._set_cos_sin_cache( - self.current_rope_size, device=torch.device(device_idx), dtype=x.dtype + self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype ) @@ -2041,7 +2041,7 @@ def unsloth_fast_generate( # Mixed precision autocast with ( _get_inference_mode_context_manager(self), - torch.autocast(device_type=DEVICE_TYPE_TORCH, dtype=dtype), + torch.autocast(device_type = DEVICE_TYPE_TORCH, dtype = dtype), ): output = self._old_generate(*args, **kwargs) @@ -2065,12 +2065,12 @@ class FastLlamaModel: @staticmethod def pre_patch(): init_name, function = patch_llama_rope_scaling( - model_name="llama", - rope_module=LlamaRotaryEmbedding, - scaled_rope_module=LlamaLinearScalingRotaryEmbedding, - extended_rope_module=LlamaExtendedRotaryEmbedding, - attention_module=LlamaAttention, - longrope_module=LongRopeRotaryEmbedding, + model_name = "llama", + rope_module = LlamaRotaryEmbedding, + scaled_rope_module = LlamaLinearScalingRotaryEmbedding, + extended_rope_module = LlamaExtendedRotaryEmbedding, + attention_module = LlamaAttention, + longrope_module = LongRopeRotaryEmbedding, ) if init_name is not None: exec(function, globals()) @@ -2103,27 +2103,27 @@ class FastLlamaModel: @staticmethod def from_pretrained( - model_name="unsloth/llama-3-8b-bnb-4bit", - max_seq_length=None, - dtype=None, - load_in_4bit=True, - token=None, - device_map="sequential", - rope_scaling=None, - fix_tokenizer=True, - model_patcher=None, - tokenizer_name=None, - trust_remote_code=False, - revision=None, - fast_inference=False, # uses vLLM - gpu_memory_utilization=0.5, - float8_kv_cache=False, - random_state=3407, - max_lora_rank=16, - disable_log_stats=False, - unsloth_vllm_standby=False, - num_labels=None, - qat_scheme=None, + model_name = "unsloth/llama-3-8b-bnb-4bit", + max_seq_length = None, + dtype = None, + load_in_4bit = True, + token = None, + device_map = "sequential", + rope_scaling = None, + fix_tokenizer = True, + model_patcher = None, + tokenizer_name = None, + trust_remote_code = False, + revision = None, + fast_inference = False, # uses vLLM + gpu_memory_utilization = 0.5, + float8_kv_cache = False, + random_state = 3407, + max_lora_rank = 16, + disable_log_stats = False, + unsloth_vllm_standby = False, + num_labels = None, + qat_scheme = None, **kwargs, ): os.environ["UNSLOTH_USE_NEW_MODEL"] = "0" @@ -2249,8 +2249,8 @@ class FastLlamaModel: # RoPE Scaling model_config = AutoConfig.from_pretrained( model_name, - token=token, - attn_implementation="sdpa", + token = token, + attn_implementation = "sdpa", ) model_config.model_name = model_name model_max_seq_length = model_config.max_position_embeddings @@ -2263,7 +2263,7 @@ class FastLlamaModel: has_rope_scaling = False try: - with open(inspect.getfile(model_function), "r", encoding="utf-8") as file: + with open(inspect.getfile(model_function), "r", encoding = "utf-8") as file: has_rope_scaling = "self.config.rope_scaling" in file.read() except: pass @@ -2310,11 +2310,11 @@ class FastLlamaModel: # we cannot quantize out_proj layer due to mamba kernels: https://github.com/tiiuae/Falcon-H1/issues/13#issuecomment-2918671274 llm_int8_skip_modules.append("out_proj") bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=dtype, - llm_int8_skip_modules=llm_int8_skip_modules, + load_in_4bit = True, + bnb_4bit_use_double_quant = True, + bnb_4bit_quant_type = "nf4", + bnb_4bit_compute_dtype = dtype, + llm_int8_skip_modules = llm_int8_skip_modules, ) # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/12 @@ -2332,26 +2332,26 @@ class FastLlamaModel: if num_labels is not None: model = AutoModelForSequenceClassification.from_pretrained( model_name, - device_map=device_map, + device_map = device_map, # torch_dtype = dtype, # transformers changed torch_dtype to dtype - num_labels=num_labels, + num_labels = num_labels, # quantization_config = bnb_config, - token=token, - max_position_embeddings=max_position_embeddings, - trust_remote_code=trust_remote_code, - attn_implementation="eager", + token = token, + max_position_embeddings = max_position_embeddings, + trust_remote_code = trust_remote_code, + attn_implementation = "eager", **kwargs, ) elif not fast_inference: model = AutoModelForCausalLM.from_pretrained( model_name, - device_map=device_map, + device_map = device_map, # torch_dtype = dtype, # transformers changed torch_dtype to dtype # quantization_config = bnb_config, - token=token, - max_position_embeddings=max_position_embeddings, - trust_remote_code=trust_remote_code, - attn_implementation="eager", + token = token, + max_position_embeddings = max_position_embeddings, + trust_remote_code = trust_remote_code, + attn_implementation = "eager", **kwargs, ) model.fast_generate = model.generate @@ -2366,17 +2366,17 @@ class FastLlamaModel: allowed_args = inspect.getfullargspec(load_vllm).args load_vllm_kwargs = dict( - model_name=model_name, - config=model_config, - gpu_memory_utilization=gpu_memory_utilization, - max_seq_length=max_seq_length, - dtype=dtype, - float8_kv_cache=float8_kv_cache, - enable_lora=True, - max_lora_rank=max_lora_rank, - disable_log_stats=disable_log_stats, - use_bitsandbytes=load_in_4bit, - unsloth_vllm_standby=unsloth_vllm_standby, + model_name = model_name, + config = model_config, + gpu_memory_utilization = gpu_memory_utilization, + max_seq_length = max_seq_length, + dtype = dtype, + float8_kv_cache = float8_kv_cache, + enable_lora = True, + max_lora_rank = max_lora_rank, + disable_log_stats = disable_log_stats, + use_bitsandbytes = load_in_4bit, + unsloth_vllm_standby = unsloth_vllm_standby, ) for allowed_arg in allowed_args: if allowed_arg not in load_vllm_kwargs and allowed_arg in kwargs: @@ -2387,7 +2387,7 @@ class FastLlamaModel: llm = load_vllm(**load_vllm_kwargs) # Convert to HF format - _, quant_state_dict = get_vllm_state_dict(llm, config=model_config) + _, quant_state_dict = get_vllm_state_dict(llm, config = model_config) model = convert_vllm_to_huggingface( quant_state_dict, model_config, dtype, bnb_config ) @@ -2403,12 +2403,12 @@ class FastLlamaModel: # Counteract saved tokenizers tokenizer_name = model_name if tokenizer_name is None else tokenizer_name tokenizer = load_correct_tokenizer( - tokenizer_name=tokenizer_name, - model_max_length=max_position_embeddings, - padding_side="right", - token=token, - trust_remote_code=trust_remote_code, - fix_tokenizer=fix_tokenizer, + tokenizer_name = tokenizer_name, + model_max_length = max_position_embeddings, + padding_side = "right", + token = token, + trust_remote_code = trust_remote_code, + fix_tokenizer = fix_tokenizer, ) model, tokenizer = patch_tokenizer(model, tokenizer) @@ -2489,7 +2489,7 @@ class FastLlamaModel: front_spaces = re.match(r"[\t\s]{1,}", inner_training_loop).group(0) inner_training_loop = re.sub( - r"^" + front_spaces, "", inner_training_loop, flags=re.MULTILINE + r"^" + front_spaces, "", inner_training_loop, flags = re.MULTILINE ) inner_training_loop = inner_training_loop.replace( "train_dataloader = tpu_spmd_dataloader(train_dataloader)", @@ -2521,12 +2521,12 @@ class FastLlamaModel: # We check the tokenizer first for errors if fix_tokenizer: tokenizer = check_tokenizer( - model=model, - tokenizer=tokenizer, - model_name=model_name, - model_max_length=max_position_embeddings, - padding_side="right", - token=token, + model = model, + tokenizer = tokenizer, + model_name = model_name, + model_max_length = max_position_embeddings, + padding_side = "right", + token = token, ) patch_saving_functions(tokenizer) @@ -2597,15 +2597,15 @@ class FastLlamaModel: @staticmethod def post_patch(model, tokenizer): model, tokenizer = patch_model_and_tokenizer( - model, tokenizer, downcast_rope=True + model, tokenizer, downcast_rope = True ) return model, tokenizer @staticmethod def get_peft_model( model, - r=16, - target_modules=[ + r = 16, + target_modules = [ "q_proj", "k_proj", "v_proj", @@ -2614,20 +2614,20 @@ class FastLlamaModel: "up_proj", "down_proj", ], - lora_alpha=16, - lora_dropout=0.0, - bias="none", - layers_to_transform=None, - layers_pattern=None, - use_gradient_checkpointing="unsloth", - random_state=3407, - max_seq_length=2048, # not used anymore - use_rslora=False, - modules_to_save=None, - init_lora_weights=True, - loftq_config={}, - temporary_location="_unsloth_temporary_saved_buffers", - qat_scheme=None, + lora_alpha = 16, + lora_dropout = 0.0, + bias = "none", + layers_to_transform = None, + layers_pattern = None, + use_gradient_checkpointing = "unsloth", + random_state = 3407, + max_seq_length = 2048, # not used anymore + use_rslora = False, + modules_to_save = None, + init_lora_weights = True, + loftq_config = {}, + temporary_location = "_unsloth_temporary_saved_buffers", + qat_scheme = None, **kwargs, ): if os.environ.get("UNSLOTH_USE_NEW_MODEL", "0") == "1": @@ -2641,22 +2641,22 @@ class FastLlamaModel: if peft_arg not in kwargs: kwargs[peft_arg] = flag return FastBaseModel.get_peft_model( - model=model, - r=r, - target_modules=target_modules, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - bias=bias, - layers_to_transform=layers_to_transform, - layers_pattern=layers_pattern, - use_gradient_checkpointing=use_gradient_checkpointing, - random_state=random_state, - max_seq_length=max_seq_length, - use_rslora=use_rslora, - modules_to_save=modules_to_save, - init_lora_weights=init_lora_weights, - loftq_config=loftq_config, - temporary_location=temporary_location, + model = model, + r = r, + target_modules = target_modules, + lora_alpha = lora_alpha, + lora_dropout = lora_dropout, + bias = bias, + layers_to_transform = layers_to_transform, + layers_pattern = layers_pattern, + use_gradient_checkpointing = use_gradient_checkpointing, + random_state = random_state, + max_seq_length = max_seq_length, + use_rslora = use_rslora, + modules_to_save = modules_to_save, + init_lora_weights = init_lora_weights, + loftq_config = loftq_config, + temporary_location = temporary_location, **kwargs, ) if os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1": @@ -2668,7 +2668,7 @@ class FastLlamaModel: if use_gradient_checkpointing == "unsloth": patch_unsloth_smart_gradient_checkpointing( - dtype=model.get_input_embeddings().weight.dtype + dtype = model.get_input_embeddings().weight.dtype ) if type(r) is not int: @@ -2744,7 +2744,7 @@ class FastLlamaModel: new_dtype = torch.float32 model.get_input_embeddings().modules_to_save.default.to( - device=DEVICE_TYPE_TORCH, dtype=new_dtype, non_blocking=True + device = DEVICE_TYPE_TORCH, dtype = new_dtype, non_blocking = True ) model.get_input_embeddings().modules_to_save.default.requires_grad_( True @@ -2752,7 +2752,7 @@ class FastLlamaModel: # [TODO] Move old embed_tokens to CPU - should be disk! model.get_input_embeddings().original_module.to( - device="cpu", non_blocking=True + device = "cpu", non_blocking = True ) model.get_input_embeddings().original_module.requires_grad_(False) @@ -2766,7 +2766,7 @@ class FastLlamaModel: new_dtype = torch.float32 model.get_output_embeddings().modules_to_save.default.to( - device=DEVICE_TYPE_TORCH, dtype=new_dtype, non_blocking=True + device = DEVICE_TYPE_TORCH, dtype = new_dtype, non_blocking = True ) model.get_output_embeddings().modules_to_save.default.requires_grad_( True @@ -2774,7 +2774,7 @@ class FastLlamaModel: # [TODO] Move old lm_head to CPU - should be disk! model.get_output_embeddings().original_module.to( - device="cpu", non_blocking=True + device = "cpu", non_blocking = True ) model.get_output_embeddings().original_module.requires_grad_(False) @@ -2829,7 +2829,7 @@ class FastLlamaModel: "Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n" "We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`." ) - loftq_config = LoftQConfig(loftq_bits=4, loftq_iter=1) + loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1) if hasattr(model.config, "quantization_config"): raise ValueError( @@ -2969,17 +2969,17 @@ class FastLlamaModel: is_classification = "Classification" in str(type(model)) arguments = dict( - r=r, - lora_alpha=lora_alpha, - target_modules=final_modules, - lora_dropout=lora_dropout, - bias=bias, - task_type=TaskType.CAUSAL_LM if not is_classification else TaskType.SEQ_CLS, - layers_to_transform=layers_to_transform, - init_lora_weights=init_lora_weights, - loftq_config=loftq_config, - use_rslora=use_rslora, - modules_to_save=modules_to_save, + r = r, + lora_alpha = lora_alpha, + target_modules = final_modules, + lora_dropout = lora_dropout, + bias = bias, + task_type = TaskType.CAUSAL_LM if not is_classification else TaskType.SEQ_CLS, + layers_to_transform = layers_to_transform, + init_lora_weights = init_lora_weights, + loftq_config = loftq_config, + use_rslora = use_rslora, + modules_to_save = modules_to_save, **kwargs, ) if not SUPPORTS_LOFTQ: @@ -3042,7 +3042,7 @@ class FastLlamaModel: new_dtype = torch.float32 model.get_input_embeddings().modules_to_save.default.to( - device=DEVICE_TYPE_TORCH, dtype=new_dtype, non_blocking=True + device = DEVICE_TYPE_TORCH, dtype = new_dtype, non_blocking = True ) model.get_input_embeddings().modules_to_save.default.requires_grad_(True) @@ -3059,7 +3059,7 @@ class FastLlamaModel: new_dtype = torch.float32 model.get_output_embeddings().modules_to_save.default.to( - device=DEVICE_TYPE_TORCH, dtype=new_dtype, non_blocking=True + device = DEVICE_TYPE_TORCH, dtype = new_dtype, non_blocking = True ) model.get_output_embeddings().modules_to_save.default.requires_grad_(True) @@ -3096,12 +3096,12 @@ class FastLlamaModel: @staticmethod def patch_peft_model( model, - use_gradient_checkpointing="unsloth", + use_gradient_checkpointing = "unsloth", ): if os.environ.get("UNSLOTH_USE_NEW_MODEL", "0") == "1": return FastBaseModel.patch_peft_model( - model=model, - use_gradient_checkpointing=use_gradient_checkpointing, + model = model, + use_gradient_checkpointing = use_gradient_checkpointing, ) if not isinstance(model, PeftModelForCausalLM) and not isinstance( model, PeftModelForSequenceClassification @@ -3138,8 +3138,8 @@ class FastLlamaModel: model = prepare_model_for_kbit_training( model, - use_gradient_checkpointing=use_gradient_checkpointing, - use_reentrant=True, + use_gradient_checkpointing = use_gradient_checkpointing, + use_reentrant = True, ) # Fix up config for transformers uploading PEFT @@ -3192,7 +3192,7 @@ class FastLlamaModel: # We also do not inplace edit QKV for Cohere! _apply_lora_mlp = ( - functools.partial(apply_lora_mlp, inplace=False) + functools.partial(apply_lora_mlp, inplace = False) if model_type == "cohere" else apply_lora_mlp ) @@ -3375,7 +3375,7 @@ class FastLlamaModel: return model @staticmethod - def for_training(model, use_gradient_checkpointing=True): + def for_training(model, use_gradient_checkpointing = True): if not hasattr(model, "parameters"): raise TypeError( "Unsloth: I think you're passing a tokenizer, not the model to for_training!" @@ -3424,4 +3424,4 @@ class FastLlamaModel: from .rl import PatchFastRL -PatchFastRL(FastLanguageModel=FastLlamaModel) +PatchFastRL(FastLanguageModel = FastLlamaModel) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 5389d84d2..e1c13315f 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -120,33 +120,33 @@ DISABLE_SDPA_MODEL_NAMES = [ class FastLanguageModel(FastLlamaModel): @staticmethod def from_pretrained( - model_name="unsloth/Llama-3.2-1B-Instruct", - max_seq_length=2048, - dtype=None, - load_in_4bit=True, # 4bit QLoRA - load_in_8bit=False, # 8bit LoRA - load_in_16bit=False, # 16bit LoRA - full_finetuning=False, - token=None, - device_map="sequential", - rope_scaling=None, - fix_tokenizer=True, - trust_remote_code=False, - use_gradient_checkpointing="unsloth", - resize_model_vocab=None, - revision=None, - use_exact_model_name=False, - offload_embedding=False, - float32_mixed_precision=None, # Forces float32 mixed precision - fast_inference=False, # uses vLLM - gpu_memory_utilization=0.5, - float8_kv_cache=False, - random_state=3407, - max_lora_rank=64, - disable_log_stats=True, - qat_scheme=None, - load_in_fp8=False, # fp8 LoRA (True, False, 'block') - unsloth_tiled_mlp=False, + model_name = "unsloth/Llama-3.2-1B-Instruct", + max_seq_length = 2048, + dtype = None, + load_in_4bit = True, # 4bit QLoRA + load_in_8bit = False, # 8bit LoRA + load_in_16bit = False, # 16bit LoRA + full_finetuning = False, + token = None, + device_map = "sequential", + rope_scaling = None, + fix_tokenizer = True, + trust_remote_code = False, + use_gradient_checkpointing = "unsloth", + resize_model_vocab = None, + revision = None, + use_exact_model_name = False, + offload_embedding = False, + float32_mixed_precision = None, # Forces float32 mixed precision + fast_inference = False, # uses vLLM + gpu_memory_utilization = 0.5, + float8_kv_cache = False, + random_state = 3407, + max_lora_rank = 64, + disable_log_stats = True, + qat_scheme = None, + load_in_fp8 = False, # fp8 LoRA (True, False, 'block') + unsloth_tiled_mlp = False, *args, **kwargs, ): @@ -157,40 +157,40 @@ class FastLanguageModel(FastLlamaModel): try: from huggingface_hub import login - login(token=token) + login(token = token) except: pass if load_in_8bit or full_finetuning or qat_scheme is not None: return FastModel.from_pretrained( - model_name=model_name, - max_seq_length=max_seq_length, - dtype=dtype, - load_in_4bit=load_in_4bit, - load_in_8bit=load_in_8bit, - load_in_16bit=load_in_16bit, - full_finetuning=full_finetuning, - token=token, - device_map=device_map, - rope_scaling=rope_scaling, # [TODO] No effect - fix_tokenizer=fix_tokenizer, # [TODO] No effect - trust_remote_code=trust_remote_code, - use_gradient_checkpointing=use_gradient_checkpointing, - resize_model_vocab=resize_model_vocab, # [TODO] No effect - revision=revision, - return_logits=False, # Return logits - fullgraph=True, # No graph breaks - use_exact_model_name=use_exact_model_name, - offload_embedding=offload_embedding, - float32_mixed_precision=float32_mixed_precision, + model_name = model_name, + max_seq_length = max_seq_length, + dtype = dtype, + load_in_4bit = load_in_4bit, + load_in_8bit = load_in_8bit, + load_in_16bit = load_in_16bit, + full_finetuning = full_finetuning, + token = token, + device_map = device_map, + rope_scaling = rope_scaling, # [TODO] No effect + fix_tokenizer = fix_tokenizer, # [TODO] No effect + trust_remote_code = trust_remote_code, + use_gradient_checkpointing = use_gradient_checkpointing, + resize_model_vocab = resize_model_vocab, # [TODO] No effect + revision = revision, + return_logits = False, # Return logits + fullgraph = True, # No graph breaks + use_exact_model_name = use_exact_model_name, + offload_embedding = offload_embedding, + float32_mixed_precision = float32_mixed_precision, # Pass vLLM/inference parameters - fast_inference=fast_inference, - gpu_memory_utilization=gpu_memory_utilization, - float8_kv_cache=float8_kv_cache, - random_state=random_state, - max_lora_rank=max_lora_rank, - disable_log_stats=disable_log_stats, - qat_scheme=qat_scheme, - load_in_fp8=load_in_fp8, + fast_inference = fast_inference, + gpu_memory_utilization = gpu_memory_utilization, + float8_kv_cache = float8_kv_cache, + random_state = random_state, + max_lora_rank = max_lora_rank, + disable_log_stats = disable_log_stats, + qat_scheme = qat_scheme, + load_in_fp8 = load_in_fp8, *args, **kwargs, ) @@ -231,7 +231,7 @@ class FastLanguageModel(FastLlamaModel): fp8_mode = None if not use_exact_model_name: new_model_name = get_model_name( - model_name, load_in_4bit=load_in_4bit, load_in_fp8=load_in_fp8 + model_name, load_in_4bit = load_in_4bit, load_in_fp8 = load_in_fp8 ) if new_model_name is None and load_in_fp8 != False: fp8_mode = _get_fp8_mode_and_check_settings( @@ -284,9 +284,9 @@ class FastLanguageModel(FastLlamaModel): try: model_config = AutoConfig.from_pretrained( model_name, - token=token, - revision=revision, - trust_remote_code=trust_remote_code, + token = token, + revision = revision, + trust_remote_code = trust_remote_code, ) is_model = True except Exception as error: @@ -300,9 +300,9 @@ class FastLanguageModel(FastLlamaModel): try: peft_config = PeftConfig.from_pretrained( model_name, - token=token, - revision=revision, - trust_remote_code=trust_remote_code, + token = token, + revision = revision, + trust_remote_code = trust_remote_code, ) is_peft = True except Exception as error: @@ -345,7 +345,7 @@ class FastLanguageModel(FastLlamaModel): both_exist = exist_adapter_config and exist_config else: # Because HfFileSystem assumes linux paths, we need to set the path with forward slashes, even on Windows. - files = HfFileSystem(token=token).glob(f"{model_name}/*.json") + files = HfFileSystem(token = token).glob(f"{model_name}/*.json") files = list(os.path.split(x)[-1] for x in files) if ( sum(x == "adapter_config.json" or x == "config.json" for x in files) @@ -393,8 +393,8 @@ class FastLanguageModel(FastLlamaModel): model_config = AutoConfig.from_pretrained( model_name, - token=token, - trust_remote_code=trust_remote_code, + token = token, + trust_remote_code = trust_remote_code, ) if not was_disabled: @@ -484,42 +484,42 @@ class FastLanguageModel(FastLlamaModel): # dispatch_model = FastGraniteModel else: return FastModel.from_pretrained( - model_name=old_model_name, - max_seq_length=max_seq_length, - dtype=dtype, - load_in_4bit=load_in_4bit, - load_in_8bit=load_in_8bit, - load_in_16bit=load_in_16bit, - full_finetuning=full_finetuning, - token=token, - device_map=device_map, - rope_scaling=rope_scaling, # [TODO] No effect - fix_tokenizer=fix_tokenizer, # [TODO] No effect - trust_remote_code=trust_remote_code, - use_gradient_checkpointing=use_gradient_checkpointing, - resize_model_vocab=resize_model_vocab, # [TODO] No effect - revision=revision, - return_logits=False, # Return logits - fullgraph=True, # No graph breaks - use_exact_model_name=use_exact_model_name, - offload_embedding=offload_embedding, - float32_mixed_precision=float32_mixed_precision, + model_name = old_model_name, + max_seq_length = max_seq_length, + dtype = dtype, + load_in_4bit = load_in_4bit, + load_in_8bit = load_in_8bit, + load_in_16bit = load_in_16bit, + full_finetuning = full_finetuning, + token = token, + device_map = device_map, + rope_scaling = rope_scaling, # [TODO] No effect + fix_tokenizer = fix_tokenizer, # [TODO] No effect + trust_remote_code = trust_remote_code, + use_gradient_checkpointing = use_gradient_checkpointing, + resize_model_vocab = resize_model_vocab, # [TODO] No effect + revision = revision, + return_logits = False, # Return logits + fullgraph = True, # No graph breaks + use_exact_model_name = use_exact_model_name, + offload_embedding = offload_embedding, + float32_mixed_precision = float32_mixed_precision, # Pass vLLM/inference parameters - fast_inference=fast_inference, - gpu_memory_utilization=gpu_memory_utilization, - float8_kv_cache=float8_kv_cache, - random_state=random_state, - max_lora_rank=max_lora_rank, - disable_log_stats=disable_log_stats, - qat_scheme=qat_scheme, - load_in_fp8=load_in_fp8, - unsloth_tiled_mlp=unsloth_tiled_mlp, + fast_inference = fast_inference, + gpu_memory_utilization = gpu_memory_utilization, + float8_kv_cache = float8_kv_cache, + random_state = random_state, + max_lora_rank = max_lora_rank, + disable_log_stats = disable_log_stats, + qat_scheme = qat_scheme, + load_in_fp8 = load_in_fp8, + unsloth_tiled_mlp = unsloth_tiled_mlp, *args, **kwargs, ) if use_gradient_checkpointing == "unsloth": - patch_unsloth_smart_gradient_checkpointing(dtype=dtype) + patch_unsloth_smart_gradient_checkpointing(dtype = dtype) # Check if this is local model since the tokenizer gets overwritten if ( @@ -535,24 +535,24 @@ class FastLanguageModel(FastLlamaModel): fast_inference, model_name = fast_inference_setup(model_name, model_config) model, tokenizer = dispatch_model.from_pretrained( - model_name=model_name, - max_seq_length=max_seq_length, - dtype=_get_dtype(dtype), - load_in_4bit=load_in_4bit, - token=token, - device_map=device_map, - rope_scaling=rope_scaling, - fix_tokenizer=fix_tokenizer, - model_patcher=dispatch_model, - tokenizer_name=tokenizer_name, - trust_remote_code=trust_remote_code, - revision=revision if not is_peft else None, - fast_inference=fast_inference, - gpu_memory_utilization=gpu_memory_utilization, - float8_kv_cache=float8_kv_cache, - random_state=random_state, - max_lora_rank=max_lora_rank, - disable_log_stats=disable_log_stats, + model_name = model_name, + max_seq_length = max_seq_length, + dtype = _get_dtype(dtype), + load_in_4bit = load_in_4bit, + token = token, + device_map = device_map, + rope_scaling = rope_scaling, + fix_tokenizer = fix_tokenizer, + model_patcher = dispatch_model, + tokenizer_name = tokenizer_name, + trust_remote_code = trust_remote_code, + revision = revision if not is_peft else None, + fast_inference = fast_inference, + gpu_memory_utilization = gpu_memory_utilization, + float8_kv_cache = float8_kv_cache, + random_state = random_state, + max_lora_rank = max_lora_rank, + disable_log_stats = disable_log_stats, *args, **kwargs, ) @@ -601,10 +601,10 @@ class FastLanguageModel(FastLlamaModel): model = PeftModel.from_pretrained( model, old_model_name, - token=token, - revision=revision, - is_trainable=True, - trust_remote_code=trust_remote_code, + token = token, + revision = revision, + is_trainable = True, + trust_remote_code = trust_remote_code, ) # Patch it as well! model = dispatch_model.patch_peft_model(model, use_gradient_checkpointing) @@ -615,7 +615,7 @@ class FastLanguageModel(FastLlamaModel): "UNSLOTH_TILED_MLP", "arctic" if unsloth_tiled_mlp else "0" ) if patch_tiled_mlp_choice != "0" or unsloth_tiled_mlp: - patch_tiled_mlp(model, patch_options_str=patch_tiled_mlp_choice) + patch_tiled_mlp(model, patch_options_str = patch_tiled_mlp_choice) return model, tokenizer @@ -645,40 +645,40 @@ class FastModel(FastBaseModel): @staticmethod def from_pretrained( - model_name="unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit", - max_seq_length=2048, - dtype=None, - load_in_4bit=True, # 4bit QLoRA - load_in_8bit=False, # 8bit LoRA - load_in_16bit=False, # 16bit LoRA - full_finetuning=False, - token=None, - device_map="sequential", - rope_scaling=None, # [TODO] No effect - fix_tokenizer=True, # [TODO] No effect - trust_remote_code=False, - use_gradient_checkpointing="unsloth", - resize_model_vocab=None, # [TODO] No effect - revision=None, - return_logits=False, # Return logits - fullgraph=True, # No graph breaks - use_exact_model_name=False, - auto_model=None, - whisper_language=None, - whisper_task=None, - unsloth_force_compile=False, - offload_embedding=False, - float32_mixed_precision=None, # Forces float32 mixed precision + model_name = "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit", + max_seq_length = 2048, + dtype = None, + load_in_4bit = True, # 4bit QLoRA + load_in_8bit = False, # 8bit LoRA + load_in_16bit = False, # 16bit LoRA + full_finetuning = False, + token = None, + device_map = "sequential", + rope_scaling = None, # [TODO] No effect + fix_tokenizer = True, # [TODO] No effect + trust_remote_code = False, + use_gradient_checkpointing = "unsloth", + resize_model_vocab = None, # [TODO] No effect + revision = None, + return_logits = False, # Return logits + fullgraph = True, # No graph breaks + use_exact_model_name = False, + auto_model = None, + whisper_language = None, + whisper_task = None, + unsloth_force_compile = False, + offload_embedding = False, + float32_mixed_precision = None, # Forces float32 mixed precision # Add the missing vLLM/inference parameters - fast_inference=False, # uses vLLM - gpu_memory_utilization=0.5, - float8_kv_cache=False, - random_state=3407, - max_lora_rank=64, - disable_log_stats=True, - qat_scheme=None, - load_in_fp8=False, # fp8 LoRA (True, False, 'block') - unsloth_tiled_mlp=False, + fast_inference = False, # uses vLLM + gpu_memory_utilization = 0.5, + float8_kv_cache = False, + random_state = 3407, + max_lora_rank = 64, + disable_log_stats = True, + qat_scheme = None, + load_in_fp8 = False, # fp8 LoRA (True, False, 'block') + unsloth_tiled_mlp = False, *args, **kwargs, ): @@ -689,7 +689,7 @@ class FastModel(FastBaseModel): try: from huggingface_hub import login - login(token=token) + login(token = token) except: pass if whisper_language is not None: @@ -765,7 +765,7 @@ class FastModel(FastBaseModel): fp8_mode = None if not use_exact_model_name: new_model_name = get_model_name( - model_name, load_in_4bit=load_in_4bit, load_in_fp8=load_in_fp8 + model_name, load_in_4bit = load_in_4bit, load_in_fp8 = load_in_fp8 ) if new_model_name is None and load_in_fp8 != False: fp8_mode = _get_fp8_mode_and_check_settings( @@ -819,9 +819,9 @@ class FastModel(FastBaseModel): try: model_config = AutoConfig.from_pretrained( model_name, - token=token, - revision=revision, - trust_remote_code=trust_remote_code, + token = token, + revision = revision, + trust_remote_code = trust_remote_code, ) is_model = True except Exception as error: @@ -835,9 +835,9 @@ class FastModel(FastBaseModel): try: peft_config = PeftConfig.from_pretrained( model_name, - token=token, - revision=revision, - trust_remote_code=trust_remote_code, + token = token, + revision = revision, + trust_remote_code = trust_remote_code, ) is_peft = True except Exception as error: @@ -1011,7 +1011,7 @@ class FastModel(FastBaseModel): exist_config = os.path.exists(os.path.join(model_name, "config.json")) both_exist = exist_adapter_config and exist_config else: - files = HfFileSystem(token=token).glob(f"{model_name}/*.json") + files = HfFileSystem(token = token).glob(f"{model_name}/*.json") files = list(os.path.split(x)[-1] for x in files) if ( sum(x == "adapter_config.json" or x == "config.json" for x in files) @@ -1059,8 +1059,8 @@ class FastModel(FastBaseModel): model_config = AutoConfig.from_pretrained( model_name, - token=token, - trust_remote_code=trust_remote_code, + token = token, + trust_remote_code = trust_remote_code, ) if not was_disabled: @@ -1092,40 +1092,40 @@ class FastModel(FastBaseModel): break # Patch gradient checkpointing if use_gradient_checkpointing == "unsloth": - patch_unsloth_smart_gradient_checkpointing(dtype=dtype) + patch_unsloth_smart_gradient_checkpointing(dtype = dtype) with redirector: - patch_loss_functions(torch_compile=False) + patch_loss_functions(torch_compile = False) model_types, supports_sdpa = unsloth_compile_transformers( - dtype=dtype, - model_name=model_name, - model_types=model_types, - token=token, - sdpa_dynamic_mask=True, - sdpa_bool_masks=True, - sdpa_gqa_replace=True, - sdpa_dynamic_compile=True, - compile_attention=True, - disable_causal_masks=True, - compile_torch_modules=True, - compile_custom_modules=True, - compile_function_calls=True, - fuse_lm_head=True, - gradient_checkpointing=True, - manual_replacements=True, - fast_lora_forwards=True, - fast_residual_stream=False, - accurate_accumulation=True, - epilogue_fusion=True, - max_autotune=False, - shape_padding=True, - cudagraphs=False, - debug=False, - fullgraph=fullgraph, - import_from_cache=False, - disable=False, - return_logits=return_logits, - trust_remote_code=trust_remote_code, - unsloth_force_compile=unsloth_force_compile, + dtype = dtype, + model_name = model_name, + model_types = model_types, + token = token, + sdpa_dynamic_mask = True, + sdpa_bool_masks = True, + sdpa_gqa_replace = True, + sdpa_dynamic_compile = True, + compile_attention = True, + disable_causal_masks = True, + compile_torch_modules = True, + compile_custom_modules = True, + compile_function_calls = True, + fuse_lm_head = True, + gradient_checkpointing = True, + manual_replacements = True, + fast_lora_forwards = True, + fast_residual_stream = False, + accurate_accumulation = True, + epilogue_fusion = True, + max_autotune = False, + shape_padding = True, + cudagraphs = False, + debug = False, + fullgraph = fullgraph, + import_from_cache = False, + disable = False, + return_logits = return_logits, + trust_remote_code = trust_remote_code, + unsloth_force_compile = unsloth_force_compile, ) # Fix SDPA issues for model_type in DISABLE_SDPA_MODEL_NAMES: @@ -1152,34 +1152,34 @@ class FastModel(FastBaseModel): auto_model = AutoModelForVision2Seq if is_vlm else AutoModelForCausalLM model, tokenizer = FastBaseModel.from_pretrained( - model_name=model_name, - max_seq_length=max_seq_length, - dtype=_get_dtype(dtype), - load_in_4bit=load_in_4bit, - load_in_8bit=load_in_8bit, - load_in_16bit=load_in_16bit, - full_finetuning=full_finetuning, - token=token, - device_map=device_map, - trust_remote_code=trust_remote_code, - revision=revision if not is_peft else None, - model_types=model_types, - tokenizer_name=tokenizer_name, - auto_model=auto_model, - use_gradient_checkpointing=use_gradient_checkpointing, - supports_sdpa=supports_sdpa, - whisper_language=whisper_language, - whisper_task=whisper_task, - auto_config=model_config, - offload_embedding=offload_embedding, - float32_mixed_precision=float32_mixed_precision, + model_name = model_name, + max_seq_length = max_seq_length, + dtype = _get_dtype(dtype), + load_in_4bit = load_in_4bit, + load_in_8bit = load_in_8bit, + load_in_16bit = load_in_16bit, + full_finetuning = full_finetuning, + token = token, + device_map = device_map, + trust_remote_code = trust_remote_code, + revision = revision if not is_peft else None, + model_types = model_types, + tokenizer_name = tokenizer_name, + auto_model = auto_model, + use_gradient_checkpointing = use_gradient_checkpointing, + supports_sdpa = supports_sdpa, + whisper_language = whisper_language, + whisper_task = whisper_task, + auto_config = model_config, + offload_embedding = offload_embedding, + float32_mixed_precision = float32_mixed_precision, # Pass vLLM/inference parameters - fast_inference=fast_inference, - gpu_memory_utilization=gpu_memory_utilization, - float8_kv_cache=float8_kv_cache, - random_state=random_state, - max_lora_rank=max_lora_rank, - disable_log_stats=disable_log_stats, + fast_inference = fast_inference, + gpu_memory_utilization = gpu_memory_utilization, + float8_kv_cache = float8_kv_cache, + random_state = random_state, + max_lora_rank = max_lora_rank, + disable_log_stats = disable_log_stats, *args, **kwargs, ) @@ -1228,14 +1228,14 @@ class FastModel(FastBaseModel): model = PeftModel.from_pretrained( model, old_model_name, - token=token, - revision=revision, - is_trainable=True, - trust_remote_code=trust_remote_code, + token = token, + revision = revision, + is_trainable = True, + trust_remote_code = trust_remote_code, ) # Patch it as well! model = FastBaseModel.post_patch_model( - model, use_gradient_checkpointing, trust_remote_code=trust_remote_code + model, use_gradient_checkpointing, trust_remote_code = trust_remote_code ) # Apply QAT if specified @@ -1249,7 +1249,7 @@ class FastModel(FastBaseModel): "UNSLOTH_TILED_MLP", "arctic" if unsloth_tiled_mlp else "0" ) if patch_tiled_mlp_choice != "0" or unsloth_tiled_mlp: - patch_tiled_mlp(model, patch_options_str=patch_tiled_mlp_choice) + patch_tiled_mlp(model, patch_options_str = patch_tiled_mlp_choice) return model, tokenizer diff --git a/unsloth/models/qwen2.py b/unsloth/models/qwen2.py index 5207e9708..3f819d6dc 100644 --- a/unsloth/models/qwen2.py +++ b/unsloth/models/qwen2.py @@ -39,10 +39,10 @@ class FastQwen2Model(FastLlamaModel): @staticmethod def pre_patch(): init_name, function = patch_linear_scaling( - model_name="qwen2", - rope_module=LlamaRotaryEmbedding, - scaled_rope_module=LlamaLinearScalingRotaryEmbedding, - attention_module=Qwen2Attention, + model_name = "qwen2", + rope_module = LlamaRotaryEmbedding, + scaled_rope_module = LlamaLinearScalingRotaryEmbedding, + attention_module = Qwen2Attention, ) if init_name is not None: exec(function, globals()) @@ -72,30 +72,30 @@ class FastQwen2Model(FastLlamaModel): @staticmethod def from_pretrained( - model_name="Qwen/Qwen2-7B", - max_seq_length=4096, - dtype=None, - load_in_4bit=True, - token=None, - device_map="sequential", - rope_scaling=None, # Qwen2 does not support RoPE scaling - fix_tokenizer=True, - model_patcher=None, - tokenizer_name=None, - trust_remote_code=False, + model_name = "Qwen/Qwen2-7B", + max_seq_length = 4096, + dtype = None, + load_in_4bit = True, + token = None, + device_map = "sequential", + rope_scaling = None, # Qwen2 does not support RoPE scaling + fix_tokenizer = True, + model_patcher = None, + tokenizer_name = None, + trust_remote_code = False, **kwargs, ): return FastLlamaModel.from_pretrained( - model_name=model_name, - max_seq_length=max_seq_length, - dtype=dtype, - load_in_4bit=load_in_4bit, - token=token, - device_map=device_map, - rope_scaling=rope_scaling, - fix_tokenizer=fix_tokenizer, - model_patcher=FastQwen2Model, - tokenizer_name=tokenizer_name, - trust_remote_code=trust_remote_code, + model_name = model_name, + max_seq_length = max_seq_length, + dtype = dtype, + load_in_4bit = load_in_4bit, + token = token, + device_map = device_map, + rope_scaling = rope_scaling, + fix_tokenizer = fix_tokenizer, + model_patcher = FastQwen2Model, + tokenizer_name = tokenizer_name, + trust_remote_code = trust_remote_code, **kwargs, ) diff --git a/unsloth/models/qwen3.py b/unsloth/models/qwen3.py index 5c935203d..6f41579d8 100644 --- a/unsloth/models/qwen3.py +++ b/unsloth/models/qwen3.py @@ -115,7 +115,7 @@ def Qwen3Attention_fast_forward( else: # Extend RoPE dynamically to fit in VRA rotary_emb = self.rotary_emb - rotary_emb.extend_rope_embedding(V, seq_len=kv_seq_len) + rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len) device_index = Q.device.index if position_ids is None: @@ -126,8 +126,8 @@ def Qwen3Attention_fast_forward( Q, K = fast_rope_embedding(Q, K, cos, sin) if past_key_value is not None: - K = torch.cat([past_key_value[0], K], dim=2) - V = torch.cat([past_key_value[1], V], dim=2) + K = torch.cat([past_key_value[0], K], dim = 2) + V = torch.cat([past_key_value[1], V], dim = 2) past_key_value = (K, V) if use_cache else None # Attention module @@ -163,7 +163,7 @@ def Qwen3Attention_fast_forward( K = K.view(1, K_M, n_kv_heads, n_groups, head_dim) V = V.view(1, V_M, n_kv_heads, n_groups, head_dim) - A = xformers_attention(Q, K, V, attn_bias=causal_mask) + A = xformers_attention(Q, K, V, attn_bias = causal_mask) A = A.view(bsz, q_len, n_heads, head_dim) elif HAS_FLASH_ATTENTION and attention_mask is None: @@ -172,7 +172,7 @@ def Qwen3Attention_fast_forward( V = V.transpose(1, 2) sw = kv_seq_len window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw) - A = flash_attn_func(Q, K, V, causal=True, window_size=window) + A = flash_attn_func(Q, K, V, causal = True, window_size = window) else: # Grouped query attention # if n_groups != 1: @@ -195,7 +195,7 @@ def Qwen3Attention_fast_forward( is_causal = False A = scaled_dot_product_attention( - Q, K, V, attn_mask=attention_mask, is_causal=is_causal + Q, K, V, attn_mask = attention_mask, is_causal = is_causal ) # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2).contiguous() @@ -214,8 +214,8 @@ def Qwen3Attention_fast_forward_inference( hidden_states: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]], position_ids, - do_prefill=False, - attention_mask=None, + do_prefill = False, + attention_mask = None, ): """ https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406 @@ -267,29 +267,29 @@ def Qwen3Attention_fast_forward_inference( if do_prefill: self.paged_attention = torch.empty( (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim), - dtype=dtype, - device=device, + dtype = dtype, + device = device, ) self.paged_attention_K = self.paged_attention[:, 0] self.paged_attention_V = self.paged_attention[:, 1] self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3) self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3) self.temp_QA = torch.empty( - (2, bsz, 1, attention_size), dtype=dtype, device=device + (2, bsz, 1, attention_size), dtype = dtype, device = device ) self.temp_KV = torch.empty( - (2, bsz, 1, n_kv_heads * head_dim), dtype=dtype, device=device + (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device ) - self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype=dtype, device=device) + self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device) # Mistral Nemo 12b has weird dimensions if attention_size != hidden_size: - self.temp_O = torch.empty((1, bsz, hidden_size), dtype=dtype, device=device) + self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device) else: self.temp_O = self.temp_QA[1][:, :, :hidden_size] self.attention = torch.empty( - (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype=dtype, device=device + (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device ) self.scalar = 1.0 / math_sqrt(self.head_dim) self.half_head_dim = head_dim // 2 @@ -309,9 +309,9 @@ def Qwen3Attention_fast_forward_inference( (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT) ) - Qn = fast_linear_forward(self.q_proj, Xn, out=self.temp_QA[0]) - Kn = fast_linear_forward(self.k_proj, Xn, out=self.temp_KV[0]) - Vn = fast_linear_forward(self.v_proj, Xn, out=self.temp_KV[1]) + Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0]) + Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0]) + Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1]) Qn = Qn.view( bsz, 1, n_heads, head_dim ) # .transpose(1, 2) # we will transpose after normalisation @@ -399,30 +399,30 @@ def Qwen3Attention_fast_forward_inference( Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963 # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows A = torch_matmul( - Qn, Knn.transpose(2, 3), out=self.attention[:, :, :, :cached_len] + Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len] ) # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched A[:] = torch_nn_functional_softmax( - A, dim=-1, dtype=torch.float32 + A, dim = -1, dtype = torch.float32 ) # .to(A.dtype) - A = torch_matmul(A, Vnn, out=Qn) + A = torch_matmul(A, Vnn, out = Qn) else: if SDPA_HAS_GQA: A = scaled_dot_product_attention( Qn, Knn, Vnn, - attn_mask=attention_mask, - is_causal=is_causal, - enable_gqa=True, + attn_mask = attention_mask, + is_causal = is_causal, + enable_gqa = True, ) else: A = scaled_dot_product_attention( - Qn, Knn, Vnn, attn_mask=attention_mask, is_causal=is_causal + Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = is_causal ) A = A.transpose(1, 2) A = A.reshape(bsz, 1, attention_size) - A = fast_linear_forward(self.o_proj, A, out=self.temp_O) + A = fast_linear_forward(self.o_proj, A, out = self.temp_O) return A, (Kn, Vn) @@ -430,10 +430,10 @@ class FastQwen3Model(FastLlamaModel): @staticmethod def pre_patch(): init_name, function = patch_linear_scaling( - model_name="Qwen3", - rope_module=LlamaRotaryEmbedding, - scaled_rope_module=LlamaLinearScalingRotaryEmbedding, - attention_module=Qwen3Attention, + model_name = "Qwen3", + rope_module = LlamaRotaryEmbedding, + scaled_rope_module = LlamaLinearScalingRotaryEmbedding, + attention_module = Qwen3Attention, ) if init_name is not None: exec(function, globals()) @@ -463,30 +463,30 @@ class FastQwen3Model(FastLlamaModel): @staticmethod def from_pretrained( # TODO: Change after release - model_name="Qwen/Qwen3-7B", - max_seq_length=4096, - dtype=None, - load_in_4bit=True, - token=None, - device_map="sequential", - rope_scaling=None, - fix_tokenizer=True, - model_patcher=None, - tokenizer_name=None, - trust_remote_code=False, + model_name = "Qwen/Qwen3-7B", + max_seq_length = 4096, + dtype = None, + load_in_4bit = True, + token = None, + device_map = "sequential", + rope_scaling = None, + fix_tokenizer = True, + model_patcher = None, + tokenizer_name = None, + trust_remote_code = False, **kwargs, ): return FastLlamaModel.from_pretrained( - model_name=model_name, - max_seq_length=max_seq_length, - dtype=dtype, - load_in_4bit=load_in_4bit, - token=token, - device_map=device_map, - rope_scaling=rope_scaling, - fix_tokenizer=fix_tokenizer, - model_patcher=FastQwen3Model, - tokenizer_name=tokenizer_name, - trust_remote_code=trust_remote_code, + model_name = model_name, + max_seq_length = max_seq_length, + dtype = dtype, + load_in_4bit = load_in_4bit, + token = token, + device_map = device_map, + rope_scaling = rope_scaling, + fix_tokenizer = fix_tokenizer, + model_patcher = FastQwen3Model, + tokenizer_name = tokenizer_name, + trust_remote_code = trust_remote_code, **kwargs, ) diff --git a/unsloth/models/qwen3_moe.py b/unsloth/models/qwen3_moe.py index ae0bf9dee..bec3fa7b0 100644 --- a/unsloth/models/qwen3_moe.py +++ b/unsloth/models/qwen3_moe.py @@ -49,29 +49,29 @@ from unsloth_zoo.utils import Version, _get_dtype torch_nn_functional_softmax = torch.nn.functional.softmax -def Qwen3MoeSparseMoeBlock_fast_forward(self, X, temp_gate=None, temp_up=None): +def Qwen3MoeSparseMoeBlock_fast_forward(self, X, temp_gate = None, temp_up = None): # adapted from https://github.com/huggingface/transformers/pull/36878/files#diff-0855b77fc27ad9449158a1c74953f909b011c00de7125f7c8e68d0ff209c092aR356-R370 bsz, seq_len, hd = X.shape X = X.view(-1, hd) router_logits = fast_linear_forward( - self.gate_proj, X, out=temp_gate + self.gate_proj, X, out = temp_gate ) # pretty much the only change from transformers implementation. routing_weights = torch_nn_functional_softmax( - router_logits, dim=-1, dtype=torch.float32 + router_logits, dim = -1, dtype = torch.float32 ) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim = -1) + routing_weights /= routing_weights.sum(dim = -1, keepdim = True) # we cast back to the input dtype routing_weights = routing_weights.to(X.dtype) - final_X = torch.zeros((bsz * seq_len, hd), dtype=torch.float32, device=X.device) + final_X = torch.zeros((bsz * seq_len, hd), dtype = torch.float32, device = X.device) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot( - selected_experts, num_classes=self.num_experts + selected_experts, num_classes = self.num_experts ).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert @@ -119,16 +119,16 @@ def Qwen3MoeDecoderLayer_fast_forward( self.input_layernorm, hidden_states ) hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - causal_mask=causal_mask, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - padding_mask=padding_mask, - position_embeddings=position_embeddings, - _flag_for_generation=self._flag_for_generation, + hidden_states = hidden_states, + causal_mask = causal_mask, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_value = past_key_value, + output_attentions = output_attentions, + use_cache = use_cache, + padding_mask = padding_mask, + position_embeddings = position_embeddings, + _flag_for_generation = self._flag_for_generation, ) hidden_states += residual @@ -145,15 +145,15 @@ def Qwen3MoeDecoderLayer_fast_forward( residual = hidden_states hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states) hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - causal_mask=causal_mask, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - padding_mask=padding_mask, - position_embeddings=position_embeddings, + hidden_states = hidden_states, + causal_mask = causal_mask, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_value = past_key_value, + output_attentions = output_attentions, + use_cache = use_cache, + padding_mask = padding_mask, + position_embeddings = position_embeddings, ) hidden_states = residual + hidden_states @@ -177,10 +177,10 @@ class FastQwen3MoeModel(FastQwen3Model): @staticmethod def pre_patch(): init_name, function = patch_linear_scaling( - model_name="Qwen3Moe", - rope_module=LlamaRotaryEmbedding, - scaled_rope_module=LlamaLinearScalingRotaryEmbedding, - attention_module=Qwen3MoeAttention, + model_name = "Qwen3Moe", + rope_module = LlamaRotaryEmbedding, + scaled_rope_module = LlamaLinearScalingRotaryEmbedding, + attention_module = Qwen3MoeAttention, ) if init_name is not None: exec(function, globals()) @@ -214,30 +214,30 @@ class FastQwen3MoeModel(FastQwen3Model): @staticmethod def from_pretrained( # TODO: Change after release - model_name="Qwen/Qwen3-7B", - max_seq_length=4096, - dtype=None, - load_in_4bit=True, - token=None, - device_map="sequential", - rope_scaling=None, - fix_tokenizer=True, - model_patcher=None, - tokenizer_name=None, - trust_remote_code=False, + model_name = "Qwen/Qwen3-7B", + max_seq_length = 4096, + dtype = None, + load_in_4bit = True, + token = None, + device_map = "sequential", + rope_scaling = None, + fix_tokenizer = True, + model_patcher = None, + tokenizer_name = None, + trust_remote_code = False, **kwargs, ): return FastLlamaModel.from_pretrained( - model_name=model_name, - max_seq_length=max_seq_length, - dtype=dtype, - load_in_4bit=load_in_4bit, - token=token, - device_map=device_map, - rope_scaling=rope_scaling, - fix_tokenizer=fix_tokenizer, - model_patcher=FastQwen3Model, - tokenizer_name=tokenizer_name, - trust_remote_code=trust_remote_code, + model_name = model_name, + max_seq_length = max_seq_length, + dtype = dtype, + load_in_4bit = load_in_4bit, + token = token, + device_map = device_map, + rope_scaling = rope_scaling, + fix_tokenizer = fix_tokenizer, + model_patcher = FastQwen3Model, + tokenizer_name = tokenizer_name, + trust_remote_code = trust_remote_code, **kwargs, ) diff --git a/unsloth/registry/_deepseek.py b/unsloth/registry/_deepseek.py index bb6642456..4bbc852cd 100644 --- a/unsloth/registry/_deepseek.py +++ b/unsloth/registry/_deepseek.py @@ -30,70 +30,70 @@ class DeepseekR1ModelInfo(ModelInfo): # Deepseek V3 Model Meta DeepseekV3Meta = ModelMeta( - org="deepseek-ai", - base_name="DeepSeek", - instruct_tags=[None], - model_version="3", - model_sizes=[""], - model_info_cls=DeepseekV3ModelInfo, - is_multimodal=False, - quant_types=[QuantType.NONE, QuantType.BF16], + org = "deepseek-ai", + base_name = "DeepSeek", + instruct_tags = [None], + model_version = "3", + model_sizes = [""], + model_info_cls = DeepseekV3ModelInfo, + is_multimodal = False, + quant_types = [QuantType.NONE, QuantType.BF16], ) DeepseekV3_0324Meta = ModelMeta( - org="deepseek-ai", - base_name="DeepSeek", - instruct_tags=[None], - model_version="3-0324", - model_sizes=[""], - model_info_cls=DeepseekV3ModelInfo, - is_multimodal=False, - quant_types=[QuantType.NONE, QuantType.GGUF], + org = "deepseek-ai", + base_name = "DeepSeek", + instruct_tags = [None], + model_version = "3-0324", + model_sizes = [""], + model_info_cls = DeepseekV3ModelInfo, + is_multimodal = False, + quant_types = [QuantType.NONE, QuantType.GGUF], ) DeepseekR1Meta = ModelMeta( - org="deepseek-ai", - base_name="DeepSeek-R1", - instruct_tags=[None], - model_version="", - model_sizes=[""], - model_info_cls=DeepseekR1ModelInfo, - is_multimodal=False, - quant_types=[QuantType.NONE, QuantType.BF16, QuantType.GGUF], + org = "deepseek-ai", + base_name = "DeepSeek-R1", + instruct_tags = [None], + model_version = "", + model_sizes = [""], + model_info_cls = DeepseekR1ModelInfo, + is_multimodal = False, + quant_types = [QuantType.NONE, QuantType.BF16, QuantType.GGUF], ) DeepseekR1ZeroMeta = ModelMeta( - org="deepseek-ai", - base_name="DeepSeek-R1", - instruct_tags=[None], - model_version="Zero", - model_sizes=[""], - model_info_cls=DeepseekR1ModelInfo, - is_multimodal=False, - quant_types=[QuantType.NONE, QuantType.GGUF], + org = "deepseek-ai", + base_name = "DeepSeek-R1", + instruct_tags = [None], + model_version = "Zero", + model_sizes = [""], + model_info_cls = DeepseekR1ModelInfo, + is_multimodal = False, + quant_types = [QuantType.NONE, QuantType.GGUF], ) DeepseekR1DistillLlamaMeta = ModelMeta( - org="deepseek-ai", - base_name="DeepSeek-R1-Distill", - instruct_tags=[None], - model_version="Llama", - model_sizes=["8", "70"], - model_info_cls=DeepseekR1ModelInfo, - is_multimodal=False, - quant_types={"8": [QuantType.UNSLOTH, QuantType.GGUF], "70": [QuantType.GGUF]}, + org = "deepseek-ai", + base_name = "DeepSeek-R1-Distill", + instruct_tags = [None], + model_version = "Llama", + model_sizes = ["8", "70"], + model_info_cls = DeepseekR1ModelInfo, + is_multimodal = False, + quant_types = {"8": [QuantType.UNSLOTH, QuantType.GGUF], "70": [QuantType.GGUF]}, ) # Deepseek R1 Distill Qwen Model Meta DeepseekR1DistillQwenMeta = ModelMeta( - org="deepseek-ai", - base_name="DeepSeek-R1-Distill", - instruct_tags=[None], - model_version="Qwen", - model_sizes=["1.5", "7", "14", "32"], - model_info_cls=DeepseekR1ModelInfo, - is_multimodal=False, - quant_types={ + org = "deepseek-ai", + base_name = "DeepSeek-R1-Distill", + instruct_tags = [None], + model_version = "Qwen", + model_sizes = ["1.5", "7", "14", "32"], + model_info_cls = DeepseekR1ModelInfo, + is_multimodal = False, + quant_types = { "1.5": [QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF], "7": [QuantType.UNSLOTH, QuantType.BNB], "14": [QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF], @@ -106,7 +106,7 @@ def register_deepseek_v3_models(include_original_model: bool = False): global _IS_DEEPSEEK_V3_REGISTERED if _IS_DEEPSEEK_V3_REGISTERED: return - _register_models(DeepseekV3Meta, include_original_model=include_original_model) + _register_models(DeepseekV3Meta, include_original_model = include_original_model) _IS_DEEPSEEK_V3_REGISTERED = True @@ -114,7 +114,7 @@ def register_deepseek_v3_0324_models(include_original_model: bool = False): global _IS_DEEPSEEK_V3_0324_REGISTERED if _IS_DEEPSEEK_V3_0324_REGISTERED: return - _register_models(DeepseekV3_0324Meta, include_original_model=include_original_model) + _register_models(DeepseekV3_0324Meta, include_original_model = include_original_model) _IS_DEEPSEEK_V3_0324_REGISTERED = True @@ -122,7 +122,7 @@ def register_deepseek_r1_models(include_original_model: bool = False): global _IS_DEEPSEEK_R1_REGISTERED if _IS_DEEPSEEK_R1_REGISTERED: return - _register_models(DeepseekR1Meta, include_original_model=include_original_model) + _register_models(DeepseekR1Meta, include_original_model = include_original_model) _IS_DEEPSEEK_R1_REGISTERED = True @@ -130,7 +130,7 @@ def register_deepseek_r1_zero_models(include_original_model: bool = False): global _IS_DEEPSEEK_R1_ZERO_REGISTERED if _IS_DEEPSEEK_R1_ZERO_REGISTERED: return - _register_models(DeepseekR1ZeroMeta, include_original_model=include_original_model) + _register_models(DeepseekR1ZeroMeta, include_original_model = include_original_model) _IS_DEEPSEEK_R1_ZERO_REGISTERED = True @@ -139,7 +139,7 @@ def register_deepseek_r1_distill_llama_models(include_original_model: bool = Fal if _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED: return _register_models( - DeepseekR1DistillLlamaMeta, include_original_model=include_original_model + DeepseekR1DistillLlamaMeta, include_original_model = include_original_model ) _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED = True @@ -149,21 +149,21 @@ def register_deepseek_r1_distill_qwen_models(include_original_model: bool = Fals if _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED: return _register_models( - DeepseekR1DistillQwenMeta, include_original_model=include_original_model + DeepseekR1DistillQwenMeta, include_original_model = include_original_model ) _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED = True def register_deepseek_models(include_original_model: bool = False): - register_deepseek_v3_models(include_original_model=include_original_model) - register_deepseek_v3_0324_models(include_original_model=include_original_model) - register_deepseek_r1_models(include_original_model=include_original_model) - register_deepseek_r1_zero_models(include_original_model=include_original_model) + register_deepseek_v3_models(include_original_model = include_original_model) + register_deepseek_v3_0324_models(include_original_model = include_original_model) + register_deepseek_r1_models(include_original_model = include_original_model) + register_deepseek_r1_zero_models(include_original_model = include_original_model) register_deepseek_r1_distill_llama_models( - include_original_model=include_original_model + include_original_model = include_original_model ) register_deepseek_r1_distill_qwen_models( - include_original_model=include_original_model + include_original_model = include_original_model ) @@ -172,7 +172,7 @@ def _list_deepseek_r1_distill_models(): from unsloth.utils.hf_hub import list_models models: list[HfModelInfo] = list_models( - author="unsloth", search="Distill", limit=1000 + author = "unsloth", search = "Distill", limit = 1000 ) distill_models = [] for model in models: @@ -185,14 +185,14 @@ def _list_deepseek_r1_distill_models(): return distill_models -register_deepseek_models(include_original_model=True) +register_deepseek_models(include_original_model = True) if __name__ == "__main__": from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info MODEL_REGISTRY.clear() - register_deepseek_models(include_original_model=True) + register_deepseek_models(include_original_model = True) for model_id, model_info in MODEL_REGISTRY.items(): model_info = _check_model_info(model_id) diff --git a/unsloth/registry/_gemma.py b/unsloth/registry/_gemma.py index 775b4b762..c338128bc 100644 --- a/unsloth/registry/_gemma.py +++ b/unsloth/registry/_gemma.py @@ -15,26 +15,26 @@ class GemmaModelInfo(ModelInfo): # Gemma3 Base Model Meta GemmaMeta3Base = ModelMeta( - org="google", - base_name="gemma", - instruct_tags=["pt"], # pt = base - model_version="3", - model_sizes=["1", "4", "12", "27"], - model_info_cls=GemmaModelInfo, - is_multimodal=True, - quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], + org = "google", + base_name = "gemma", + instruct_tags = ["pt"], # pt = base + model_version = "3", + model_sizes = ["1", "4", "12", "27"], + model_info_cls = GemmaModelInfo, + is_multimodal = True, + quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], ) # Gemma3 Instruct Model Meta GemmaMeta3Instruct = ModelMeta( - org="google", - base_name="gemma", - instruct_tags=["it"], # it = instruction tuned - model_version="3", - model_sizes=["1", "4", "12", "27"], - model_info_cls=GemmaModelInfo, - is_multimodal=True, - quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF], + org = "google", + base_name = "gemma", + instruct_tags = ["it"], # it = instruction tuned + model_version = "3", + model_sizes = ["1", "4", "12", "27"], + model_info_cls = GemmaModelInfo, + is_multimodal = True, + quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF], ) @@ -42,7 +42,7 @@ def register_gemma_3_base_models(include_original_model: bool = False): global _IS_GEMMA_3_BASE_REGISTERED if _IS_GEMMA_3_BASE_REGISTERED: return - _register_models(GemmaMeta3Base, include_original_model=include_original_model) + _register_models(GemmaMeta3Base, include_original_model = include_original_model) _IS_GEMMA_3_BASE_REGISTERED = True @@ -50,13 +50,13 @@ def register_gemma_3_instruct_models(include_original_model: bool = False): global _IS_GEMMA_3_INSTRUCT_REGISTERED if _IS_GEMMA_3_INSTRUCT_REGISTERED: return - _register_models(GemmaMeta3Instruct, include_original_model=include_original_model) + _register_models(GemmaMeta3Instruct, include_original_model = include_original_model) _IS_GEMMA_3_INSTRUCT_REGISTERED = True def register_gemma_models(include_original_model: bool = False): - register_gemma_3_base_models(include_original_model=include_original_model) - register_gemma_3_instruct_models(include_original_model=include_original_model) + register_gemma_3_base_models(include_original_model = include_original_model) + register_gemma_3_instruct_models(include_original_model = include_original_model) if __name__ == "__main__": @@ -64,7 +64,7 @@ if __name__ == "__main__": MODEL_REGISTRY.clear() - register_gemma_models(include_original_model=True) + register_gemma_models(include_original_model = True) for model_id, model_info in MODEL_REGISTRY.items(): model_info = _check_model_info(model_id) diff --git a/unsloth/registry/_mistral.py b/unsloth/registry/_mistral.py index 817fc1d5c..173d6cfde 100644 --- a/unsloth/registry/_mistral.py +++ b/unsloth/registry/_mistral.py @@ -23,14 +23,14 @@ class MistralSmallModelInfo(ModelInfo): MistralSmall_2503_Base_Meta = ModelMeta( - org="mistralai", - base_name="Mistral-Small", - instruct_tags=["Base"], - model_version=_MISTRAL_SMALL_03_25_VERSION, - model_sizes=["24"], - model_info_cls=MistralSmallModelInfo, - is_multimodal=False, - quant_types=[QuantType.NONE, QuantType.UNSLOTH, QuantType.BNB], + org = "mistralai", + base_name = "Mistral-Small", + instruct_tags = ["Base"], + model_version = _MISTRAL_SMALL_03_25_VERSION, + model_sizes = ["24"], + model_info_cls = MistralSmallModelInfo, + is_multimodal = False, + quant_types = [QuantType.NONE, QuantType.UNSLOTH, QuantType.BNB], ) MistralSmall_2503_Instruct_Meta = copy.deepcopy(MistralSmall_2503_Base_Meta) @@ -54,23 +54,23 @@ def register_mistral_small_models(include_original_model: bool = False): if _IS_MISTRAL_SMALL_REGISTERED: return _register_models( - MistralSmall_2503_Base_Meta, include_original_model=include_original_model + MistralSmall_2503_Base_Meta, include_original_model = include_original_model ) _register_models( - MistralSmall_2503_Instruct_Meta, include_original_model=include_original_model + MistralSmall_2503_Instruct_Meta, include_original_model = include_original_model ) _register_models( - MistralSmall_2501_Base_Meta, include_original_model=include_original_model + MistralSmall_2501_Base_Meta, include_original_model = include_original_model ) _register_models( - MistralSmall_2501_Instruct_Meta, include_original_model=include_original_model + MistralSmall_2501_Instruct_Meta, include_original_model = include_original_model ) _IS_MISTRAL_SMALL_REGISTERED = True def register_mistral_models(include_original_model: bool = False): - register_mistral_small_models(include_original_model=include_original_model) + register_mistral_small_models(include_original_model = include_original_model) if __name__ == "__main__": @@ -78,7 +78,7 @@ if __name__ == "__main__": MODEL_REGISTRY.clear() - register_mistral_models(include_original_model=True) + register_mistral_models(include_original_model = True) for model_id, model_info in MODEL_REGISTRY.items(): model_info = _check_model_info(model_id) diff --git a/unsloth/registry/_qwen.py b/unsloth/registry/_qwen.py index fa4042bbe..f852cb876 100644 --- a/unsloth/registry/_qwen.py +++ b/unsloth/registry/_qwen.py @@ -43,50 +43,50 @@ class QwenQVQPreviewModelInfo(ModelInfo): # Qwen2.5 Model Meta Qwen_2_5_Meta = ModelMeta( - org="Qwen", - base_name="Qwen", - instruct_tags=[None, "Instruct"], - model_version="2.5", - model_sizes=["3", "7"], - model_info_cls=QwenModelInfo, - is_multimodal=False, - quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], + org = "Qwen", + base_name = "Qwen", + instruct_tags = [None, "Instruct"], + model_version = "2.5", + model_sizes = ["3", "7"], + model_info_cls = QwenModelInfo, + is_multimodal = False, + quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], ) # Qwen2.5 VL Model Meta Qwen_2_5_VLMeta = ModelMeta( - org="Qwen", - base_name="Qwen", - instruct_tags=["Instruct"], # No base, only instruction tuned - model_version="2.5", - model_sizes=["3", "7", "32", "72"], - model_info_cls=QwenVLModelInfo, - is_multimodal=True, - quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], + org = "Qwen", + base_name = "Qwen", + instruct_tags = ["Instruct"], # No base, only instruction tuned + model_version = "2.5", + model_sizes = ["3", "7", "32", "72"], + model_info_cls = QwenVLModelInfo, + is_multimodal = True, + quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], ) # Qwen QwQ Model Meta QwenQwQMeta = ModelMeta( - org="Qwen", - base_name="QwQ", - instruct_tags=[None], - model_version="", - model_sizes=["32"], - model_info_cls=QwenQwQModelInfo, - is_multimodal=False, - quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF], + org = "Qwen", + base_name = "QwQ", + instruct_tags = [None], + model_version = "", + model_sizes = ["32"], + model_info_cls = QwenQwQModelInfo, + is_multimodal = False, + quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF], ) # Qwen QVQ Preview Model Meta QwenQVQPreviewMeta = ModelMeta( - org="Qwen", - base_name="QVQ", - instruct_tags=[None], - model_version="", - model_sizes=["72"], - model_info_cls=QwenQVQPreviewModelInfo, - is_multimodal=True, - quant_types=[QuantType.NONE, QuantType.BNB], + org = "Qwen", + base_name = "QVQ", + instruct_tags = [None], + model_version = "", + model_sizes = ["72"], + model_info_cls = QwenQVQPreviewModelInfo, + is_multimodal = True, + quant_types = [QuantType.NONE, QuantType.BNB], ) @@ -94,7 +94,7 @@ def register_qwen_2_5_models(include_original_model: bool = False): global _IS_QWEN_2_5_REGISTERED if _IS_QWEN_2_5_REGISTERED: return - _register_models(Qwen_2_5_Meta, include_original_model=include_original_model) + _register_models(Qwen_2_5_Meta, include_original_model = include_original_model) _IS_QWEN_2_5_REGISTERED = True @@ -102,7 +102,7 @@ def register_qwen_2_5_vl_models(include_original_model: bool = False): global _IS_QWEN_2_5_VL_REGISTERED if _IS_QWEN_2_5_VL_REGISTERED: return - _register_models(Qwen_2_5_VLMeta, include_original_model=include_original_model) + _register_models(Qwen_2_5_VLMeta, include_original_model = include_original_model) _IS_QWEN_2_5_VL_REGISTERED = True @@ -110,15 +110,15 @@ def register_qwen_qwq_models(include_original_model: bool = False): global _IS_QWEN_QWQ_REGISTERED if _IS_QWEN_QWQ_REGISTERED: return - _register_models(QwenQwQMeta, include_original_model=include_original_model) - _register_models(QwenQVQPreviewMeta, include_original_model=include_original_model) + _register_models(QwenQwQMeta, include_original_model = include_original_model) + _register_models(QwenQVQPreviewMeta, include_original_model = include_original_model) _IS_QWEN_QWQ_REGISTERED = True def register_qwen_models(include_original_model: bool = False): - register_qwen_2_5_models(include_original_model=include_original_model) - register_qwen_2_5_vl_models(include_original_model=include_original_model) - register_qwen_qwq_models(include_original_model=include_original_model) + register_qwen_2_5_models(include_original_model = include_original_model) + register_qwen_2_5_vl_models(include_original_model = include_original_model) + register_qwen_qwq_models(include_original_model = include_original_model) if __name__ == "__main__": @@ -126,7 +126,7 @@ if __name__ == "__main__": MODEL_REGISTRY.clear() - register_qwen_models(include_original_model=True) + register_qwen_models(include_original_model = True) for model_id, model_info in MODEL_REGISTRY.items(): model_info = _check_model_info(model_id) diff --git a/unsloth/registry/registry.py b/unsloth/registry/registry.py index 880c9fb33..945301420 100644 --- a/unsloth/registry/registry.py +++ b/unsloth/registry/registry.py @@ -62,7 +62,7 @@ class ModelInfo: @classmethod def construct_model_name( - cls, base_name, version, size, quant_type, instruct_tag, key="" + cls, base_name, version, size, quant_type, instruct_tag, key = "" ): key = cls.append_instruct_tag(key, instruct_tag) key = cls.append_quant_type(key, quant_type) @@ -81,10 +81,10 @@ class ModelMeta: base_name: str model_version: str model_info_cls: type[ModelInfo] - model_sizes: list[str] = field(default_factory=list) - instruct_tags: list[str] = field(default_factory=list) + model_sizes: list[str] = field(default_factory = list) + instruct_tags: list[str] = field(default_factory = list) quant_types: list[QuantType] | dict[str, list[QuantType]] = field( - default_factory=list + default_factory = list ) is_multimodal: bool = False @@ -104,11 +104,11 @@ def register_model( name: str = None, ): name = name or model_info_cls.construct_model_name( - base_name=base_name, - version=version, - size=size, - quant_type=quant_type, - instruct_tag=instruct_tag, + base_name = base_name, + version = version, + size = size, + quant_type = quant_type, + instruct_tag = instruct_tag, ) key = f"{org}/{name}" @@ -118,14 +118,14 @@ def register_model( ) MODEL_REGISTRY[key] = model_info_cls( - org=org, - base_name=base_name, - version=version, - size=size, - is_multimodal=is_multimodal, - instruct_tag=instruct_tag, - quant_type=quant_type, - name=name, + org = org, + base_name = base_name, + version = version, + size = size, + is_multimodal = is_multimodal, + instruct_tag = instruct_tag, + quant_type = quant_type, + name = name, ) @@ -137,7 +137,7 @@ def _check_model_info(model_id: str, properties: list[str] = ["lastModified"]): api = HfApi() try: - model_info: HfModelInfo = api.model_info(model_id, expand=properties) + model_info: HfModelInfo = api.model_info(model_id, expand = properties) except Exception as e: if isinstance(e, RepositoryNotFoundError): warnings.warn(f"{model_id} not found on Hugging Face") @@ -168,24 +168,24 @@ def _register_models(model_meta: ModelMeta, include_original_model: bool = False # NOTE: models registered with org="unsloth" and QUANT_TYPE.NONE are aliases of QUANT_TYPE.UNSLOTH _org = "unsloth" # unsloth models -- these are all quantized versions of the original model register_model( - model_info_cls=model_info_cls, - org=_org, - base_name=base_name, - version=model_version, - size=size, - instruct_tag=instruct_tag, - quant_type=quant_type, - is_multimodal=is_multimodal, + model_info_cls = model_info_cls, + org = _org, + base_name = base_name, + version = model_version, + size = size, + instruct_tag = instruct_tag, + quant_type = quant_type, + is_multimodal = is_multimodal, ) # include original model from releasing organization if include_original_model: register_model( - model_info_cls=model_info_cls, - org=org, - base_name=base_name, - version=model_version, - size=size, - instruct_tag=instruct_tag, - quant_type=QuantType.NONE, - is_multimodal=is_multimodal, + model_info_cls = model_info_cls, + org = org, + base_name = base_name, + version = model_version, + size = size, + instruct_tag = instruct_tag, + quant_type = QuantType.NONE, + is_multimodal = is_multimodal, ) diff --git a/unsloth/trainer.py b/unsloth/trainer.py index 196508ebb..6bdb5604d 100644 --- a/unsloth/trainer.py +++ b/unsloth/trainer.py @@ -79,7 +79,7 @@ def _create_unsloth_optimizer( model, optimizer_cls, optimizer_kwargs, - embedding_lr=5e-5, + embedding_lr = 5e-5, ): lr = optimizer_kwargs["lr"] weight_decay = optimizer_kwargs.get("weight_decay", 0.0) diff --git a/unsloth/utils/hf_hub.py b/unsloth/utils/hf_hub.py index 30255b863..75df00fbf 100644 --- a/unsloth/utils/hf_hub.py +++ b/unsloth/utils/hf_hub.py @@ -36,7 +36,7 @@ def get_model_info( if _HFAPI is None: _HFAPI = HfApi() try: - model_info: ModelInfo = _HFAPI.model_info(model_id, expand=properties) + model_info: ModelInfo = _HFAPI.model_info(model_id, expand = properties) except Exception as e: print(f"Error getting model info for {model_id}: {e}") model_info = None @@ -68,11 +68,11 @@ def list_models( properties = None models: list[ModelInfo] = _HFAPI.list_models( - author=author, - search=search, - sort=sort, - limit=limit, - expand=properties, - full=full, + author = author, + search = search, + sort = sort, + limit = limit, + expand = properties, + full = full, ) return models