mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"
This reverts commit cad158a56c.
This commit is contained in:
parent
cad158a56c
commit
66649d18bd
42 changed files with 2394 additions and 2394 deletions
|
|
@ -54,37 +54,37 @@ if __name__ == "__main__":
|
|||
seed = 42
|
||||
batch_size = 5
|
||||
num_generations = 5
|
||||
tokenizer = setup_tokenizer(model_name, fixup_funcs=[fix_llama3_tokenizer])
|
||||
tokenizer = setup_tokenizer(model_name, fixup_funcs = [fix_llama3_tokenizer])
|
||||
temperature = 0.8
|
||||
max_new_tokens = 20
|
||||
|
||||
peft_config = get_peft_config(lora_rank=lora_rank, target_modules="all-linear")
|
||||
model = setup_model(model_name, quantize=True, dtype=dtype, peft_config=peft_config)
|
||||
peft_config = get_peft_config(lora_rank = lora_rank, target_modules = "all-linear")
|
||||
model = setup_model(model_name, quantize = True, dtype = dtype, peft_config = peft_config)
|
||||
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[USER_MESSAGE], tokenize=False, add_generation_prompt=True
|
||||
[USER_MESSAGE], tokenize = False, add_generation_prompt = True
|
||||
)
|
||||
with header_footer_context("Test Prompt and Answer"):
|
||||
print(f"Test Prompt:\n{prompt}\nExpected Answer:\n{ANSWER}")
|
||||
|
||||
dataset: Dataset = create_dataset(
|
||||
tokenizer, num_examples=num_examples, messages=DEFAULT_MESSAGES
|
||||
tokenizer, num_examples = num_examples, messages = DEFAULT_MESSAGES
|
||||
)
|
||||
with header_footer_context("Dataset"):
|
||||
print(f"Dataset: {next(iter(dataset))}")
|
||||
|
||||
training_args = SFTConfig(
|
||||
output_dir=output_dir,
|
||||
max_steps=max_steps,
|
||||
per_device_train_batch_size=batch_size,
|
||||
log_level="info",
|
||||
report_to="none",
|
||||
num_train_epochs=1,
|
||||
logging_steps=1,
|
||||
seed=seed,
|
||||
bf16=dtype == torch.bfloat16,
|
||||
fp16=dtype == torch.float16,
|
||||
save_strategy="no",
|
||||
output_dir = output_dir,
|
||||
max_steps = max_steps,
|
||||
per_device_train_batch_size = batch_size,
|
||||
log_level = "info",
|
||||
report_to = "none",
|
||||
num_train_epochs = 1,
|
||||
logging_steps = 1,
|
||||
seed = seed,
|
||||
bf16 = dtype == torch.bfloat16,
|
||||
fp16 = dtype == torch.float16,
|
||||
save_strategy = "no",
|
||||
)
|
||||
|
||||
with header_footer_context("Train Args"):
|
||||
|
|
@ -92,7 +92,7 @@ if __name__ == "__main__":
|
|||
print(peft_config)
|
||||
|
||||
trainer = setup_trainer(
|
||||
model, tokenizer, dataset, training_args, peft_config=peft_config
|
||||
model, tokenizer, dataset, training_args, peft_config = peft_config
|
||||
)
|
||||
|
||||
with header_footer_context("Model"):
|
||||
|
|
@ -108,11 +108,11 @@ if __name__ == "__main__":
|
|||
responses = sample_responses(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
prompt = prompt,
|
||||
**generation_args,
|
||||
)
|
||||
with header_footer_context("Responses before training"):
|
||||
check_responses(responses, answer=ANSWER, prompt=prompt)
|
||||
check_responses(responses, answer = ANSWER, prompt = prompt)
|
||||
|
||||
with header_footer_context("Peft Weights before training"):
|
||||
for name, stats in itertools.islice(describe_peft_weights(model), 2):
|
||||
|
|
@ -129,11 +129,11 @@ if __name__ == "__main__":
|
|||
responses = sample_responses(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
prompt = prompt,
|
||||
**generation_args,
|
||||
)
|
||||
with header_footer_context("Responses after training"):
|
||||
check_responses(responses, answer=ANSWER, prompt=prompt)
|
||||
check_responses(responses, answer = ANSWER, prompt = prompt)
|
||||
|
||||
model_copy = deepcopy(model)
|
||||
|
||||
|
|
@ -142,18 +142,18 @@ if __name__ == "__main__":
|
|||
responses = sample_responses(
|
||||
merged_model,
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
prompt = prompt,
|
||||
**generation_args,
|
||||
)
|
||||
with header_footer_context("Responses after custom merging to 16bit"):
|
||||
check_responses(responses, answer=ANSWER, prompt=prompt)
|
||||
check_responses(responses, answer = ANSWER, prompt = prompt)
|
||||
|
||||
merged_model_peft = model_copy.merge_and_unload()
|
||||
responses = sample_responses(
|
||||
merged_model_peft,
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
prompt = prompt,
|
||||
**generation_args,
|
||||
)
|
||||
with header_footer_context("Responses after peft merge_and_unload"):
|
||||
check_responses(responses, answer=ANSWER, prompt=prompt)
|
||||
check_responses(responses, answer = ANSWER, prompt = prompt)
|
||||
|
|
|
|||
|
|
@ -50,13 +50,13 @@ def get_unsloth_model_and_tokenizer(
|
|||
dtype: torch.dtype = torch.bfloat16,
|
||||
):
|
||||
return FastLanguageModel.from_pretrained(
|
||||
model_name=model_name,
|
||||
max_seq_length=max_seq_length,
|
||||
load_in_4bit=load_in_4bit,
|
||||
fast_inference=fast_inference,
|
||||
max_lora_rank=max_lora_rank,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
dtype=dtype,
|
||||
model_name = model_name,
|
||||
max_seq_length = max_seq_length,
|
||||
load_in_4bit = load_in_4bit,
|
||||
fast_inference = fast_inference,
|
||||
max_lora_rank = max_lora_rank,
|
||||
gpu_memory_utilization = gpu_memory_utilization,
|
||||
dtype = dtype,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -69,11 +69,11 @@ def get_unsloth_peft_model(
|
|||
):
|
||||
return FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=lora_rank,
|
||||
target_modules=target_modules,
|
||||
lora_alpha=lora_rank,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
random_state=random_state,
|
||||
r = lora_rank,
|
||||
target_modules = target_modules,
|
||||
lora_alpha = lora_rank,
|
||||
use_gradient_checkpointing = use_gradient_checkpointing,
|
||||
random_state = random_state,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -101,48 +101,48 @@ if __name__ == "__main__":
|
|||
|
||||
model, tokenizer = get_unsloth_model_and_tokenizer(
|
||||
model_name,
|
||||
max_seq_length=512,
|
||||
load_in_4bit=True,
|
||||
fast_inference=False,
|
||||
max_lora_rank=lora_rank,
|
||||
dtype=dtype,
|
||||
max_seq_length = 512,
|
||||
load_in_4bit = True,
|
||||
fast_inference = False,
|
||||
max_lora_rank = lora_rank,
|
||||
dtype = dtype,
|
||||
)
|
||||
temperature = 0.8
|
||||
max_new_tokens = 20
|
||||
|
||||
model = get_unsloth_peft_model(
|
||||
model,
|
||||
lora_rank=lora_rank,
|
||||
target_modules=target_modules,
|
||||
use_gradient_checkpointing=gradient_checkpointing,
|
||||
random_state=seed,
|
||||
lora_rank = lora_rank,
|
||||
target_modules = target_modules,
|
||||
use_gradient_checkpointing = gradient_checkpointing,
|
||||
random_state = seed,
|
||||
)
|
||||
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[USER_MESSAGE], tokenize=False, add_generation_prompt=True
|
||||
[USER_MESSAGE], tokenize = False, add_generation_prompt = True
|
||||
)
|
||||
|
||||
with header_footer_context("Test Prompt and Answer"):
|
||||
print(f"Test Prompt:\n{prompt}\nExpected Answer:\n{ANSWER}")
|
||||
|
||||
dataset: Dataset = create_dataset(
|
||||
tokenizer, num_examples=num_examples, messages=DEFAULT_MESSAGES
|
||||
tokenizer, num_examples = num_examples, messages = DEFAULT_MESSAGES
|
||||
)
|
||||
with header_footer_context("Dataset"):
|
||||
print(f"Dataset: {next(iter(dataset))}")
|
||||
|
||||
training_args = SFTConfig(
|
||||
output_dir=output_dir,
|
||||
max_steps=max_steps,
|
||||
per_device_train_batch_size=batch_size,
|
||||
log_level="info",
|
||||
report_to="none",
|
||||
num_train_epochs=1,
|
||||
logging_steps=1,
|
||||
seed=seed,
|
||||
bf16=dtype == torch.bfloat16,
|
||||
fp16=dtype == torch.float16,
|
||||
save_strategy="no",
|
||||
output_dir = output_dir,
|
||||
max_steps = max_steps,
|
||||
per_device_train_batch_size = batch_size,
|
||||
log_level = "info",
|
||||
report_to = "none",
|
||||
num_train_epochs = 1,
|
||||
logging_steps = 1,
|
||||
seed = seed,
|
||||
bf16 = dtype == torch.bfloat16,
|
||||
fp16 = dtype == torch.float16,
|
||||
save_strategy = "no",
|
||||
)
|
||||
|
||||
with header_footer_context("Train Args"):
|
||||
|
|
@ -163,11 +163,11 @@ if __name__ == "__main__":
|
|||
responses = sample_responses(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
prompt = prompt,
|
||||
**generation_args,
|
||||
)
|
||||
with header_footer_context("Responses before training"):
|
||||
check_responses(responses, answer=ANSWER, prompt=prompt)
|
||||
check_responses(responses, answer = ANSWER, prompt = prompt)
|
||||
with header_footer_context("Peft Weights before training"):
|
||||
for name, stats in itertools.islice(describe_peft_weights(model), 2):
|
||||
print(f"{name}:\n{stats}")
|
||||
|
|
@ -183,29 +183,29 @@ if __name__ == "__main__":
|
|||
responses = sample_responses(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
prompt = prompt,
|
||||
**generation_args,
|
||||
)
|
||||
with header_footer_context("Responses after training"):
|
||||
check_responses(responses, answer=ANSWER, prompt=prompt)
|
||||
check_responses(responses, answer = ANSWER, prompt = prompt)
|
||||
|
||||
model.save_pretrained_merged(
|
||||
unsloth_merged_path,
|
||||
tokenizer,
|
||||
save_method="merged_16bit",
|
||||
save_method = "merged_16bit",
|
||||
)
|
||||
merged_model_unsloth, tokenizer = get_unsloth_model_and_tokenizer(
|
||||
unsloth_merged_path,
|
||||
max_seq_length=512,
|
||||
load_in_4bit=False,
|
||||
fast_inference=False,
|
||||
dtype=dtype,
|
||||
max_seq_length = 512,
|
||||
load_in_4bit = False,
|
||||
fast_inference = False,
|
||||
dtype = dtype,
|
||||
)
|
||||
responses = sample_responses(
|
||||
merged_model_unsloth,
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
prompt = prompt,
|
||||
**generation_args,
|
||||
)
|
||||
with header_footer_context("Responses after unsloth merge to 16bit"):
|
||||
check_responses(responses, answer=ANSWER, prompt=prompt)
|
||||
check_responses(responses, answer = ANSWER, prompt = prompt)
|
||||
|
|
|
|||
|
|
@ -78,17 +78,17 @@ def formatting_prompts_func(examples):
|
|||
}
|
||||
|
||||
|
||||
def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=False):
|
||||
def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):
|
||||
"""Load model and compute perplexity in subprocess"""
|
||||
from unsloth import FastLanguageModel
|
||||
from tests.utils.perplexity_eval import ppl_model
|
||||
|
||||
# Load model
|
||||
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="./unsloth_out/merged_qwen_text_model",
|
||||
max_seq_length=2048,
|
||||
load_in_4bit=load_in_4bit,
|
||||
load_in_8bit=load_in_8bit,
|
||||
model_name = "./unsloth_out/merged_qwen_text_model",
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = load_in_4bit,
|
||||
load_in_8bit = load_in_8bit,
|
||||
)
|
||||
# Set up tokenizer
|
||||
# merged_tokenizer = get_chat_template(
|
||||
|
|
@ -98,7 +98,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
|
|||
|
||||
# Load dataset fresh in subprocess
|
||||
dataset_ppl = load_dataset(
|
||||
"allenai/openassistant-guanaco-reformatted", split="eval"
|
||||
"allenai/openassistant-guanaco-reformatted", split = "eval"
|
||||
)
|
||||
|
||||
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
||||
|
|
@ -146,7 +146,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
|
|||
"text": texts,
|
||||
}
|
||||
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
|
||||
|
||||
# Compute perplexity using the passed dataset
|
||||
ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)
|
||||
|
|
@ -172,7 +172,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
|
|||
|
||||
# Main execution code should be wrapped in this guard
|
||||
if __name__ == "__main__":
|
||||
mp.set_start_method("spawn", force=True)
|
||||
mp.set_start_method("spawn", force = True)
|
||||
|
||||
if torch.cuda.is_bf16_supported():
|
||||
compute_dtype = torch.bfloat16
|
||||
|
|
@ -182,31 +182,31 @@ if __name__ == "__main__":
|
|||
attn_implementation = "sdpa"
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="unsloth/Qwen2.5-7B-Instruct",
|
||||
max_seq_length=2048,
|
||||
dtype=compute_dtype,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
full_finetuning=False,
|
||||
attn_implementation=attn_implementation,
|
||||
model_name = "unsloth/Qwen2.5-7B-Instruct",
|
||||
max_seq_length = 2048,
|
||||
dtype = compute_dtype,
|
||||
load_in_4bit = True,
|
||||
load_in_8bit = False,
|
||||
full_finetuning = False,
|
||||
attn_implementation = attn_implementation,
|
||||
)
|
||||
|
||||
dataset_train = load_dataset(
|
||||
"allenai/openassistant-guanaco-reformatted", split="train"
|
||||
"allenai/openassistant-guanaco-reformatted", split = "train"
|
||||
)
|
||||
dataset_ppl = load_dataset(
|
||||
"allenai/openassistant-guanaco-reformatted", split="eval"
|
||||
"allenai/openassistant-guanaco-reformatted", split = "eval"
|
||||
)
|
||||
|
||||
dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
|
||||
dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
|
||||
|
||||
add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=16,
|
||||
target_modules=[
|
||||
r = 16,
|
||||
target_modules = [
|
||||
"k_proj",
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
|
|
@ -215,40 +215,40 @@ if __name__ == "__main__":
|
|||
"down_proj",
|
||||
"up_proj",
|
||||
],
|
||||
lora_alpha=16,
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
use_gradient_checkpointing="unsloth",
|
||||
random_state=3407,
|
||||
use_rslora=False,
|
||||
loftq_config=None,
|
||||
lora_alpha = 16,
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
use_gradient_checkpointing = "unsloth",
|
||||
random_state = 3407,
|
||||
use_rslora = False,
|
||||
loftq_config = None,
|
||||
)
|
||||
|
||||
from unsloth import is_bfloat16_supported
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset_train,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=2048,
|
||||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
|
||||
dataset_num_proc=2,
|
||||
packing=False,
|
||||
args=TrainingArguments(
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
warmup_ratio=0.1,
|
||||
max_steps=200,
|
||||
learning_rate=2e-4,
|
||||
fp16=not is_bfloat16_supported(),
|
||||
bf16=is_bfloat16_supported(),
|
||||
logging_steps=50,
|
||||
optim="adamw_8bit",
|
||||
lr_scheduler_type="linear",
|
||||
seed=3407,
|
||||
output_dir="outputs",
|
||||
report_to="none",
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
train_dataset = dataset_train,
|
||||
dataset_text_field = "text",
|
||||
max_seq_length = 2048,
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
|
||||
dataset_num_proc = 2,
|
||||
packing = False,
|
||||
args = TrainingArguments(
|
||||
per_device_train_batch_size = 2,
|
||||
gradient_accumulation_steps = 4,
|
||||
warmup_ratio = 0.1,
|
||||
max_steps = 200,
|
||||
learning_rate = 2e-4,
|
||||
fp16 = not is_bfloat16_supported(),
|
||||
bf16 = is_bfloat16_supported(),
|
||||
logging_steps = 50,
|
||||
optim = "adamw_8bit",
|
||||
lr_scheduler_type = "linear",
|
||||
seed = 3407,
|
||||
output_dir = "outputs",
|
||||
report_to = "none",
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -260,7 +260,7 @@ if __name__ == "__main__":
|
|||
# saving and merging the model to local disk
|
||||
print("merge and save to local disk")
|
||||
model.save_pretrained_merged(
|
||||
save_directory="./unsloth_out/merged_qwen_text_model", tokenizer=tokenizer
|
||||
save_directory = "./unsloth_out/merged_qwen_text_model", tokenizer = tokenizer
|
||||
)
|
||||
|
||||
# print("cleaning")
|
||||
|
|
@ -272,10 +272,10 @@ if __name__ == "__main__":
|
|||
# load model from local disk and test
|
||||
print("Loading merged model in 4 bit for perplexity test")
|
||||
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="./unsloth_out/merged_qwen_text_model",
|
||||
max_seq_length=2048,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
model_name = "./unsloth_out/merged_qwen_text_model",
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = True,
|
||||
load_in_8bit = False,
|
||||
)
|
||||
|
||||
add_to_comparison(
|
||||
|
|
@ -284,7 +284,7 @@ if __name__ == "__main__":
|
|||
|
||||
print("Computing 8-bit model perplexity in subprocess...")
|
||||
result_queue = mp.Queue()
|
||||
p = mp.Process(target=load_and_compute_8bit_ppl, args=(result_queue, False, True))
|
||||
p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
|
|
@ -293,10 +293,10 @@ if __name__ == "__main__":
|
|||
|
||||
print("Loading merged model in 16 bit for perplexity test")
|
||||
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="./unsloth_out/merged_qwen_text_model",
|
||||
max_seq_length=2048,
|
||||
load_in_4bit=False,
|
||||
load_in_8bit=False,
|
||||
model_name = "./unsloth_out/merged_qwen_text_model",
|
||||
max_seq_length = 2048,
|
||||
load_in_4bit = False,
|
||||
load_in_8bit = False,
|
||||
)
|
||||
|
||||
add_to_comparison(
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ def formatting_prompts_func(examples):
|
|||
convos = examples["messages"]
|
||||
texts = [
|
||||
tokenizer.apply_chat_template(
|
||||
convo, tokenize=False, add_generation_prompt=False
|
||||
convo, tokenize = False, add_generation_prompt = False
|
||||
)
|
||||
for convo in convos
|
||||
]
|
||||
|
|
@ -52,34 +52,34 @@ else:
|
|||
attn_implementation = "sdpa"
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="unsloth/Llama-3.2-1B-Instruct",
|
||||
max_seq_length=2048,
|
||||
dtype=compute_dtype,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
full_finetuning=False,
|
||||
attn_implementation=attn_implementation,
|
||||
model_name = "unsloth/Llama-3.2-1B-Instruct",
|
||||
max_seq_length = 2048,
|
||||
dtype = compute_dtype,
|
||||
load_in_4bit = True,
|
||||
load_in_8bit = False,
|
||||
full_finetuning = False,
|
||||
attn_implementation = attn_implementation,
|
||||
)
|
||||
|
||||
tokenizer = get_chat_template(
|
||||
tokenizer,
|
||||
chat_template="llama-3.1",
|
||||
chat_template = "llama-3.1",
|
||||
)
|
||||
|
||||
from unsloth.chat_templates import standardize_sharegpt
|
||||
|
||||
dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split="train")
|
||||
dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split="eval")
|
||||
dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split = "train")
|
||||
dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split = "eval")
|
||||
|
||||
dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
|
||||
dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
|
||||
|
||||
add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=16,
|
||||
target_modules=[
|
||||
r = 16,
|
||||
target_modules = [
|
||||
"k_proj",
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
|
|
@ -88,40 +88,40 @@ model = FastLanguageModel.get_peft_model(
|
|||
"down_proj",
|
||||
"up_proj",
|
||||
],
|
||||
lora_alpha=16,
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
use_gradient_checkpointing="unsloth",
|
||||
random_state=3407,
|
||||
use_rslora=False,
|
||||
loftq_config=None,
|
||||
lora_alpha = 16,
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
use_gradient_checkpointing = "unsloth",
|
||||
random_state = 3407,
|
||||
use_rslora = False,
|
||||
loftq_config = None,
|
||||
)
|
||||
|
||||
from unsloth import is_bfloat16_supported
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset_train,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=2048,
|
||||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
|
||||
dataset_num_proc=2,
|
||||
packing=False,
|
||||
args=TrainingArguments(
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
warmup_ratio=0.1,
|
||||
max_steps=30,
|
||||
learning_rate=2e-4,
|
||||
fp16=not is_bfloat16_supported(),
|
||||
bf16=is_bfloat16_supported(),
|
||||
logging_steps=50,
|
||||
optim="adamw_8bit",
|
||||
lr_scheduler_type="linear",
|
||||
seed=3407,
|
||||
output_dir="outputs",
|
||||
report_to="none",
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
train_dataset = dataset_train,
|
||||
dataset_text_field = "text",
|
||||
max_seq_length = 2048,
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
|
||||
dataset_num_proc = 2,
|
||||
packing = False,
|
||||
args = TrainingArguments(
|
||||
per_device_train_batch_size = 2,
|
||||
gradient_accumulation_steps = 4,
|
||||
warmup_ratio = 0.1,
|
||||
max_steps = 30,
|
||||
learning_rate = 2e-4,
|
||||
fp16 = not is_bfloat16_supported(),
|
||||
bf16 = is_bfloat16_supported(),
|
||||
logging_steps = 50,
|
||||
optim = "adamw_8bit",
|
||||
lr_scheduler_type = "linear",
|
||||
seed = 3407,
|
||||
output_dir = "outputs",
|
||||
report_to = "none",
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -129,8 +129,8 @@ from unsloth.chat_templates import train_on_responses_only
|
|||
|
||||
trainer = train_on_responses_only(
|
||||
trainer,
|
||||
instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
|
||||
response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
|
||||
response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
)
|
||||
|
||||
# run training
|
||||
|
|
@ -160,7 +160,7 @@ try:
|
|||
print("\n" + "=" * 80)
|
||||
print("=== UPLOADING MODEL TO HUB ===".center(80))
|
||||
print("=" * 80 + "\n")
|
||||
model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token)
|
||||
model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)
|
||||
success["upload"] = True
|
||||
print("✅ Model uploaded successfully!")
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ def formatting_prompts_func(examples):
|
|||
convos = examples["messages"]
|
||||
texts = [
|
||||
tokenizer.apply_chat_template(
|
||||
convo, tokenize=False, add_generation_prompt=False
|
||||
convo, tokenize = False, add_generation_prompt = False
|
||||
)
|
||||
for convo in convos
|
||||
]
|
||||
|
|
@ -52,34 +52,34 @@ else:
|
|||
attn_implementation = "sdpa"
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="unsloth/Llama-3.1-8B-Instruct",
|
||||
max_seq_length=2048,
|
||||
dtype=compute_dtype,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
full_finetuning=False,
|
||||
attn_implementation=attn_implementation,
|
||||
model_name = "unsloth/Llama-3.1-8B-Instruct",
|
||||
max_seq_length = 2048,
|
||||
dtype = compute_dtype,
|
||||
load_in_4bit = True,
|
||||
load_in_8bit = False,
|
||||
full_finetuning = False,
|
||||
attn_implementation = attn_implementation,
|
||||
)
|
||||
|
||||
tokenizer = get_chat_template(
|
||||
tokenizer,
|
||||
chat_template="llama-3.1",
|
||||
chat_template = "llama-3.1",
|
||||
)
|
||||
|
||||
from unsloth.chat_templates import standardize_sharegpt
|
||||
|
||||
dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split="train")
|
||||
dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split="eval")
|
||||
dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split = "train")
|
||||
dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split = "eval")
|
||||
|
||||
dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
|
||||
dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
|
||||
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
|
||||
|
||||
add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=16,
|
||||
target_modules=[
|
||||
r = 16,
|
||||
target_modules = [
|
||||
"k_proj",
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
|
|
@ -88,40 +88,40 @@ model = FastLanguageModel.get_peft_model(
|
|||
"down_proj",
|
||||
"up_proj",
|
||||
],
|
||||
lora_alpha=16,
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
use_gradient_checkpointing="unsloth",
|
||||
random_state=3407,
|
||||
use_rslora=False,
|
||||
loftq_config=None,
|
||||
lora_alpha = 16,
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
use_gradient_checkpointing = "unsloth",
|
||||
random_state = 3407,
|
||||
use_rslora = False,
|
||||
loftq_config = None,
|
||||
)
|
||||
|
||||
from unsloth import is_bfloat16_supported
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset_train,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=2048,
|
||||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
|
||||
dataset_num_proc=2,
|
||||
packing=False,
|
||||
args=TrainingArguments(
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
warmup_ratio=0.1,
|
||||
max_steps=30,
|
||||
learning_rate=2e-4,
|
||||
fp16=not is_bfloat16_supported(),
|
||||
bf16=is_bfloat16_supported(),
|
||||
logging_steps=50,
|
||||
optim="adamw_8bit",
|
||||
lr_scheduler_type="linear",
|
||||
seed=3407,
|
||||
output_dir="outputs",
|
||||
report_to="none",
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
train_dataset = dataset_train,
|
||||
dataset_text_field = "text",
|
||||
max_seq_length = 2048,
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
|
||||
dataset_num_proc = 2,
|
||||
packing = False,
|
||||
args = TrainingArguments(
|
||||
per_device_train_batch_size = 2,
|
||||
gradient_accumulation_steps = 4,
|
||||
warmup_ratio = 0.1,
|
||||
max_steps = 30,
|
||||
learning_rate = 2e-4,
|
||||
fp16 = not is_bfloat16_supported(),
|
||||
bf16 = is_bfloat16_supported(),
|
||||
logging_steps = 50,
|
||||
optim = "adamw_8bit",
|
||||
lr_scheduler_type = "linear",
|
||||
seed = 3407,
|
||||
output_dir = "outputs",
|
||||
report_to = "none",
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -129,8 +129,8 @@ from unsloth.chat_templates import train_on_responses_only
|
|||
|
||||
trainer = train_on_responses_only(
|
||||
trainer,
|
||||
instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
|
||||
response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
|
||||
response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
)
|
||||
|
||||
# run training
|
||||
|
|
@ -161,7 +161,7 @@ try:
|
|||
print("\n" + "=" * 80)
|
||||
print("=== UPLOADING MODEL TO HUB ===".center(80))
|
||||
print("=" * 80 + "\n")
|
||||
model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token)
|
||||
model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)
|
||||
success["upload"] = True
|
||||
print("✅ Model uploaded successfully!")
|
||||
except Exception as e:
|
||||
|
|
@ -173,8 +173,8 @@ try:
|
|||
print("\n" + "=" * 80)
|
||||
print("=== VERIFYING REPO CONTENTS ===".center(80))
|
||||
print("=" * 80 + "\n")
|
||||
fs = HfFileSystem(token=hf_token)
|
||||
file_list = fs.ls(repo_name, detail=True)
|
||||
fs = HfFileSystem(token = hf_token)
|
||||
file_list = fs.ls(repo_name, detail = True)
|
||||
safetensors_found = any(
|
||||
file["name"].endswith("model.safetensors.index.json") for file in file_list
|
||||
)
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ max_seq_length = 2048 # Can increase for longer reasoning traces
|
|||
lora_rank = 64 # Larger rank = smarter, but slower
|
||||
|
||||
|
||||
def evaluate_merged_model(result_queue, load_in_4bit=False, load_in_8bit=False):
|
||||
def evaluate_merged_model(result_queue, load_in_4bit = False, load_in_8bit = False):
|
||||
from unsloth import FastLanguageModel
|
||||
from tests.utils.aime_eval import evaluate_model_aime
|
||||
|
||||
|
|
@ -32,12 +32,12 @@ def evaluate_merged_model(result_queue, load_in_4bit=False, load_in_8bit=False):
|
|||
lora_rank = 64 # Larger rank = smarter, but slower
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="./final_merged_model",
|
||||
max_seq_length=max_seq_length,
|
||||
load_in_4bit=True, # False for LoRA 16bit
|
||||
fast_inference=True, # Enable vLLM fast inference
|
||||
max_lora_rank=lora_rank,
|
||||
gpu_memory_utilization=0.8, # Reduce if out of memory
|
||||
model_name = "./final_merged_model",
|
||||
max_seq_length = max_seq_length,
|
||||
load_in_4bit = True, # False for LoRA 16bit
|
||||
fast_inference = True, # Enable vLLM fast inference
|
||||
max_lora_rank = lora_rank,
|
||||
gpu_memory_utilization = 0.8, # Reduce if out of memory
|
||||
)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
|
|
@ -53,14 +53,14 @@ def evaluate_merged_model(result_queue, load_in_4bit=False, load_in_8bit=False):
|
|||
print(f"{'='*60}")
|
||||
|
||||
evaluate_model_aime(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
model_type=model_type,
|
||||
temperature=0.3,
|
||||
n_sampling=8,
|
||||
max_tokens=32768,
|
||||
top_p=0.95,
|
||||
seed=0,
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
model_type = model_type,
|
||||
temperature = 0.3,
|
||||
n_sampling = 8,
|
||||
max_tokens = 32768,
|
||||
top_p = 0.95,
|
||||
seed = 0,
|
||||
)
|
||||
|
||||
result_queue.put(results)
|
||||
|
|
@ -74,12 +74,12 @@ def evaluate_merged_model(result_queue, load_in_4bit=False, load_in_8bit=False):
|
|||
# Main execution code should be wrapped in this guard
|
||||
def training_run(result_queue):
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="meta-llama/Llama-3.2-3B-Instruct",
|
||||
max_seq_length=max_seq_length,
|
||||
load_in_4bit=False, # False for LoRA 16bit
|
||||
fast_inference=True, # Enable vLLM fast inference
|
||||
max_lora_rank=lora_rank,
|
||||
gpu_memory_utilization=0.8, # Reduce if out of memory
|
||||
model_name = "meta-llama/Llama-3.2-3B-Instruct",
|
||||
max_seq_length = max_seq_length,
|
||||
load_in_4bit = False, # False for LoRA 16bit
|
||||
fast_inference = True, # Enable vLLM fast inference
|
||||
max_lora_rank = lora_rank,
|
||||
gpu_memory_utilization = 0.8, # Reduce if out of memory
|
||||
)
|
||||
|
||||
"""### Helper Functions
|
||||
|
|
@ -166,10 +166,10 @@ def training_run(result_queue):
|
|||
lengths = dataset.map(
|
||||
lambda x: {
|
||||
"tokens": tokenizer.apply_chat_template(
|
||||
x["prompt"], add_generation_prompt=True, tokenize=True
|
||||
x["prompt"], add_generation_prompt = True, tokenize = True
|
||||
)
|
||||
},
|
||||
batched=True,
|
||||
batched = True,
|
||||
).map(lambda x: {"length": len(x["tokens"])})["length"]
|
||||
|
||||
max_length = max(lengths)
|
||||
|
|
@ -181,7 +181,7 @@ def training_run(result_queue):
|
|||
)
|
||||
return max_length, avg_length
|
||||
|
||||
def extract_unsloth_answer(text, start_tag="<SOLUTION>", end_tag="</SOLUTION>"):
|
||||
def extract_unsloth_answer(text, start_tag = "<SOLUTION>", end_tag = "</SOLUTION>"):
|
||||
"""Extract answer from Unsloth SOLUTION tags"""
|
||||
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
|
|
@ -213,10 +213,10 @@ def training_run(result_queue):
|
|||
"""Count tokens in text"""
|
||||
if not text:
|
||||
return 0
|
||||
encoding = tokenizer_instance(text, return_tensors="pt")
|
||||
encoding = tokenizer_instance(text, return_tensors = "pt")
|
||||
return len(encoding["input_ids"][0])
|
||||
|
||||
def check_format_compliance(text, format_type="unsloth"):
|
||||
def check_format_compliance(text, format_type = "unsloth"):
|
||||
"""Check if response follows expected format"""
|
||||
if format_type == "unsloth":
|
||||
reasoning_start = "<start_reasoning>"
|
||||
|
|
@ -419,11 +419,11 @@ def training_run(result_queue):
|
|||
# Save comparison
|
||||
comparison_data = {
|
||||
"summary": all_results,
|
||||
"best_model": max(all_results, key=lambda x: x["exact_match_pct"]),
|
||||
"best_model": max(all_results, key = lambda x: x["exact_match_pct"]),
|
||||
}
|
||||
|
||||
with open("model_comparison_comprehensive.json", "w") as f:
|
||||
json.dump(comparison_data, f, indent=4)
|
||||
json.dump(comparison_data, f, indent = 4)
|
||||
|
||||
print(
|
||||
f"\nBest performing model: {comparison_data['best_model']['model_type']} "
|
||||
|
|
@ -449,10 +449,10 @@ def training_run(result_queue):
|
|||
from datasets import load_dataset
|
||||
|
||||
# Load GSM8K
|
||||
gsm8k_dataset = load_dataset("openai/gsm8k", "main", split="train")
|
||||
gsm8k_dataset = load_dataset("openai/gsm8k", "main", split = "train")
|
||||
|
||||
# Load LIMO (adjust this based on your access method)
|
||||
limo_train = load_dataset("GAIR/LIMO", split="train")
|
||||
limo_train = load_dataset("GAIR/LIMO", split = "train")
|
||||
|
||||
# Prepare datasets
|
||||
gsm8k_train = prepare_gsm8k_dataset(gsm8k_dataset)
|
||||
|
|
@ -466,28 +466,28 @@ def training_run(result_queue):
|
|||
|
||||
# Single temperature evaluation on combined dataset
|
||||
results = evaluate_model_aime(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
model_type="base",
|
||||
temperature=0.3,
|
||||
n_sampling=8,
|
||||
max_tokens=32768,
|
||||
top_p=0.95,
|
||||
seed=0,
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
model_type = "base",
|
||||
temperature = 0.3,
|
||||
n_sampling = 8,
|
||||
max_tokens = 32768,
|
||||
top_p = 0.95,
|
||||
seed = 0,
|
||||
)
|
||||
|
||||
from unsloth.chat_templates import get_chat_template
|
||||
|
||||
tokenizer = get_chat_template(
|
||||
tokenizer,
|
||||
chat_template="llama-3.1",
|
||||
chat_template = "llama-3.1",
|
||||
)
|
||||
|
||||
def formatting_prompts_func(examples):
|
||||
convos = examples["prompt"]
|
||||
texts = [
|
||||
tokenizer.apply_chat_template(
|
||||
convo, tokenize=False, add_generation_prompt=False
|
||||
convo, tokenize = False, add_generation_prompt = False
|
||||
)
|
||||
for convo in convos
|
||||
]
|
||||
|
|
@ -497,7 +497,7 @@ def training_run(result_queue):
|
|||
|
||||
limo_train = limo_train.map(
|
||||
formatting_prompts_func,
|
||||
batched=True,
|
||||
batched = True,
|
||||
)
|
||||
|
||||
from trl import SFTTrainer
|
||||
|
|
@ -510,8 +510,8 @@ def training_run(result_queue):
|
|||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
target_modules=[
|
||||
r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
target_modules = [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
|
|
@ -520,37 +520,37 @@ def training_run(result_queue):
|
|||
"up_proj",
|
||||
"down_proj",
|
||||
], # Remove QKVO if out of memory
|
||||
lora_alpha=lora_rank,
|
||||
use_gradient_checkpointing="unsloth", # Enable long context finetuning
|
||||
random_state=3407,
|
||||
lora_alpha = lora_rank,
|
||||
use_gradient_checkpointing = "unsloth", # Enable long context finetuning
|
||||
random_state = 3407,
|
||||
)
|
||||
|
||||
if limo_train is not None:
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=limo_train,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=max_seq_length,
|
||||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
|
||||
dataset_num_proc=2,
|
||||
packing=False, # Can make training 5x faster for short sequences.
|
||||
args=TrainingArguments(
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
warmup_steps=5,
|
||||
num_train_epochs=1, # Set this for 1 full training run.
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
train_dataset = limo_train,
|
||||
dataset_text_field = "text",
|
||||
max_seq_length = max_seq_length,
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
|
||||
dataset_num_proc = 2,
|
||||
packing = False, # Can make training 5x faster for short sequences.
|
||||
args = TrainingArguments(
|
||||
per_device_train_batch_size = 2,
|
||||
gradient_accumulation_steps = 4,
|
||||
warmup_steps = 5,
|
||||
num_train_epochs = 1, # Set this for 1 full training run.
|
||||
# max_steps = 60,
|
||||
learning_rate=2e-4,
|
||||
fp16=not is_bfloat16_supported(),
|
||||
bf16=is_bfloat16_supported(),
|
||||
logging_steps=1,
|
||||
optim="adamw_8bit",
|
||||
weight_decay=0.01,
|
||||
lr_scheduler_type="linear",
|
||||
seed=3407,
|
||||
output_dir="outputs",
|
||||
report_to="none", # Use this for WandB etc
|
||||
learning_rate = 2e-4,
|
||||
fp16 = not is_bfloat16_supported(),
|
||||
bf16 = is_bfloat16_supported(),
|
||||
logging_steps = 1,
|
||||
optim = "adamw_8bit",
|
||||
weight_decay = 0.01,
|
||||
lr_scheduler_type = "linear",
|
||||
seed = 3407,
|
||||
output_dir = "outputs",
|
||||
report_to = "none", # Use this for WandB etc
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -558,8 +558,8 @@ def training_run(result_queue):
|
|||
|
||||
trainer = train_on_responses_only(
|
||||
trainer,
|
||||
instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
|
||||
response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
|
||||
response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
)
|
||||
|
||||
# Train
|
||||
|
|
@ -588,7 +588,7 @@ def training_run(result_queue):
|
|||
PRINT_EVERY_STEPS = 5
|
||||
|
||||
match_numbers = re.compile(
|
||||
solution_start + r".*?([\d\.\,]{1,})", flags=re.MULTILINE | re.DOTALL
|
||||
solution_start + r".*?([\d\.\,]{1,})", flags = re.MULTILINE | re.DOTALL
|
||||
)
|
||||
|
||||
def check_numbers(prompts, completions, answer, **kwargs):
|
||||
|
|
@ -642,37 +642,37 @@ def training_run(result_queue):
|
|||
from trl import GRPOConfig, GRPOTrainer
|
||||
|
||||
training_args = GRPOConfig(
|
||||
learning_rate=5e-6,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type="cosine",
|
||||
optim="adamw_torch_fused",
|
||||
logging_steps=1,
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4, # Increase to 4 for smoother training
|
||||
num_generations=8, # Decrease if out of memory
|
||||
max_prompt_length=max_prompt_length,
|
||||
max_completion_length=max_seq_length - max_prompt_length,
|
||||
learning_rate = 5e-6,
|
||||
weight_decay = 0.1,
|
||||
warmup_ratio = 0.1,
|
||||
lr_scheduler_type = "cosine",
|
||||
optim = "adamw_torch_fused",
|
||||
logging_steps = 1,
|
||||
per_device_train_batch_size = 1,
|
||||
gradient_accumulation_steps = 4, # Increase to 4 for smoother training
|
||||
num_generations = 8, # Decrease if out of memory
|
||||
max_prompt_length = max_prompt_length,
|
||||
max_completion_length = max_seq_length - max_prompt_length,
|
||||
# num_train_epochs = 1, # Set to 1 for a full training run
|
||||
# max_steps = 250,
|
||||
max_steps=1000,
|
||||
save_steps=250,
|
||||
max_grad_norm=0.1,
|
||||
report_to="none", # Can use Weights & Biases
|
||||
output_dir="outputs",
|
||||
max_steps = 1000,
|
||||
save_steps = 250,
|
||||
max_grad_norm = 0.1,
|
||||
report_to = "none", # Can use Weights & Biases
|
||||
output_dir = "outputs",
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=[
|
||||
model = model,
|
||||
processing_class = tokenizer,
|
||||
reward_funcs = [
|
||||
match_format_exactly,
|
||||
match_format_approximately,
|
||||
check_answer_correctness,
|
||||
check_numbers,
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=gsm8k_train,
|
||||
args = training_args,
|
||||
train_dataset = gsm8k_train,
|
||||
)
|
||||
|
||||
# Train
|
||||
|
|
@ -696,14 +696,14 @@ def training_run(result_queue):
|
|||
print(f"{'='*60}")
|
||||
|
||||
grpo_results = evaluate_model_aime(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
model_type="grpo",
|
||||
temperature=0.3,
|
||||
n_sampling=8,
|
||||
max_tokens=32768,
|
||||
top_p=0.95,
|
||||
seed=0,
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
model_type = "grpo",
|
||||
temperature = 0.3,
|
||||
n_sampling = 8,
|
||||
max_tokens = 32768,
|
||||
top_p = 0.95,
|
||||
seed = 0,
|
||||
)
|
||||
|
||||
all_results.append(grpo_results)
|
||||
|
|
@ -716,7 +716,7 @@ def training_run(result_queue):
|
|||
# Save as merged model
|
||||
try:
|
||||
model.save_pretrained_merged(
|
||||
"final_merged_model", tokenizer, save_method="merged_16bit"
|
||||
"final_merged_model", tokenizer, save_method = "merged_16bit"
|
||||
)
|
||||
print("✅ Merged model saved to: final_merged_model/")
|
||||
except Exception as e:
|
||||
|
|
@ -774,12 +774,12 @@ def training_run(result_queue):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mp.set_start_method("spawn", force=True)
|
||||
mp.set_start_method("spawn", force = True)
|
||||
result_queue = mp.Queue()
|
||||
all_results = []
|
||||
|
||||
# run main finetuning and grpo loop
|
||||
p = mp.Process(target=training_run, args=(result_queue,))
|
||||
p = mp.Process(target = training_run, args = (result_queue,))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
|
|
@ -787,7 +787,7 @@ if __name__ == "__main__":
|
|||
all_results = results
|
||||
|
||||
# evaluate merged model loaded 16bits
|
||||
p = mp.Process(target=evaluate_merged_model, args=(result_queue, False, False))
|
||||
p = mp.Process(target = evaluate_merged_model, args = (result_queue, False, False))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
|
|
@ -796,7 +796,7 @@ if __name__ == "__main__":
|
|||
safe_remove_directory("./unsloth_compiled_cache")
|
||||
|
||||
# Merged model load 8 bits model AIME eval
|
||||
p = mp.Process(target=evaluate_merged_model, args=(result_queue, False, True))
|
||||
p = mp.Process(target = evaluate_merged_model, args = (result_queue, False, True))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
|
|
@ -806,7 +806,7 @@ if __name__ == "__main__":
|
|||
safe_remove_directory("./unsloth_compiled_cache")
|
||||
|
||||
# Merged model load 4 bits model AIME eval
|
||||
p = mp.Process(target=evaluate_merged_model, args=(result_queue, True, False))
|
||||
p = mp.Process(target = evaluate_merged_model, args = (result_queue, True, False))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
|
|
|
|||
|
|
@ -43,31 +43,31 @@ tokenizer_files = [
|
|||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", params=model_to_test)
|
||||
@pytest.fixture(scope = "session", params = model_to_test)
|
||||
def loaded_model_tokenizer(request):
|
||||
model_name = request.param
|
||||
print("Loading model and tokenizer...")
|
||||
|
||||
model, tokenizer = FastModel.from_pretrained(
|
||||
model_name, # use small model
|
||||
max_seq_length=128,
|
||||
dtype=None,
|
||||
load_in_4bit=True,
|
||||
max_seq_length = 128,
|
||||
dtype = None,
|
||||
load_in_4bit = True,
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
model = FastModel.get_peft_model(
|
||||
model,
|
||||
r=16,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
lora_alpha=16,
|
||||
use_gradient_checkpointing="unsloth",
|
||||
r = 16,
|
||||
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
lora_alpha = 16,
|
||||
use_gradient_checkpointing = "unsloth",
|
||||
)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", params=torchao_models)
|
||||
@pytest.fixture(scope = "session", params = torchao_models)
|
||||
def fp16_model_tokenizer(request):
|
||||
"""Load model in FP16 for TorchAO quantization"""
|
||||
model_name = request.param
|
||||
|
|
@ -75,29 +75,29 @@ def fp16_model_tokenizer(request):
|
|||
|
||||
model, tokenizer = FastModel.from_pretrained(
|
||||
model_name,
|
||||
max_seq_length=128,
|
||||
dtype=None,
|
||||
load_in_4bit=False, # No BnB quantization
|
||||
max_seq_length = 128,
|
||||
dtype = None,
|
||||
load_in_4bit = False, # No BnB quantization
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
model = FastModel.get_peft_model(
|
||||
model,
|
||||
r=16,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
lora_alpha=16,
|
||||
use_gradient_checkpointing="unsloth",
|
||||
r = 16,
|
||||
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
lora_alpha = 16,
|
||||
use_gradient_checkpointing = "unsloth",
|
||||
)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@pytest.fixture(scope = "session")
|
||||
def model(loaded_model_tokenizer):
|
||||
return loaded_model_tokenizer[0]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@pytest.fixture(scope = "session")
|
||||
def tokenizer(loaded_model_tokenizer):
|
||||
return loaded_model_tokenizer[1]
|
||||
|
||||
|
|
@ -133,7 +133,7 @@ def test_save_merged_16bit(model, tokenizer, temp_save_dir: str):
|
|||
)
|
||||
|
||||
model.save_pretrained_merged(
|
||||
save_path, tokenizer=tokenizer, save_method="merged_16bit"
|
||||
save_path, tokenizer = tokenizer, save_method = "merged_16bit"
|
||||
)
|
||||
|
||||
# Check model files
|
||||
|
|
@ -172,9 +172,9 @@ def test_save_merged_16bit(model, tokenizer, temp_save_dir: str):
|
|||
# Test loading the model from the saved path
|
||||
loaded_model, loaded_tokenizer = FastLanguageModel.from_pretrained(
|
||||
save_path,
|
||||
max_seq_length=128,
|
||||
dtype=None,
|
||||
load_in_4bit=True,
|
||||
max_seq_length = 128,
|
||||
dtype = None,
|
||||
load_in_4bit = True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -186,7 +186,7 @@ def test_save_merged_4bit(model, tokenizer, temp_save_dir: str):
|
|||
)
|
||||
|
||||
model.save_pretrained_merged(
|
||||
save_path, tokenizer=tokenizer, save_method="merged_4bit_forced"
|
||||
save_path, tokenizer = tokenizer, save_method = "merged_4bit_forced"
|
||||
)
|
||||
|
||||
# Check model files
|
||||
|
|
@ -230,15 +230,15 @@ def test_save_merged_4bit(model, tokenizer, temp_save_dir: str):
|
|||
# Test loading the model from the saved path
|
||||
loaded_model, loaded_tokenizer = FastModel.from_pretrained(
|
||||
save_path,
|
||||
max_seq_length=128,
|
||||
dtype=None,
|
||||
load_in_4bit=True,
|
||||
max_seq_length = 128,
|
||||
dtype = None,
|
||||
load_in_4bit = True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("torchao") is None,
|
||||
reason="require torchao to be installed",
|
||||
reason = "require torchao to be installed",
|
||||
)
|
||||
def test_save_torchao(fp16_model_tokenizer, temp_save_dir: str):
|
||||
model, tokenizer = fp16_model_tokenizer
|
||||
|
|
@ -251,9 +251,9 @@ def test_save_torchao(fp16_model_tokenizer, temp_save_dir: str):
|
|||
torchao_config = Int8DynamicActivationInt8WeightConfig()
|
||||
model.save_pretrained_torchao(
|
||||
save_path,
|
||||
tokenizer=tokenizer,
|
||||
torchao_config=torchao_config,
|
||||
push_to_hub=False,
|
||||
tokenizer = tokenizer,
|
||||
torchao_config = torchao_config,
|
||||
push_to_hub = False,
|
||||
)
|
||||
|
||||
weight_files_16bit = [
|
||||
|
|
@ -316,15 +316,15 @@ def test_save_torchao(fp16_model_tokenizer, temp_save_dir: str):
|
|||
with torch.serialization.safe_globals([getattr]):
|
||||
loaded_model, loaded_tokenizer = FastModel.from_pretrained(
|
||||
torchao_save_path,
|
||||
max_seq_length=128,
|
||||
dtype=None,
|
||||
load_in_4bit=False,
|
||||
max_seq_length = 128,
|
||||
dtype = None,
|
||||
load_in_4bit = False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("torchao") is None,
|
||||
reason="require torchao to be installed",
|
||||
reason = "require torchao to be installed",
|
||||
)
|
||||
def test_save_and_inference_torchao(fp16_model_tokenizer, temp_save_dir: str):
|
||||
model, tokenizer = fp16_model_tokenizer
|
||||
|
|
@ -343,9 +343,9 @@ def test_save_and_inference_torchao(fp16_model_tokenizer, temp_save_dir: str):
|
|||
# Save with TorchAO
|
||||
model.save_pretrained_torchao(
|
||||
save_path,
|
||||
tokenizer=tokenizer,
|
||||
torchao_config=torchao_config,
|
||||
push_to_hub=False,
|
||||
tokenizer = tokenizer,
|
||||
torchao_config = torchao_config,
|
||||
push_to_hub = False,
|
||||
)
|
||||
|
||||
torchao_save_path = save_path + "-torchao"
|
||||
|
|
@ -361,9 +361,9 @@ def test_save_and_inference_torchao(fp16_model_tokenizer, temp_save_dir: str):
|
|||
with torch.serialization.safe_globals([getattr]):
|
||||
loaded_model, loaded_tokenizer = FastModel.from_pretrained(
|
||||
torchao_save_path,
|
||||
max_seq_length=128,
|
||||
dtype=None,
|
||||
load_in_4bit=False,
|
||||
max_seq_length = 128,
|
||||
dtype = None,
|
||||
load_in_4bit = False,
|
||||
)
|
||||
|
||||
FastModel.for_inference(loaded_model) # Enable native 2x faster inference
|
||||
|
|
@ -376,24 +376,24 @@ def test_save_and_inference_torchao(fp16_model_tokenizer, temp_save_dir: str):
|
|||
]
|
||||
inputs = loaded_tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True, # Must add for generation
|
||||
return_tensors="pt",
|
||||
tokenize = True,
|
||||
add_generation_prompt = True, # Must add for generation
|
||||
return_tensors = "pt",
|
||||
).to("cuda")
|
||||
|
||||
outputs = loaded_model.generate( # ← Use loaded_model, not model
|
||||
input_ids=inputs,
|
||||
max_new_tokens=64,
|
||||
use_cache=False, # Avoid cache issues
|
||||
temperature=1.5,
|
||||
min_p=0.1,
|
||||
do_sample=True,
|
||||
pad_token_id=loaded_tokenizer.pad_token_id or loaded_tokenizer.eos_token_id,
|
||||
input_ids = inputs,
|
||||
max_new_tokens = 64,
|
||||
use_cache = False, # Avoid cache issues
|
||||
temperature = 1.5,
|
||||
min_p = 0.1,
|
||||
do_sample = True,
|
||||
pad_token_id = loaded_tokenizer.pad_token_id or loaded_tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
# Decode with the LOADED tokenizer
|
||||
generated_text = loaded_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
input_text = loaded_tokenizer.decode(inputs[0], skip_special_tokens=True)
|
||||
generated_text = loaded_tokenizer.decode(outputs[0], skip_special_tokens = True)
|
||||
input_text = loaded_tokenizer.decode(inputs[0], skip_special_tokens = True)
|
||||
response_part = generated_text[len(input_text) :].strip()
|
||||
|
||||
print(f"Input: {input_text}")
|
||||
|
|
|
|||
|
|
@ -26,11 +26,11 @@ print(f"{'='*80}")
|
|||
|
||||
|
||||
model, tokenizer = FastModel.from_pretrained(
|
||||
model_name="unsloth/csm-1b",
|
||||
max_seq_length=2048, # Choose any for long context!
|
||||
dtype=None, # Leave as None for auto-detection
|
||||
auto_model=CsmForConditionalGeneration,
|
||||
load_in_4bit=False, # Select True for 4bit - reduces memory usage
|
||||
model_name = "unsloth/csm-1b",
|
||||
max_seq_length = 2048, # Choose any for long context!
|
||||
dtype = None, # Leave as None for auto-detection
|
||||
auto_model = CsmForConditionalGeneration,
|
||||
load_in_4bit = False, # Select True for 4bit - reduces memory usage
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -39,8 +39,8 @@ base_model_class = model.__class__.__name__
|
|||
|
||||
model = FastModel.get_peft_model(
|
||||
model,
|
||||
r=32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
target_modules=[
|
||||
r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
target_modules = [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
|
|
@ -49,14 +49,14 @@ model = FastModel.get_peft_model(
|
|||
"up_proj",
|
||||
"down_proj",
|
||||
],
|
||||
lora_alpha=32,
|
||||
lora_dropout=0, # Supports any, but = 0 is optimized
|
||||
bias="none", # Supports any, but = "none" is optimized
|
||||
lora_alpha = 32,
|
||||
lora_dropout = 0, # Supports any, but = 0 is optimized
|
||||
bias = "none", # Supports any, but = "none" is optimized
|
||||
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
|
||||
use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
|
||||
random_state=3407,
|
||||
use_rslora=False, # We support rank stabilized LoRA
|
||||
loftq_config=None, # And LoftQ
|
||||
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
|
||||
random_state = 3407,
|
||||
use_rslora = False, # We support rank stabilized LoRA
|
||||
loftq_config = None, # And LoftQ
|
||||
)
|
||||
|
||||
print("✅ Model and LoRA adapters loaded successfully!")
|
||||
|
|
@ -110,11 +110,11 @@ print(f"{'='*80}")
|
|||
|
||||
|
||||
model, processor = FastModel.from_pretrained(
|
||||
model_name="./csm",
|
||||
max_seq_length=2048, # Choose any for long context!
|
||||
dtype=None, # Leave as None for auto-detection
|
||||
auto_model=CsmForConditionalGeneration,
|
||||
load_in_4bit=False, # Select True for 4bit - reduces memory usage
|
||||
model_name = "./csm",
|
||||
max_seq_length = 2048, # Choose any for long context!
|
||||
dtype = None, # Leave as None for auto-detection
|
||||
auto_model = CsmForConditionalGeneration,
|
||||
load_in_4bit = False, # Select True for 4bit - reduces memory usage
|
||||
)
|
||||
|
||||
from transformers import AutoProcessor
|
||||
|
|
@ -138,19 +138,19 @@ try:
|
|||
"We just finished fine tuning a text to speech model... and it's pretty good!"
|
||||
)
|
||||
speaker_id = 0
|
||||
inputs = processor(f"[{speaker_id}]{text}", add_special_tokens=True).to("cuda")
|
||||
inputs = processor(f"[{speaker_id}]{text}", add_special_tokens = True).to("cuda")
|
||||
audio_values = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=125, # 125 tokens is 10 seconds of audio, for longer speech increase this
|
||||
max_new_tokens = 125, # 125 tokens is 10 seconds of audio, for longer speech increase this
|
||||
# play with these parameters to get the best results
|
||||
depth_decoder_temperature=0.6,
|
||||
depth_decoder_top_k=0,
|
||||
depth_decoder_top_p=0.9,
|
||||
temperature=0.8,
|
||||
top_k=50,
|
||||
top_p=1.0,
|
||||
depth_decoder_temperature = 0.6,
|
||||
depth_decoder_top_k = 0,
|
||||
depth_decoder_top_p = 0.9,
|
||||
temperature = 0.8,
|
||||
top_k = 50,
|
||||
top_p = 1.0,
|
||||
#########################################################
|
||||
output_audio=True,
|
||||
output_audio = True,
|
||||
)
|
||||
audio = audio_values[0].to(torch.float32).cpu().numpy()
|
||||
sf.write("example_without_context.wav", audio, 24000)
|
||||
|
|
|
|||
|
|
@ -42,10 +42,10 @@ print(f"{'='*80}")
|
|||
|
||||
max_seq_length = 2048
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="unsloth/Llasa-1B",
|
||||
max_seq_length=max_seq_length,
|
||||
dtype=None, # Select None for auto detection
|
||||
load_in_4bit=False, # Choose True for 4bit which reduces memory
|
||||
model_name = "unsloth/Llasa-1B",
|
||||
max_seq_length = max_seq_length,
|
||||
dtype = None, # Select None for auto detection
|
||||
load_in_4bit = False, # Choose True for 4bit which reduces memory
|
||||
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
|
||||
)
|
||||
|
||||
|
|
@ -54,16 +54,16 @@ base_model_class = model.__class__.__name__
|
|||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
lora_alpha=128,
|
||||
lora_dropout=0, # Supports any, but = 0 is optimized
|
||||
bias="none", # Supports any, but = "none" is optimized
|
||||
r = 128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
target_modules = ["q_proj", "v_proj"],
|
||||
lora_alpha = 128,
|
||||
lora_dropout = 0, # Supports any, but = 0 is optimized
|
||||
bias = "none", # Supports any, but = "none" is optimized
|
||||
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
|
||||
use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
|
||||
random_state=3407,
|
||||
use_rslora=False, # We support rank stabilized LoRA
|
||||
loftq_config=None, # And LoftQ
|
||||
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
|
||||
random_state = 3407,
|
||||
use_rslora = False, # We support rank stabilized LoRA
|
||||
loftq_config = None, # And LoftQ
|
||||
)
|
||||
|
||||
print("✅ Model and LoRA adapters loaded successfully!")
|
||||
|
|
@ -117,10 +117,10 @@ print(f"{'='*80}")
|
|||
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="./lasa",
|
||||
max_seq_length=max_seq_length,
|
||||
dtype=None, # Select None for auto detection
|
||||
load_in_4bit=False, # Choose True for 4bit which reduces memory
|
||||
model_name = "./lasa",
|
||||
max_seq_length = max_seq_length,
|
||||
dtype = None, # Select None for auto detection
|
||||
load_in_4bit = False, # Choose True for 4bit which reduces memory
|
||||
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
|
||||
)
|
||||
|
||||
|
|
@ -166,7 +166,7 @@ def extract_speech_ids(speech_tokens_str):
|
|||
|
||||
# TTS start!
|
||||
with torch.inference_mode():
|
||||
with torch.amp.autocast("cuda", dtype=model.dtype):
|
||||
with torch.amp.autocast("cuda", dtype = model.dtype):
|
||||
formatted_text = (
|
||||
f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
|
||||
)
|
||||
|
|
@ -178,7 +178,7 @@ with torch.inference_mode():
|
|||
]
|
||||
|
||||
input_ids = tokenizer.apply_chat_template(
|
||||
chat, tokenize=True, return_tensors="pt", continue_final_message=True
|
||||
chat, tokenize = True, return_tensors = "pt", continue_final_message = True
|
||||
)
|
||||
input_ids = input_ids.to("cuda")
|
||||
|
||||
|
|
@ -187,16 +187,16 @@ with torch.inference_mode():
|
|||
# Generate the speech autoregressively
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
max_length=2048, # We trained our model with a max length of 2048
|
||||
eos_token_id=speech_end_id,
|
||||
do_sample=True,
|
||||
top_p=1.2, # Adjusts the diversity of generated content
|
||||
temperature=1.2, # Controls randomness in output
|
||||
max_length = 2048, # We trained our model with a max length of 2048
|
||||
eos_token_id = speech_end_id,
|
||||
do_sample = True,
|
||||
top_p = 1.2, # Adjusts the diversity of generated content
|
||||
temperature = 1.2, # Controls randomness in output
|
||||
)
|
||||
# Extract the speech tokens
|
||||
generated_ids = outputs[0][input_ids.shape[1] : -1]
|
||||
|
||||
speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens = True)
|
||||
|
||||
# Convert token <|s_23456|> to int 23456
|
||||
speech_tokens = extract_speech_ids(speech_tokens)
|
||||
|
|
|
|||
|
|
@ -28,12 +28,12 @@ print(f"{'='*80}")
|
|||
|
||||
|
||||
model, tokenizer = FastModel.from_pretrained(
|
||||
model_name="unsloth/whisper-large-v3",
|
||||
dtype=None, # Leave as None for auto detection
|
||||
load_in_4bit=False, # Set to True to do 4bit quantization which reduces memory
|
||||
auto_model=WhisperForConditionalGeneration,
|
||||
whisper_language="English",
|
||||
whisper_task="transcribe",
|
||||
model_name = "unsloth/whisper-large-v3",
|
||||
dtype = None, # Leave as None for auto detection
|
||||
load_in_4bit = False, # Set to True to do 4bit quantization which reduces memory
|
||||
auto_model = WhisperForConditionalGeneration,
|
||||
whisper_language = "English",
|
||||
whisper_task = "transcribe",
|
||||
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
|
||||
)
|
||||
|
||||
|
|
@ -46,17 +46,17 @@ model.generation_config.forced_decoder_ids = None
|
|||
|
||||
model = FastModel.get_peft_model(
|
||||
model,
|
||||
r=64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
lora_alpha=64,
|
||||
lora_dropout=0, # Supports any, but = 0 is optimized
|
||||
bias="none", # Supports any, but = "none" is optimized
|
||||
r = 64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
target_modules = ["q_proj", "v_proj"],
|
||||
lora_alpha = 64,
|
||||
lora_dropout = 0, # Supports any, but = 0 is optimized
|
||||
bias = "none", # Supports any, but = "none" is optimized
|
||||
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
|
||||
use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
|
||||
random_state=3407,
|
||||
use_rslora=False, # We support rank stabilized LoRA
|
||||
loftq_config=None, # And LoftQ
|
||||
task_type=None, # ** MUST set this for Whisper **
|
||||
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
|
||||
random_state = 3407,
|
||||
use_rslora = False, # We support rank stabilized LoRA
|
||||
loftq_config = None, # And LoftQ
|
||||
task_type = None, # ** MUST set this for Whisper **
|
||||
)
|
||||
|
||||
print("✅ Model and LoRA adapters loaded successfully!")
|
||||
|
|
@ -110,12 +110,12 @@ print(f"{'='*80}")
|
|||
|
||||
|
||||
model, tokenizer = FastModel.from_pretrained(
|
||||
model_name="./whisper",
|
||||
dtype=None, # Leave as None for auto detection
|
||||
load_in_4bit=False, # Set to True to do 4bit quantization which reduces memory
|
||||
auto_model=WhisperForConditionalGeneration,
|
||||
whisper_language="English",
|
||||
whisper_task="transcribe",
|
||||
model_name = "./whisper",
|
||||
dtype = None, # Leave as None for auto detection
|
||||
load_in_4bit = False, # Set to True to do 4bit quantization which reduces memory
|
||||
auto_model = WhisperForConditionalGeneration,
|
||||
whisper_language = "English",
|
||||
whisper_task = "transcribe",
|
||||
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
|
||||
)
|
||||
|
||||
|
|
@ -135,7 +135,7 @@ try:
|
|||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
response = requests.get(audio_url, headers=headers)
|
||||
response = requests.get(audio_url, headers = headers)
|
||||
response.raise_for_status()
|
||||
with open(audio_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
|
@ -156,12 +156,12 @@ model.eval()
|
|||
# Create pipeline without specifying the device
|
||||
whisper = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model=model,
|
||||
tokenizer=tokenizer.tokenizer,
|
||||
feature_extractor=tokenizer.feature_extractor,
|
||||
processor=tokenizer,
|
||||
return_language=True,
|
||||
torch_dtype=torch.float16, # Remove the device parameter
|
||||
model = model,
|
||||
tokenizer = tokenizer.tokenizer,
|
||||
feature_extractor = tokenizer.feature_extractor,
|
||||
processor = tokenizer,
|
||||
return_language = True,
|
||||
torch_dtype = torch.float16, # Remove the device parameter
|
||||
)
|
||||
# Example usage
|
||||
audio_file = "Speech_12dB_s16.flac"
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from tests.utils.cleanup_utils import safe_remove_directory
|
|||
## Dataset Preparation"""
|
||||
|
||||
print("\n📊 Loading and preparing dataset...")
|
||||
dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split="train")
|
||||
dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split = "train")
|
||||
# To select the first 2000 examples
|
||||
train_dataset = dataset.select(range(2000))
|
||||
|
||||
|
|
@ -81,11 +81,11 @@ print("🤖 Loading base vision model...")
|
|||
try:
|
||||
model, tokenizer = FastVisionModel.from_pretrained(
|
||||
# model_name = "unsloth/Qwen2-VL-7B-Instruct",
|
||||
model_name="unsloth/Qwen2-VL-7B-Instruct",
|
||||
max_seq_length=2048, # Choose any for long context!
|
||||
load_in_4bit=True, # 4 bit quantization to reduce memory
|
||||
load_in_8bit=False, # [NEW!] A bit more accurate, uses 2x memory
|
||||
full_finetuning=False, # [NEW!] We have full finetuning now!
|
||||
model_name = "unsloth/Qwen2-VL-7B-Instruct",
|
||||
max_seq_length = 2048, # Choose any for long context!
|
||||
load_in_4bit = True, # 4 bit quantization to reduce memory
|
||||
load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
|
||||
full_finetuning = False, # [NEW!] We have full finetuning now!
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to load base model: {e}")
|
||||
|
|
@ -96,18 +96,18 @@ print("\n🔧 Setting up LoRA configuration...")
|
|||
try:
|
||||
model = FastVisionModel.get_peft_model(
|
||||
model,
|
||||
finetune_vision_layers=True, # Turn off for just text!
|
||||
finetune_language_layers=True, # Should leave on!
|
||||
finetune_attention_modules=True, # Attention good for GRPO
|
||||
finetune_mlp_modules=True, # SHould leave on always!
|
||||
r=16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
lora_alpha=32,
|
||||
lora_dropout=0, # Supports any, but = 0 is optimized
|
||||
bias="none", # Supports any, but = "none" is optimized
|
||||
use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
|
||||
random_state=3407,
|
||||
use_rslora=False, # We support rank stabilized LoRA
|
||||
loftq_config=None, # And LoftQ
|
||||
finetune_vision_layers = True, # Turn off for just text!
|
||||
finetune_language_layers = True, # Should leave on!
|
||||
finetune_attention_modules = True, # Attention good for GRPO
|
||||
finetune_mlp_modules = True, # SHould leave on always!
|
||||
r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
lora_alpha = 32,
|
||||
lora_dropout = 0, # Supports any, but = 0 is optimized
|
||||
bias = "none", # Supports any, but = "none" is optimized
|
||||
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
|
||||
random_state = 3407,
|
||||
use_rslora = False, # We support rank stabilized LoRA
|
||||
loftq_config = None, # And LoftQ
|
||||
)
|
||||
print("✅ LoRA configuration applied successfully!")
|
||||
print(f" 🎯 LoRA rank (r): 16")
|
||||
|
|
@ -128,40 +128,40 @@ FastVisionModel.for_training(model) # Enable for training!
|
|||
|
||||
try:
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=UnslothVisionDataCollator(model, tokenizer),
|
||||
train_dataset=train_dataset,
|
||||
args=SFTConfig(
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
data_collator = UnslothVisionDataCollator(model, tokenizer),
|
||||
train_dataset = train_dataset,
|
||||
args = SFTConfig(
|
||||
# per_device_train_batch_size = 4,
|
||||
# gradient_accumulation_steps = 8,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs={
|
||||
per_device_train_batch_size = 2,
|
||||
gradient_accumulation_steps = 4,
|
||||
gradient_checkpointing = True,
|
||||
gradient_checkpointing_kwargs = {
|
||||
"use_reentrant": False
|
||||
}, # use reentrant checkpointing
|
||||
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
|
||||
warmup_ratio=0.03,
|
||||
max_grad_norm = 0.3, # max gradient norm based on QLoRA paper
|
||||
warmup_ratio = 0.03,
|
||||
# num_train_epochs = 2, # Set this instead of max_steps for full training runs
|
||||
max_steps=10,
|
||||
learning_rate=2e-4,
|
||||
fp16=not is_bf16_supported(),
|
||||
bf16=is_bf16_supported(),
|
||||
logging_steps=5,
|
||||
save_strategy="epoch",
|
||||
optim="adamw_torch_fused",
|
||||
weight_decay=0.01,
|
||||
lr_scheduler_type="linear",
|
||||
seed=3407,
|
||||
output_dir="checkpoints",
|
||||
report_to="none", # For Weights and Biases
|
||||
max_steps = 10,
|
||||
learning_rate = 2e-4,
|
||||
fp16 = not is_bf16_supported(),
|
||||
bf16 = is_bf16_supported(),
|
||||
logging_steps = 5,
|
||||
save_strategy = "epoch",
|
||||
optim = "adamw_torch_fused",
|
||||
weight_decay = 0.01,
|
||||
lr_scheduler_type = "linear",
|
||||
seed = 3407,
|
||||
output_dir = "checkpoints",
|
||||
report_to = "none", # For Weights and Biases
|
||||
# You MUST put the below items for vision finetuning:
|
||||
remove_unused_columns=False,
|
||||
dataset_text_field="",
|
||||
dataset_kwargs={"skip_prepare_dataset": True},
|
||||
dataset_num_proc=4,
|
||||
max_seq_length=2048,
|
||||
remove_unused_columns = False,
|
||||
dataset_text_field = "",
|
||||
dataset_kwargs = {"skip_prepare_dataset": True},
|
||||
dataset_num_proc = 4,
|
||||
max_seq_length = 2048,
|
||||
),
|
||||
)
|
||||
print("✅ Trainer setup completed!")
|
||||
|
|
@ -221,7 +221,7 @@ try:
|
|||
print("=== UPLOADING MODEL TO HUB ===".center(80))
|
||||
print("=" * 80 + "\n")
|
||||
print(f"🚀 Uploading to repository: {repo_name}")
|
||||
model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token)
|
||||
model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)
|
||||
success["upload"] = True
|
||||
print("✅ Model uploaded successfully!")
|
||||
except Exception as e:
|
||||
|
|
@ -233,8 +233,8 @@ try:
|
|||
print("\n" + "=" * 80)
|
||||
print("=== VERIFYING REPO CONTENTS ===".center(80))
|
||||
print("=" * 80 + "\n")
|
||||
fs = HfFileSystem(token=hf_token)
|
||||
file_list = fs.ls(repo_name, detail=True)
|
||||
fs = HfFileSystem(token = hf_token)
|
||||
file_list = fs.ls(repo_name, detail = True)
|
||||
safetensors_found = any(
|
||||
file["name"].endswith("model.safetensors.index.json") for file in file_list
|
||||
)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from tests.utils.cleanup_utils import safe_remove_directory
|
|||
## Dataset Preparation"""
|
||||
|
||||
print("\n📊 Loading and preparing dataset...")
|
||||
dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split="train")
|
||||
dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split = "train")
|
||||
# To select the first 2000 examples
|
||||
train_dataset = dataset.select(range(2000))
|
||||
|
||||
|
|
@ -82,11 +82,11 @@ print("🤖 Loading base vision model...")
|
|||
try:
|
||||
model, tokenizer = FastVisionModel.from_pretrained(
|
||||
# model_name = "unsloth/Qwen2-VL-7B-Instruct",
|
||||
model_name="unsloth/Qwen2-VL-2B-Instruct",
|
||||
max_seq_length=2048, # Choose any for long context!
|
||||
load_in_4bit=True, # 4 bit quantization to reduce memory
|
||||
load_in_8bit=False, # [NEW!] A bit more accurate, uses 2x memory
|
||||
full_finetuning=False, # [NEW!] We have full finetuning now!
|
||||
model_name = "unsloth/Qwen2-VL-2B-Instruct",
|
||||
max_seq_length = 2048, # Choose any for long context!
|
||||
load_in_4bit = True, # 4 bit quantization to reduce memory
|
||||
load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
|
||||
full_finetuning = False, # [NEW!] We have full finetuning now!
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to load base model: {e}")
|
||||
|
|
@ -97,18 +97,18 @@ print("\n🔧 Setting up LoRA configuration...")
|
|||
try:
|
||||
model = FastVisionModel.get_peft_model(
|
||||
model,
|
||||
finetune_vision_layers=True, # Turn off for just text!
|
||||
finetune_language_layers=True, # Should leave on!
|
||||
finetune_attention_modules=True, # Attention good for GRPO
|
||||
finetune_mlp_modules=True, # SHould leave on always!
|
||||
r=16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
lora_alpha=32,
|
||||
lora_dropout=0, # Supports any, but = 0 is optimized
|
||||
bias="none", # Supports any, but = "none" is optimized
|
||||
use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
|
||||
random_state=3407,
|
||||
use_rslora=False, # We support rank stabilized LoRA
|
||||
loftq_config=None, # And LoftQ
|
||||
finetune_vision_layers = True, # Turn off for just text!
|
||||
finetune_language_layers = True, # Should leave on!
|
||||
finetune_attention_modules = True, # Attention good for GRPO
|
||||
finetune_mlp_modules = True, # SHould leave on always!
|
||||
r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||
lora_alpha = 32,
|
||||
lora_dropout = 0, # Supports any, but = 0 is optimized
|
||||
bias = "none", # Supports any, but = "none" is optimized
|
||||
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
|
||||
random_state = 3407,
|
||||
use_rslora = False, # We support rank stabilized LoRA
|
||||
loftq_config = None, # And LoftQ
|
||||
)
|
||||
print("✅ LoRA configuration applied successfully!")
|
||||
print(f" 🎯 LoRA rank (r): 16")
|
||||
|
|
@ -129,40 +129,40 @@ FastVisionModel.for_training(model) # Enable for training!
|
|||
|
||||
try:
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=UnslothVisionDataCollator(model, tokenizer),
|
||||
train_dataset=train_dataset,
|
||||
args=SFTConfig(
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
data_collator = UnslothVisionDataCollator(model, tokenizer),
|
||||
train_dataset = train_dataset,
|
||||
args = SFTConfig(
|
||||
# per_device_train_batch_size = 4,
|
||||
# gradient_accumulation_steps = 8,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs={
|
||||
per_device_train_batch_size = 2,
|
||||
gradient_accumulation_steps = 4,
|
||||
gradient_checkpointing = True,
|
||||
gradient_checkpointing_kwargs = {
|
||||
"use_reentrant": False
|
||||
}, # use reentrant checkpointing
|
||||
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
|
||||
warmup_ratio=0.03,
|
||||
max_grad_norm = 0.3, # max gradient norm based on QLoRA paper
|
||||
warmup_ratio = 0.03,
|
||||
# num_train_epochs = 2, # Set this instead of max_steps for full training runs
|
||||
max_steps=10,
|
||||
learning_rate=2e-4,
|
||||
fp16=not is_bf16_supported(),
|
||||
bf16=is_bf16_supported(),
|
||||
logging_steps=5,
|
||||
save_strategy="epoch",
|
||||
optim="adamw_torch_fused",
|
||||
weight_decay=0.01,
|
||||
lr_scheduler_type="linear",
|
||||
seed=3407,
|
||||
output_dir="checkpoints",
|
||||
report_to="none", # For Weights and Biases
|
||||
max_steps = 10,
|
||||
learning_rate = 2e-4,
|
||||
fp16 = not is_bf16_supported(),
|
||||
bf16 = is_bf16_supported(),
|
||||
logging_steps = 5,
|
||||
save_strategy = "epoch",
|
||||
optim = "adamw_torch_fused",
|
||||
weight_decay = 0.01,
|
||||
lr_scheduler_type = "linear",
|
||||
seed = 3407,
|
||||
output_dir = "checkpoints",
|
||||
report_to = "none", # For Weights and Biases
|
||||
# You MUST put the below items for vision finetuning:
|
||||
remove_unused_columns=False,
|
||||
dataset_text_field="",
|
||||
dataset_kwargs={"skip_prepare_dataset": True},
|
||||
dataset_num_proc=4,
|
||||
max_seq_length=2048,
|
||||
remove_unused_columns = False,
|
||||
dataset_text_field = "",
|
||||
dataset_kwargs = {"skip_prepare_dataset": True},
|
||||
dataset_num_proc = 4,
|
||||
max_seq_length = 2048,
|
||||
),
|
||||
)
|
||||
print("✅ Trainer setup completed!")
|
||||
|
|
@ -221,7 +221,7 @@ try:
|
|||
print("=== UPLOADING MODEL TO HUB ===".center(80))
|
||||
print("=" * 80 + "\n")
|
||||
print(f"🚀 Uploading to repository: {repo_name}")
|
||||
model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token)
|
||||
model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)
|
||||
success["upload"] = True
|
||||
print("✅ Model uploaded successfully!")
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ def download_and_combine_aime_datasets(data_dir: str = "./data/aime") -> str:
|
|||
"test2025-II": "https://raw.githubusercontent.com/GAIR-NLP/AIME-Preview/main/eval/data/aime/test2025-II.jsonl",
|
||||
}
|
||||
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
os.makedirs(data_dir, exist_ok = True)
|
||||
combined_filepath = os.path.join(data_dir, "aime.jsonl")
|
||||
|
||||
# Check if combined file already exists
|
||||
|
|
@ -67,9 +67,9 @@ def download_and_combine_aime_datasets(data_dir: str = "./data/aime") -> str:
|
|||
|
||||
# Write combined dataset
|
||||
if all_problems:
|
||||
with open(combined_filepath, "w", encoding="utf-8") as f:
|
||||
with open(combined_filepath, "w", encoding = "utf-8") as f:
|
||||
for problem in all_problems:
|
||||
f.write(json.dumps(problem, ensure_ascii=False) + "\n")
|
||||
f.write(json.dumps(problem, ensure_ascii = False) + "\n")
|
||||
|
||||
print(f"✅ Combined {len(all_problems)} problems from {len(datasets)} datasets")
|
||||
print(f" Saved to: {combined_filepath}")
|
||||
|
|
@ -92,7 +92,7 @@ def load_aime_dataset(data_dir: str = "./data/aime") -> List[Dict[str, Any]]:
|
|||
filepath = download_and_combine_aime_datasets(data_dir)
|
||||
|
||||
examples = []
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
with open(filepath, "r", encoding = "utf-8") as f:
|
||||
for line_num, line in enumerate(f):
|
||||
line = line.strip()
|
||||
if line:
|
||||
|
|
@ -188,20 +188,20 @@ def get_num_tokens(text, tokenizer_instance):
|
|||
"""Count tokens in text"""
|
||||
if not text:
|
||||
return 0
|
||||
encoding = tokenizer_instance(text, return_tensors="pt")
|
||||
encoding = tokenizer_instance(text, return_tensors = "pt")
|
||||
return len(encoding["input_ids"][0])
|
||||
|
||||
|
||||
def evaluate_model_aime(
|
||||
model,
|
||||
tokenizer,
|
||||
model_type="base",
|
||||
lora_request=None,
|
||||
temperature=0.3,
|
||||
n_sampling=8,
|
||||
max_tokens=32768,
|
||||
top_p=0.95,
|
||||
seed=0,
|
||||
model_type = "base",
|
||||
lora_request = None,
|
||||
temperature = 0.3,
|
||||
n_sampling = 8,
|
||||
max_tokens = 32768,
|
||||
top_p = 0.95,
|
||||
seed = 0,
|
||||
):
|
||||
"""Evaluate model on combined AIME dataset with official configuration"""
|
||||
|
||||
|
|
@ -237,11 +237,11 @@ def evaluate_model_aime(
|
|||
|
||||
# Setup sampling parameters (AIME configuration)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens,
|
||||
n=n_sampling, # Multiple samples per question
|
||||
seed=seed,
|
||||
temperature = temperature,
|
||||
top_p = top_p,
|
||||
max_tokens = max_tokens,
|
||||
n = n_sampling, # Multiple samples per question
|
||||
seed = seed,
|
||||
)
|
||||
|
||||
print(f"\n🔧 Configuration:")
|
||||
|
|
@ -272,13 +272,13 @@ def evaluate_model_aime(
|
|||
|
||||
# Main evaluation loop
|
||||
with tqdm(
|
||||
total=len(eval_dataset), desc="Processing AIME problems", unit="problem"
|
||||
total = len(eval_dataset), desc = "Processing AIME problems", unit = "problem"
|
||||
) as pbar:
|
||||
for task_id, item in enumerate(eval_dataset):
|
||||
try:
|
||||
# Prepare prompt
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
item["prompt"], add_generation_prompt=True, tokenize=False
|
||||
item["prompt"], add_generation_prompt = True, tokenize = False
|
||||
)
|
||||
|
||||
input_tokens.append(get_num_tokens(prompt_text, tokenizer))
|
||||
|
|
@ -286,9 +286,9 @@ def evaluate_model_aime(
|
|||
# Generate multiple responses
|
||||
outputs = model.fast_generate(
|
||||
[prompt_text],
|
||||
sampling_params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
use_tqdm=False,
|
||||
sampling_params = sampling_params,
|
||||
lora_request = lora_request,
|
||||
use_tqdm = False,
|
||||
)[0].outputs
|
||||
|
||||
# Process all generated responses
|
||||
|
|
@ -413,8 +413,8 @@ def evaluate_model_aime(
|
|||
|
||||
# Save results
|
||||
filename = f"aime_eval_combined_{model_type}_t{temperature}_n{n_sampling}.json"
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
json.dump({"results": results, "records": records}, f, indent=4)
|
||||
with open(filename, "w", encoding = "utf-8") as f:
|
||||
json.dump({"results": results, "records": records}, f, indent = 4)
|
||||
|
||||
# Print comprehensive summary
|
||||
print(f"\n{'='*70}")
|
||||
|
|
@ -517,27 +517,27 @@ def compare_aime_results(all_results):
|
|||
if all_results and "source_accuracies" in all_results[0]:
|
||||
datasets = list(all_results[0]["source_accuracies"].keys())
|
||||
|
||||
print(f"{'Model':<15}", end="")
|
||||
print(f"{'Model':<15}", end = "")
|
||||
for dataset in datasets:
|
||||
print(f"{dataset:<15}", end="")
|
||||
print(f"{dataset:<15}", end = "")
|
||||
print()
|
||||
print("-" * (15 + 15 * len(datasets)))
|
||||
|
||||
for result in all_results:
|
||||
print(f"{result['model_type']:<15}", end="")
|
||||
print(f"{result['model_type']:<15}", end = "")
|
||||
for dataset in datasets:
|
||||
accuracy = result["source_accuracies"].get(dataset, 0)
|
||||
print(f"{accuracy:<15.1f}", end="")
|
||||
print(f"{accuracy:<15.1f}", end = "")
|
||||
print()
|
||||
|
||||
# Save comparison
|
||||
comparison_data = {
|
||||
"summary": all_results,
|
||||
"best_model": max(all_results, key=lambda x: x["accuracy"]),
|
||||
"best_model": max(all_results, key = lambda x: x["accuracy"]),
|
||||
}
|
||||
|
||||
with open("aime_model_comparison.json", "w") as f:
|
||||
json.dump(comparison_data, f, indent=4)
|
||||
json.dump(comparison_data, f, indent = 4)
|
||||
|
||||
print(
|
||||
f"\nBest performing model: {comparison_data['best_model']['model_type']} "
|
||||
|
|
|
|||
|
|
@ -79,27 +79,27 @@ def generate_responses(
|
|||
skip_special_tokens: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
):
|
||||
inputs = [tokenizer(prompt, return_tensors="pt") for _ in range(num_generations)]
|
||||
inputs = [tokenizer(prompt, return_tensors = "pt") for _ in range(num_generations)]
|
||||
keys = inputs[0].keys()
|
||||
batched_inputs = {
|
||||
key: torch.cat([input[key] for input in inputs], dim=0).to(model.device)
|
||||
key: torch.cat([input[key] for input in inputs], dim = 0).to(model.device)
|
||||
for key in keys
|
||||
}
|
||||
|
||||
if dtype is not None:
|
||||
inference_context = torch.autocast(device_type="cuda", dtype=dtype)
|
||||
inference_context = torch.autocast(device_type = "cuda", dtype = dtype)
|
||||
else:
|
||||
inference_context = nullcontext()
|
||||
|
||||
with inference_context:
|
||||
outputs = model.generate(
|
||||
**batched_inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=do_sample,
|
||||
temperature=temperature,
|
||||
max_new_tokens = max_new_tokens,
|
||||
do_sample = do_sample,
|
||||
temperature = temperature,
|
||||
)
|
||||
|
||||
responses = tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens)
|
||||
responses = tokenizer.batch_decode(outputs, skip_special_tokens = skip_special_tokens)
|
||||
return responses
|
||||
|
||||
|
||||
|
|
@ -117,11 +117,11 @@ def sample_responses(
|
|||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
temperature=temperature,
|
||||
num_generations=num_generations,
|
||||
max_new_tokens=max_new_tokens,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
dtype=dtype,
|
||||
temperature = temperature,
|
||||
num_generations = num_generations,
|
||||
max_new_tokens = max_new_tokens,
|
||||
skip_special_tokens = skip_special_tokens,
|
||||
dtype = dtype,
|
||||
)
|
||||
return responses
|
||||
|
||||
|
|
@ -136,32 +136,32 @@ def setup_tokenizer(model_name, fixup_funcs: list[Callable] = []):
|
|||
def setup_model(
|
||||
model_name,
|
||||
quantize: bool = True,
|
||||
dtype=torch.bfloat16,
|
||||
peft_config=None,
|
||||
dtype = torch.bfloat16,
|
||||
peft_config = None,
|
||||
autocast_adapter: bool = True,
|
||||
):
|
||||
if quantize:
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=dtype,
|
||||
load_in_4bit = True,
|
||||
bnb_4bit_use_double_quant = True,
|
||||
bnb_4bit_quant_type = "nf4",
|
||||
bnb_4bit_compute_dtype = dtype,
|
||||
)
|
||||
else:
|
||||
bnb_config = None
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
device_map="cuda:0",
|
||||
attn_implementation="sdpa",
|
||||
quantization_config=bnb_config,
|
||||
torch_dtype=dtype,
|
||||
device_map = "cuda:0",
|
||||
attn_implementation = "sdpa",
|
||||
quantization_config = bnb_config,
|
||||
torch_dtype = dtype,
|
||||
)
|
||||
model = prepare_model_for_kbit_training(model) if quantize else model
|
||||
|
||||
if peft_config is not None:
|
||||
model = get_peft_model(
|
||||
model, peft_config, autocast_adapter_dtype=autocast_adapter
|
||||
model, peft_config, autocast_adapter_dtype = autocast_adapter
|
||||
)
|
||||
|
||||
return model
|
||||
|
|
@ -169,19 +169,19 @@ def setup_model(
|
|||
|
||||
def get_peft_config(
|
||||
lora_rank,
|
||||
lora_alpha=None,
|
||||
lora_dropout=0.0,
|
||||
bias="none",
|
||||
target_modules="all-linear",
|
||||
lora_alpha = None,
|
||||
lora_dropout = 0.0,
|
||||
bias = "none",
|
||||
target_modules = "all-linear",
|
||||
):
|
||||
lora_alpha = lora_alpha or 2 * lora_rank
|
||||
peft_config = LoraConfig(
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
r=lora_rank,
|
||||
bias=bias,
|
||||
target_modules=target_modules,
|
||||
task_type="CAUSAL_LM",
|
||||
lora_alpha = lora_alpha,
|
||||
lora_dropout = lora_dropout,
|
||||
r = lora_rank,
|
||||
bias = bias,
|
||||
target_modules = target_modules,
|
||||
task_type = "CAUSAL_LM",
|
||||
)
|
||||
return peft_config
|
||||
|
||||
|
|
@ -191,18 +191,18 @@ def setup_trainer(
|
|||
tokenizer,
|
||||
dataset,
|
||||
train_args,
|
||||
peft_config=None,
|
||||
formatting_func=None,
|
||||
collator=None,
|
||||
peft_config = None,
|
||||
formatting_func = None,
|
||||
collator = None,
|
||||
):
|
||||
return SFTTrainer(
|
||||
model=model,
|
||||
peft_config=peft_config,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer,
|
||||
formatting_func=formatting_func,
|
||||
data_collator=collator,
|
||||
args=train_args,
|
||||
model = model,
|
||||
peft_config = peft_config,
|
||||
train_dataset = dataset,
|
||||
processing_class = tokenizer,
|
||||
formatting_func = formatting_func,
|
||||
data_collator = collator,
|
||||
args = train_args,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -212,17 +212,17 @@ def setup_lora(
|
|||
dataset,
|
||||
peft_config,
|
||||
train_args,
|
||||
formatting_func=None,
|
||||
collator=None,
|
||||
formatting_func = None,
|
||||
collator = None,
|
||||
):
|
||||
return LoraConfig(
|
||||
model=model,
|
||||
peft_config=peft_config,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer,
|
||||
formatting_func=formatting_func,
|
||||
data_collator=collator,
|
||||
args=train_args,
|
||||
model = model,
|
||||
peft_config = peft_config,
|
||||
train_dataset = dataset,
|
||||
processing_class = tokenizer,
|
||||
formatting_func = formatting_func,
|
||||
data_collator = collator,
|
||||
args = train_args,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -236,7 +236,7 @@ def convert_weights_back_to_dtype(model, dtype):
|
|||
param.data = param.data.to(dtype)
|
||||
|
||||
|
||||
def fix_llama3_tokenizer(tokenizer, padding_side="right"):
|
||||
def fix_llama3_tokenizer(tokenizer, padding_side = "right"):
|
||||
tokenizer.padding_side = padding_side
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
pad_token = [w for w in added_vocab if "pad" in w]
|
||||
|
|
@ -276,12 +276,12 @@ def _convert_lora_to_linear(module: LoraLayer, adapter_name: str = "default"):
|
|||
w_dq = w_dq.to(original_dtype)
|
||||
|
||||
new_module = torch.nn.Linear(
|
||||
w_dq.shape[1], w_dq.shape[0], bias=module.base_layer.bias is not None
|
||||
w_dq.shape[1], w_dq.shape[0], bias = module.base_layer.bias is not None
|
||||
)
|
||||
new_module.weight.data = torch.nn.Parameter(w_dq, requires_grad=False)
|
||||
new_module.weight.data = torch.nn.Parameter(w_dq, requires_grad = False)
|
||||
if module.lora_bias[adapter_name]:
|
||||
bias_data = module.base_layer.bias.data + module.lora_B[adapter_name].bias
|
||||
new_module.bias.data = torch.nn.Parameter(bias_data, requires_grad=False)
|
||||
new_module.bias.data = torch.nn.Parameter(bias_data, requires_grad = False)
|
||||
return new_module
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ if already_imported:
|
|||
f"to ensure all optimizations are applied. Your code may run slower or encounter "
|
||||
f"memory issues without these optimizations.\n\n"
|
||||
f"Please restructure your imports with 'import unsloth' at the top of your file.",
|
||||
stacklevel=2,
|
||||
stacklevel = 2,
|
||||
)
|
||||
del already_imported, critical_modules
|
||||
|
||||
|
|
@ -141,7 +141,7 @@ if DEVICE_TYPE == "cuda":
|
|||
old_is_bf16_supported = torch.cuda.is_bf16_supported
|
||||
if "including_emulation" in str(inspect.signature(old_is_bf16_supported)):
|
||||
|
||||
def is_bf16_supported(including_emulation=False):
|
||||
def is_bf16_supported(including_emulation = False):
|
||||
return old_is_bf16_supported(including_emulation)
|
||||
|
||||
torch.cuda.is_bf16_supported = is_bf16_supported
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ class LoRA_MLP(torch.autograd.Function):
|
|||
downS,
|
||||
_forward_function,
|
||||
_backward_function,
|
||||
inplace=True,
|
||||
inplace = True,
|
||||
):
|
||||
dtype = X.dtype
|
||||
|
||||
|
|
@ -169,39 +169,39 @@ class LoRA_MLP(torch.autograd.Function):
|
|||
# d_downB = (downA.t() @ h.t()) @ dY
|
||||
# d_downA *= downS
|
||||
# d_downB *= downS
|
||||
d_downA.addmm_(h.t(), dY @ downB.t(), alpha=downS, beta=0)
|
||||
d_downB.addmm_(downA.t() @ h.t(), dY, alpha=downS, beta=0)
|
||||
d_downA.addmm_(h.t(), dY @ downB.t(), alpha = downS, beta = 0)
|
||||
d_downB.addmm_(downA.t() @ h.t(), dY, alpha = downS, beta = 0)
|
||||
|
||||
# Up projection LoRA weights
|
||||
# d_upA = X.t() @ (df @ upB.t())
|
||||
# d_upB = (upA.t() @ X.t()) @ df
|
||||
# d_upA *= upS
|
||||
# d_upB *= upS
|
||||
d_upA.addmm_(X.t(), df @ upB.t(), alpha=upS, beta=0)
|
||||
d_upB.addmm_(upA.t() @ X.t(), df, alpha=upS, beta=0)
|
||||
d_upA.addmm_(X.t(), df @ upB.t(), alpha = upS, beta = 0)
|
||||
d_upB.addmm_(upA.t() @ X.t(), df, alpha = upS, beta = 0)
|
||||
|
||||
# Gate projection LoRA weights
|
||||
# d_gateA = X.t() @ (de @ gateB.t())
|
||||
# d_gateB = (gateA.t() @ X.t()) @ de
|
||||
# d_gateA *= gateS
|
||||
# d_gateB *= gateS
|
||||
d_gateA.addmm_(X.t(), de @ gateB.t(), alpha=gateS, beta=0)
|
||||
d_gateB.addmm_(gateA.t() @ X.t(), de, alpha=gateS, beta=0)
|
||||
d_gateA.addmm_(X.t(), de @ gateB.t(), alpha = gateS, beta = 0)
|
||||
d_gateB.addmm_(gateA.t() @ X.t(), de, alpha = gateS, beta = 0)
|
||||
|
||||
# dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
|
||||
# dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
|
||||
upW = fast_dequantize(upW.t(), upW_quant)
|
||||
dX = torch.matmul(df, upW.t(), out=X if ctx.inplace else None)
|
||||
dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None)
|
||||
del upW
|
||||
# dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
|
||||
dX.addmm_(df @ upB.t(), upA.t(), alpha=upS)
|
||||
dX.addmm_(df @ upB.t(), upA.t(), alpha = upS)
|
||||
|
||||
gateW = fast_dequantize(gateW.t(), gateW_quant)
|
||||
# dX += de @ gateW.t()
|
||||
dX.addmm_(de, gateW.t())
|
||||
del gateW
|
||||
# dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
|
||||
dX.addmm_(de @ gateB.t(), gateA.t(), alpha=gateS)
|
||||
dX.addmm_(de @ gateB.t(), gateA.t(), alpha = gateS)
|
||||
|
||||
# gateW, gateW_quant, gateA, gateB, gateS,
|
||||
# upW, upW_quant, upA, upB, upS,
|
||||
|
|
@ -232,7 +232,7 @@ class LoRA_MLP(torch.autograd.Function):
|
|||
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
|
||||
|
||||
|
||||
def apply_lora_mlp_swiglu(self, X, inplace=True):
|
||||
def apply_lora_mlp_swiglu(self, X, inplace = True):
|
||||
X = _maybe_fake_quantize_activations(X, self.gate_proj)
|
||||
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||
|
|
@ -264,7 +264,7 @@ def apply_lora_mlp_swiglu(self, X, inplace=True):
|
|||
from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel
|
||||
|
||||
|
||||
def apply_lora_mlp_geglu_exact(self, X, inplace=True):
|
||||
def apply_lora_mlp_geglu_exact(self, X, inplace = True):
|
||||
X = _maybe_fake_quantize_activations(X, self.gate_proj)
|
||||
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||
|
|
@ -375,7 +375,7 @@ class LoRA_QKV(torch.autograd.Function):
|
|||
VA,
|
||||
VB,
|
||||
VS,
|
||||
inplace=True,
|
||||
inplace = True,
|
||||
):
|
||||
dtype = X.dtype
|
||||
|
||||
|
|
@ -452,32 +452,32 @@ class LoRA_QKV(torch.autograd.Function):
|
|||
# d_QB = (QA.t() @ X.t()) @ dQ
|
||||
# d_QA *= QS
|
||||
# d_QB *= QS
|
||||
d_QA.addmm_(X.t(), dQ @ QB.t(), alpha=QS, beta=0)
|
||||
d_QB.addmm_(QA.t() @ X.t(), dQ, alpha=QS, beta=0)
|
||||
d_QA.addmm_(X.t(), dQ @ QB.t(), alpha = QS, beta = 0)
|
||||
d_QB.addmm_(QA.t() @ X.t(), dQ, alpha = QS, beta = 0)
|
||||
|
||||
# K Projection
|
||||
# d_KA = X.t() @ (dK @ KB.t())
|
||||
# d_KB = (KA.t() @ X.t()) @ dK
|
||||
# d_KA *= KS
|
||||
# d_KB *= KS
|
||||
d_KA.addmm_(X.t(), dK @ KB.t(), alpha=KS, beta=0)
|
||||
d_KB.addmm_(KA.t() @ X.t(), dK, alpha=KS, beta=0)
|
||||
d_KA.addmm_(X.t(), dK @ KB.t(), alpha = KS, beta = 0)
|
||||
d_KB.addmm_(KA.t() @ X.t(), dK, alpha = KS, beta = 0)
|
||||
|
||||
# V Projection
|
||||
# d_VA = X.t() @ (dV @ VB.t())
|
||||
# d_VB = (VA.t() @ X.t()) @ dV
|
||||
# d_VA *= VS
|
||||
# d_VB *= VS
|
||||
d_VA.addmm_(X.t(), dV @ VB.t(), alpha=VS, beta=0)
|
||||
d_VB.addmm_(VA.t() @ X.t(), dV, alpha=VS, beta=0)
|
||||
d_VA.addmm_(X.t(), dV @ VB.t(), alpha = VS, beta = 0)
|
||||
d_VB.addmm_(VA.t() @ X.t(), dV, alpha = VS, beta = 0)
|
||||
|
||||
# Combine derivatives to find dX
|
||||
# dQ
|
||||
QW = fast_dequantize(QW.t(), QW_quant)
|
||||
dX = torch.matmul(dQ, QW.t(), out=X if ctx.inplace else None)
|
||||
dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None)
|
||||
del QW
|
||||
# dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))
|
||||
dX.addmm_(dQ @ QB.t(), QA.t(), alpha=QS)
|
||||
dX.addmm_(dQ @ QB.t(), QA.t(), alpha = QS)
|
||||
|
||||
# dK
|
||||
KW = fast_dequantize(KW.t(), KW_quant)
|
||||
|
|
@ -485,7 +485,7 @@ class LoRA_QKV(torch.autograd.Function):
|
|||
dX.addmm_(dK, KW.t())
|
||||
del KW
|
||||
# dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())
|
||||
dX.addmm_(dK @ KB.t(), KA.t(), alpha=KS)
|
||||
dX.addmm_(dK @ KB.t(), KA.t(), alpha = KS)
|
||||
|
||||
# dV
|
||||
VW = fast_dequantize(VW.t(), VW_quant)
|
||||
|
|
@ -493,7 +493,7 @@ class LoRA_QKV(torch.autograd.Function):
|
|||
dX.addmm_(dV, VW.t())
|
||||
del VW
|
||||
# dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())
|
||||
dX.addmm_(dV @ VB.t(), VA.t(), alpha=VS)
|
||||
dX.addmm_(dV @ VB.t(), VA.t(), alpha = VS)
|
||||
|
||||
# QW, QW_quant, QA, QB, QS,
|
||||
# KW, KW_quant, KA, KB, KS,
|
||||
|
|
@ -519,7 +519,7 @@ class LoRA_QKV(torch.autograd.Function):
|
|||
)
|
||||
|
||||
|
||||
def apply_lora_qkv(self, X, inplace=True):
|
||||
def apply_lora_qkv(self, X, inplace = True):
|
||||
X = _maybe_fake_quantize_activations(X, self.q_proj)
|
||||
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
|
||||
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
|
||||
|
|
@ -611,15 +611,15 @@ class LoRA_W(torch.autograd.Function):
|
|||
# d_B = (A.t() @ X.t()) @ dY
|
||||
# d_A *= S
|
||||
# d_B *= S
|
||||
d_A.addmm_(X.t(), dY @ B.t(), alpha=S, beta=0)
|
||||
d_B.addmm_(A.t() @ X.t(), dY, alpha=S, beta=0)
|
||||
d_A.addmm_(X.t(), dY @ B.t(), alpha = S, beta = 0)
|
||||
d_B.addmm_(A.t() @ X.t(), dY, alpha = S, beta = 0)
|
||||
|
||||
# Get derivative for dX
|
||||
W = fast_dequantize(W.t(), W_quant)
|
||||
dX = dY @ W.t()
|
||||
del W
|
||||
# dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
|
||||
dX.addmm_(dY @ B.t(), A.t(), alpha=S)
|
||||
dX.addmm_(dY @ B.t(), A.t(), alpha = S)
|
||||
|
||||
# W, W_quant, A, B, S
|
||||
return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None
|
||||
|
|
@ -649,7 +649,7 @@ def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
|||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif adapter_names is not None:
|
||||
result = self._mixed_batch_forward(
|
||||
x, *args, adapter_names=adapter_names, **kwargs
|
||||
x, *args, adapter_names = adapter_names, **kwargs
|
||||
)
|
||||
elif self.merged:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
|
|
@ -705,11 +705,11 @@ def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
|||
|
||||
result = result + self.lora_magnitude_vector[active_adapter](
|
||||
x,
|
||||
lora_A=lora_A,
|
||||
lora_B=lora_B,
|
||||
scaling=scaling,
|
||||
base_layer=self.get_base_layer(),
|
||||
base_result=base_result,
|
||||
lora_A = lora_A,
|
||||
lora_B = lora_B,
|
||||
scaling = scaling,
|
||||
base_layer = self.get_base_layer(),
|
||||
base_result = base_result,
|
||||
)
|
||||
if requires_conversion:
|
||||
result = result.to(expected_dtype)
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ def run_benchmark_forward(
|
|||
hidden_size = config.hidden_size
|
||||
|
||||
X = torch.randn(
|
||||
bs, seqlen, hidden_size, dtype=dtype, device=device, requires_grad=True
|
||||
bs, seqlen, hidden_size, dtype = dtype, device = device, requires_grad = True
|
||||
)
|
||||
|
||||
# Forward
|
||||
|
|
@ -89,7 +89,7 @@ def run_benchmark_backward(
|
|||
config: AutoConfig,
|
||||
seqlen: int,
|
||||
dtype: torch.dtype,
|
||||
bs=1,
|
||||
bs = 1,
|
||||
):
|
||||
torch.manual_seed(
|
||||
SEED
|
||||
|
|
@ -98,7 +98,7 @@ def run_benchmark_backward(
|
|||
hidden_size = config.hidden_size
|
||||
|
||||
X = torch.randn(
|
||||
bs, seqlen, hidden_size, dtype=dtype, device=device, requires_grad=True
|
||||
bs, seqlen, hidden_size, dtype = dtype, device = device, requires_grad = True
|
||||
)
|
||||
X_test = X.detach().clone().requires_grad_(True)
|
||||
|
||||
|
|
@ -114,14 +114,14 @@ def run_benchmark_backward(
|
|||
|
||||
# Bench
|
||||
grad_output = torch.randn_like(output)
|
||||
bench_backward_ref = lambda: output.backward(grad_output, retain_graph=True) # noqa: E731
|
||||
bench_backward_fused = lambda: test_output.backward(grad_output, retain_graph=True) # noqa: E731
|
||||
bench_backward_ref = lambda: output.backward(grad_output, retain_graph = True) # noqa: E731
|
||||
bench_backward_fused = lambda: test_output.backward(grad_output, retain_graph = True) # noqa: E731
|
||||
|
||||
ref_backward_time = do_bench(
|
||||
bench_backward_ref, grad_to_none=[X, *ref_model.parameters()]
|
||||
bench_backward_ref, grad_to_none = [X, *ref_model.parameters()]
|
||||
)
|
||||
fused_backward_time = do_bench(
|
||||
bench_backward_fused, grad_to_none=[X_test, *tt_model.parameters()]
|
||||
bench_backward_fused, grad_to_none = [X_test, *tt_model.parameters()]
|
||||
)
|
||||
print(
|
||||
f"Backward: ref {ref_backward_time:.4f}, fused {fused_backward_time:.4f}, speedup {ref_backward_time / fused_backward_time:.1f}x"
|
||||
|
|
@ -138,10 +138,10 @@ def setup_model(
|
|||
kernel_config_fwd,
|
||||
kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX,
|
||||
dX_only=False,
|
||||
dW_only=False,
|
||||
overlap_router_shared=False,
|
||||
device="cuda",
|
||||
dX_only = False,
|
||||
dW_only = False,
|
||||
overlap_router_shared = False,
|
||||
device = "cuda",
|
||||
):
|
||||
if isinstance(config, Qwen3MoeConfig):
|
||||
ref_model = Qwen3MoeSparseMoeBlock(config).to(device, dtype)
|
||||
|
|
@ -149,29 +149,29 @@ def setup_model(
|
|||
# Triton kernel grouped gemm version of MoE Block -- this is what we're testing
|
||||
tt_model = Qwen3MoeFusedGroupedGEMMBlock.from_hf(
|
||||
ref_model,
|
||||
permute_x=permute_x,
|
||||
permute_y=permute_y,
|
||||
autotune=autotune,
|
||||
kernel_config_fwd=kernel_config_fwd,
|
||||
kernel_config_bwd_dW=kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX=kernel_config_bwd_dX,
|
||||
dX_only=dX_only,
|
||||
dW_only=dW_only,
|
||||
permute_x = permute_x,
|
||||
permute_y = permute_y,
|
||||
autotune = autotune,
|
||||
kernel_config_fwd = kernel_config_fwd,
|
||||
kernel_config_bwd_dW = kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX = kernel_config_bwd_dX,
|
||||
dX_only = dX_only,
|
||||
dW_only = dW_only,
|
||||
).to(device, dtype)
|
||||
|
||||
elif isinstance(config, Llama4TextConfig):
|
||||
ref_model = Llama4TextMoe(config).to(device, dtype)
|
||||
tt_model = Llama4TritonTextMoe(
|
||||
config,
|
||||
overlap_router_shared=overlap_router_shared,
|
||||
permute_x=permute_x,
|
||||
permute_y=permute_y,
|
||||
autotune=autotune,
|
||||
kernel_config_fwd=kernel_config_fwd,
|
||||
kernel_config_bwd_dW=kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX=kernel_config_bwd_dX,
|
||||
dX_only=dX_only,
|
||||
dW_only=dW_only,
|
||||
overlap_router_shared = overlap_router_shared,
|
||||
permute_x = permute_x,
|
||||
permute_y = permute_y,
|
||||
autotune = autotune,
|
||||
kernel_config_fwd = kernel_config_fwd,
|
||||
kernel_config_bwd_dW = kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX = kernel_config_bwd_dX,
|
||||
dX_only = dX_only,
|
||||
dW_only = dW_only,
|
||||
).to(device, dtype)
|
||||
|
||||
else:
|
||||
|
|
@ -205,31 +205,31 @@ def run_benchmark(
|
|||
|
||||
ref_model, tt_model = setup_model(
|
||||
model_config,
|
||||
dtype=dtype,
|
||||
permute_x=permute_x,
|
||||
permute_y=permute_y,
|
||||
autotune=autotune,
|
||||
kernel_config_fwd=kernel_config_fwd,
|
||||
kernel_config_bwd_dW=kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX=kernel_config_bwd_dX,
|
||||
dX_only=dX_only,
|
||||
dW_only=dW_only,
|
||||
overlap_router_shared=overlap_router_shared,
|
||||
dtype = dtype,
|
||||
permute_x = permute_x,
|
||||
permute_y = permute_y,
|
||||
autotune = autotune,
|
||||
kernel_config_fwd = kernel_config_fwd,
|
||||
kernel_config_bwd_dW = kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX = kernel_config_bwd_dX,
|
||||
dX_only = dX_only,
|
||||
dW_only = dW_only,
|
||||
overlap_router_shared = overlap_router_shared,
|
||||
)
|
||||
|
||||
if mode == "forward":
|
||||
ref_time, fused_time = run_benchmark_forward(
|
||||
ref_model,
|
||||
tt_model,
|
||||
config=model_config,
|
||||
seqlen=seqlen,
|
||||
dtype=dtype,
|
||||
autotune=autotune,
|
||||
kernel_config_fwd=kernel_config_fwd,
|
||||
config = model_config,
|
||||
seqlen = seqlen,
|
||||
dtype = dtype,
|
||||
autotune = autotune,
|
||||
kernel_config_fwd = kernel_config_fwd,
|
||||
)
|
||||
else:
|
||||
ref_time, fused_time = run_benchmark_backward(
|
||||
ref_model, tt_model, config=model_config, seqlen=seqlen, dtype=dtype
|
||||
ref_model, tt_model, config = model_config, seqlen = seqlen, dtype = dtype
|
||||
)
|
||||
|
||||
if autotune:
|
||||
|
|
@ -251,60 +251,60 @@ def run_benchmark(
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--results_dir", type=str, default="benchmark_results")
|
||||
parser.add_argument("--model", type=str, choices=["llama4", "qwen3"], required=True)
|
||||
parser.add_argument("--seqlen", type=int, default=1024)
|
||||
parser.add_argument("--results_dir", type = str, default = "benchmark_results")
|
||||
parser.add_argument("--model", type = str, choices = ["llama4", "qwen3"], required = True)
|
||||
parser.add_argument("--seqlen", type = int, default = 1024)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["bfloat16", "float16"], default="bfloat16"
|
||||
"--dtype", type = str, choices = ["bfloat16", "float16"], default = "bfloat16"
|
||||
)
|
||||
parser.add_argument("--permute_x", action="store_true")
|
||||
parser.add_argument("--permute_y", action="store_true")
|
||||
parser.add_argument("--autotune", action="store_true")
|
||||
parser.add_argument("--overlap_router_shared", action="store_true")
|
||||
parser.add_argument("--permute_x", action = "store_true")
|
||||
parser.add_argument("--permute_y", action = "store_true")
|
||||
parser.add_argument("--autotune", action = "store_true")
|
||||
parser.add_argument("--overlap_router_shared", action = "store_true")
|
||||
parser.add_argument(
|
||||
"--BLOCK_SIZE_M",
|
||||
nargs=2,
|
||||
type=int,
|
||||
default=[DEFAULT_M_BLOCK_SIZES[0], DEFAULT_M_BLOCK_SIZES[-1]],
|
||||
nargs = 2,
|
||||
type = int,
|
||||
default = [DEFAULT_M_BLOCK_SIZES[0], DEFAULT_M_BLOCK_SIZES[-1]],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--BLOCK_SIZE_N",
|
||||
nargs=2,
|
||||
type=int,
|
||||
default=[DEFAULT_N_BLOCK_SIZES[0], DEFAULT_N_BLOCK_SIZES[-1]],
|
||||
nargs = 2,
|
||||
type = int,
|
||||
default = [DEFAULT_N_BLOCK_SIZES[0], DEFAULT_N_BLOCK_SIZES[-1]],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--BLOCK_SIZE_K",
|
||||
nargs=2,
|
||||
type=int,
|
||||
default=[DEFAULT_K_BLOCK_SIZES[0], DEFAULT_K_BLOCK_SIZES[-1]],
|
||||
nargs = 2,
|
||||
type = int,
|
||||
default = [DEFAULT_K_BLOCK_SIZES[0], DEFAULT_K_BLOCK_SIZES[-1]],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_warps",
|
||||
nargs=2,
|
||||
type=int,
|
||||
default=[DEFAULT_NUM_WARPS[0], DEFAULT_NUM_WARPS[-1]],
|
||||
nargs = 2,
|
||||
type = int,
|
||||
default = [DEFAULT_NUM_WARPS[0], DEFAULT_NUM_WARPS[-1]],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_stages",
|
||||
nargs=2,
|
||||
type=int,
|
||||
default=[DEFAULT_NUM_STAGES[0], DEFAULT_NUM_STAGES[-1]],
|
||||
nargs = 2,
|
||||
type = int,
|
||||
default = [DEFAULT_NUM_STAGES[0], DEFAULT_NUM_STAGES[-1]],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_tma_load_w", action="store_true"
|
||||
"--use_tma_load_w", action = "store_true"
|
||||
) # No need to specify, will automatically parametrize these for each kernel config
|
||||
parser.add_argument(
|
||||
"--use_tma_load_x", action="store_true"
|
||||
"--use_tma_load_x", action = "store_true"
|
||||
) # No need to specify, will automatically parametrize these for each kernel config
|
||||
parser.add_argument(
|
||||
"--use_tma_load_dy", action="store_true"
|
||||
"--use_tma_load_dy", action = "store_true"
|
||||
) # No need to specify, will automatically parametrize these for each kernel config
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
choices=["forward", "backward", "dW", "dX"],
|
||||
default="forward",
|
||||
type = str,
|
||||
choices = ["forward", "backward", "dW", "dX"],
|
||||
default = "forward",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args.dtype = getattr(torch, args.dtype)
|
||||
|
|
@ -324,13 +324,13 @@ if __name__ == "__main__":
|
|||
ref_time, fused_time = run_benchmark(
|
||||
args.mode,
|
||||
model_config,
|
||||
seqlen=args.seqlen,
|
||||
dtype=args.dtype,
|
||||
permute_x=args.permute_x,
|
||||
permute_y=args.permute_y,
|
||||
autotune=args.autotune,
|
||||
overlap_router_shared=args.overlap_router_shared,
|
||||
results_dir=args.results_dir,
|
||||
seqlen = args.seqlen,
|
||||
dtype = args.dtype,
|
||||
permute_x = args.permute_x,
|
||||
permute_y = args.permute_y,
|
||||
autotune = args.autotune,
|
||||
overlap_router_shared = args.overlap_router_shared,
|
||||
results_dir = args.results_dir,
|
||||
)
|
||||
end_time = time.time()
|
||||
print(f"Total time: {end_time - start_time:.4f} seconds")
|
||||
|
|
@ -343,13 +343,13 @@ if __name__ == "__main__":
|
|||
kernel_configs = create_kernel_configs(args, args.permute_x, args.permute_y)
|
||||
print(f"Running {len(kernel_configs)} kernel configs")
|
||||
default_kernel_config_fwd = KernelConfigForward(
|
||||
permute_x=args.permute_x, permute_y=args.permute_y
|
||||
permute_x = args.permute_x, permute_y = args.permute_y
|
||||
)
|
||||
default_kernel_config_bwd_dW = KernelConfigBackward_dW(
|
||||
permute_x=args.permute_x, permute_y=args.permute_y
|
||||
permute_x = args.permute_x, permute_y = args.permute_y
|
||||
)
|
||||
default_kernel_config_bwd_dX = KernelConfigBackward_dX(
|
||||
permute_x=args.permute_x, permute_y=args.permute_y
|
||||
permute_x = args.permute_x, permute_y = args.permute_y
|
||||
)
|
||||
results = []
|
||||
for kernel_config in kernel_configs:
|
||||
|
|
@ -374,21 +374,21 @@ if __name__ == "__main__":
|
|||
ref_time, fused_time = run_benchmark(
|
||||
args.mode,
|
||||
model_config,
|
||||
seqlen=args.seqlen,
|
||||
dtype=args.dtype,
|
||||
permute_x=kernel_config.permute_x,
|
||||
permute_y=kernel_config.permute_y,
|
||||
autotune=False,
|
||||
kernel_config_fwd=kernel_config_fwd,
|
||||
kernel_config_bwd_dW=kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX=kernel_config_bwd_dX,
|
||||
seqlen = args.seqlen,
|
||||
dtype = args.dtype,
|
||||
permute_x = kernel_config.permute_x,
|
||||
permute_y = kernel_config.permute_y,
|
||||
autotune = False,
|
||||
kernel_config_fwd = kernel_config_fwd,
|
||||
kernel_config_bwd_dW = kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX = kernel_config_bwd_dX,
|
||||
)
|
||||
results.append(
|
||||
KernelResult(
|
||||
torch_time=ref_time,
|
||||
triton_time=fused_time,
|
||||
speedup=ref_time / fused_time,
|
||||
kernel_config=kernel_config,
|
||||
torch_time = ref_time,
|
||||
triton_time = fused_time,
|
||||
speedup = ref_time / fused_time,
|
||||
kernel_config = kernel_config,
|
||||
)
|
||||
)
|
||||
df = post_process_results(
|
||||
|
|
|
|||
|
|
@ -37,15 +37,15 @@ def convert_args_to_list(args):
|
|||
|
||||
|
||||
def get_forward_configs(
|
||||
BLOCK_M=DEFAULT_M_BLOCK_SIZES,
|
||||
BLOCK_N=DEFAULT_N_BLOCK_SIZES,
|
||||
BLOCK_K=DEFAULT_K_BLOCK_SIZES,
|
||||
TMA_LOAD_X=True,
|
||||
TMA_LOAD_W=True,
|
||||
TMA_STORE=False, # NOTE: TMA_STORE is disabled for now
|
||||
num_warps=DEFAULT_NUM_WARPS,
|
||||
num_stages=DEFAULT_NUM_STAGES,
|
||||
num_ctas=DEFAULT_NUM_CTAS,
|
||||
BLOCK_M = DEFAULT_M_BLOCK_SIZES,
|
||||
BLOCK_N = DEFAULT_N_BLOCK_SIZES,
|
||||
BLOCK_K = DEFAULT_K_BLOCK_SIZES,
|
||||
TMA_LOAD_X = True,
|
||||
TMA_LOAD_W = True,
|
||||
TMA_STORE = False, # NOTE: TMA_STORE is disabled for now
|
||||
num_warps = DEFAULT_NUM_WARPS,
|
||||
num_stages = DEFAULT_NUM_STAGES,
|
||||
num_ctas = DEFAULT_NUM_CTAS,
|
||||
):
|
||||
(
|
||||
BLOCK_M,
|
||||
|
|
@ -95,16 +95,16 @@ def get_forward_configs(
|
|||
kernel_configs.append(
|
||||
triton.Config(
|
||||
dict(
|
||||
BLOCK_SIZE_M=block_m,
|
||||
BLOCK_SIZE_N=block_n,
|
||||
BLOCK_SIZE_K=block_k,
|
||||
USE_TMA_LOAD_X=tma_load_x,
|
||||
USE_TMA_LOAD_W=tma_load_w,
|
||||
USE_TMA_STORE=tma_store,
|
||||
BLOCK_SIZE_M = block_m,
|
||||
BLOCK_SIZE_N = block_n,
|
||||
BLOCK_SIZE_K = block_k,
|
||||
USE_TMA_LOAD_X = tma_load_x,
|
||||
USE_TMA_LOAD_W = tma_load_w,
|
||||
USE_TMA_STORE = tma_store,
|
||||
),
|
||||
num_warps=w,
|
||||
num_stages=s,
|
||||
num_ctas=num_ctas,
|
||||
num_warps = w,
|
||||
num_stages = s,
|
||||
num_ctas = num_ctas,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -112,15 +112,15 @@ def get_forward_configs(
|
|||
|
||||
|
||||
def get_dX_kernel_configs(
|
||||
BLOCK_M=DEFAULT_M_BLOCK_SIZES,
|
||||
BLOCK_N=DEFAULT_N_BLOCK_SIZES,
|
||||
BLOCK_K=DEFAULT_K_BLOCK_SIZES,
|
||||
TMA_LOAD_dY=True,
|
||||
TMA_LOAD_W=True,
|
||||
TMA_STORE=False, # NOTE: TMA_STORE is disabled for now
|
||||
num_warps=DEFAULT_NUM_WARPS,
|
||||
num_stages=DEFAULT_NUM_STAGES,
|
||||
num_ctas=DEFAULT_NUM_CTAS,
|
||||
BLOCK_M = DEFAULT_M_BLOCK_SIZES,
|
||||
BLOCK_N = DEFAULT_N_BLOCK_SIZES,
|
||||
BLOCK_K = DEFAULT_K_BLOCK_SIZES,
|
||||
TMA_LOAD_dY = True,
|
||||
TMA_LOAD_W = True,
|
||||
TMA_STORE = False, # NOTE: TMA_STORE is disabled for now
|
||||
num_warps = DEFAULT_NUM_WARPS,
|
||||
num_stages = DEFAULT_NUM_STAGES,
|
||||
num_ctas = DEFAULT_NUM_CTAS,
|
||||
):
|
||||
(
|
||||
BLOCK_M,
|
||||
|
|
@ -170,16 +170,16 @@ def get_dX_kernel_configs(
|
|||
kernel_configs.append(
|
||||
triton.Config(
|
||||
dict(
|
||||
BLOCK_SIZE_M=block_m,
|
||||
BLOCK_SIZE_N=block_n,
|
||||
BLOCK_SIZE_K=block_k,
|
||||
USE_TMA_LOAD_dY=tma_load_dy,
|
||||
USE_TMA_LOAD_W=tma_load_w,
|
||||
USE_TMA_STORE=tma_store,
|
||||
BLOCK_SIZE_M = block_m,
|
||||
BLOCK_SIZE_N = block_n,
|
||||
BLOCK_SIZE_K = block_k,
|
||||
USE_TMA_LOAD_dY = tma_load_dy,
|
||||
USE_TMA_LOAD_W = tma_load_w,
|
||||
USE_TMA_STORE = tma_store,
|
||||
),
|
||||
num_warps=w,
|
||||
num_stages=s,
|
||||
num_ctas=num_ctas,
|
||||
num_warps = w,
|
||||
num_stages = s,
|
||||
num_ctas = num_ctas,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -187,15 +187,15 @@ def get_dX_kernel_configs(
|
|||
|
||||
|
||||
def get_dW_kernel_configs(
|
||||
BLOCK_M=DEFAULT_M_BLOCK_SIZES,
|
||||
BLOCK_N=DEFAULT_N_BLOCK_SIZES,
|
||||
BLOCK_K=DEFAULT_K_BLOCK_SIZES,
|
||||
num_warps=DEFAULT_NUM_WARPS,
|
||||
num_stages=DEFAULT_NUM_STAGES,
|
||||
num_ctas=DEFAULT_NUM_CTAS,
|
||||
TMA_LOAD_dY=True,
|
||||
TMA_LOAD_X=True,
|
||||
TMA_STORE=False,
|
||||
BLOCK_M = DEFAULT_M_BLOCK_SIZES,
|
||||
BLOCK_N = DEFAULT_N_BLOCK_SIZES,
|
||||
BLOCK_K = DEFAULT_K_BLOCK_SIZES,
|
||||
num_warps = DEFAULT_NUM_WARPS,
|
||||
num_stages = DEFAULT_NUM_STAGES,
|
||||
num_ctas = DEFAULT_NUM_CTAS,
|
||||
TMA_LOAD_dY = True,
|
||||
TMA_LOAD_X = True,
|
||||
TMA_STORE = False,
|
||||
):
|
||||
(
|
||||
BLOCK_M,
|
||||
|
|
@ -245,16 +245,16 @@ def get_dW_kernel_configs(
|
|||
kernel_configs.append(
|
||||
triton.Config(
|
||||
dict(
|
||||
BLOCK_SIZE_M=block_m,
|
||||
BLOCK_SIZE_N=block_n,
|
||||
BLOCK_SIZE_K=block_k,
|
||||
USE_TMA_LOAD_dY=tma_load_dy,
|
||||
USE_TMA_LOAD_X=tma_load_x,
|
||||
USE_TMA_STORE=tma_store,
|
||||
BLOCK_SIZE_M = block_m,
|
||||
BLOCK_SIZE_N = block_n,
|
||||
BLOCK_SIZE_K = block_k,
|
||||
USE_TMA_LOAD_dY = tma_load_dy,
|
||||
USE_TMA_LOAD_X = tma_load_x,
|
||||
USE_TMA_STORE = tma_store,
|
||||
),
|
||||
num_warps=w,
|
||||
num_stages=s,
|
||||
num_ctas=num_ctas,
|
||||
num_warps = w,
|
||||
num_stages = s,
|
||||
num_ctas = num_ctas,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -84,18 +84,18 @@ def _grouped_gemm_dX_kernel(
|
|||
if USE_TMA_LOAD_dY:
|
||||
dY_desc = tl._experimental_make_tensor_descriptor(
|
||||
dY_ptr,
|
||||
shape=[TOTAL_TOKENS, N],
|
||||
strides=[N, 1],
|
||||
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
||||
shape = [TOTAL_TOKENS, N],
|
||||
strides = [N, 1],
|
||||
block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],
|
||||
)
|
||||
|
||||
if USE_TMA_LOAD_W:
|
||||
expert_stride = N * K
|
||||
w_desc = tl._experimental_make_tensor_descriptor(
|
||||
w_ptr,
|
||||
shape=[NUM_EXPERTS, N, K],
|
||||
strides=[expert_stride, K, 1],
|
||||
block_shape=[1, BLOCK_SIZE_N, BLOCK_SIZE_K],
|
||||
shape = [NUM_EXPERTS, N, K],
|
||||
strides = [expert_stride, K, 1],
|
||||
block_shape = [1, BLOCK_SIZE_N, BLOCK_SIZE_K],
|
||||
)
|
||||
|
||||
m_end = 0
|
||||
|
|
@ -104,7 +104,7 @@ def _grouped_gemm_dX_kernel(
|
|||
n_block_range = tl.arange(0, BLOCK_SIZE_N)
|
||||
k_block_range = tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
for expert_idx in range(NUM_EXPERTS, flatten=FLATTEN):
|
||||
for expert_idx in range(NUM_EXPERTS, flatten = FLATTEN):
|
||||
m_start = m_end
|
||||
m_size = tl.load(m_sizes_ptr + expert_idx).to(tl.int32)
|
||||
m_end = m_start + m_size
|
||||
|
|
@ -125,9 +125,9 @@ def _grouped_gemm_dX_kernel(
|
|||
)
|
||||
dX_desc = tl._experimental_make_tensor_descriptor(
|
||||
dX_ptr,
|
||||
shape=[m_end, K],
|
||||
strides=[K, 1],
|
||||
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
||||
shape = [m_end, K],
|
||||
strides = [K, 1],
|
||||
block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],
|
||||
)
|
||||
|
||||
# Lower bound and upper bound are defined relative to the total tiles processed so far
|
||||
|
|
@ -152,7 +152,7 @@ def _grouped_gemm_dX_kernel(
|
|||
)
|
||||
expert_token_idx = tl.load(
|
||||
gather_indices_ptr + indices_to_gather,
|
||||
mask=indices_to_gather < TOTAL_TOKENS,
|
||||
mask = indices_to_gather < TOTAL_TOKENS,
|
||||
)
|
||||
expert_token_offsets = expert_token_idx[:, None]
|
||||
|
||||
|
|
@ -210,13 +210,13 @@ def _grouped_gemm_dX_kernel(
|
|||
# col_mask = offs_bk[None, :] < K
|
||||
store_mask = row_mask # & col_mask
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype = tl.float32)
|
||||
|
||||
# GEMM main loop
|
||||
for n_offset in range(0, N, BLOCK_SIZE_N):
|
||||
# dY block [M, N]
|
||||
if not USE_TMA_LOAD_dY:
|
||||
dY = tl.load(dY_ptrs, mask=row_mask)
|
||||
dY = tl.load(dY_ptrs, mask = row_mask)
|
||||
else:
|
||||
dY = dY_desc.load(
|
||||
[m_start + tile_m_idx * BLOCK_SIZE_M, n_offset]
|
||||
|
|
@ -253,7 +253,7 @@ def _grouped_gemm_dX_kernel(
|
|||
tl.store(
|
||||
dX_ptr + store_idx + offs_bk[None, :],
|
||||
dX,
|
||||
mask=store_mask,
|
||||
mask = store_mask,
|
||||
)
|
||||
|
||||
# Move to the next tile within this expert group
|
||||
|
|
@ -264,9 +264,9 @@ def _grouped_gemm_dX_kernel(
|
|||
|
||||
|
||||
_autotuned_grouped_gemm_dX_kernel = triton.autotune(
|
||||
configs=get_dX_kernel_configs(),
|
||||
prune_configs_by={"early_config_prune": prune_dX_configs},
|
||||
key=["NUM_EXPERTS", "NUM_TOKENS", "N", "K", "PERMUTE_X", "PERMUTE_Y"],
|
||||
configs = get_dX_kernel_configs(),
|
||||
prune_configs_by = {"early_config_prune": prune_dX_configs},
|
||||
key = ["NUM_EXPERTS", "NUM_TOKENS", "N", "K", "PERMUTE_X", "PERMUTE_Y"],
|
||||
)(_grouped_gemm_dX_kernel)
|
||||
|
||||
"""
|
||||
|
|
@ -324,17 +324,17 @@ def _grouped_gemm_dW_kernel(
|
|||
if USE_TMA_LOAD_dY and not TMA_LOAD_BOTH:
|
||||
dY_desc = tl._experimental_make_tensor_descriptor(
|
||||
dY_ptr,
|
||||
shape=[TOTAL_TOKENS, N],
|
||||
strides=[N, 1],
|
||||
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
||||
shape = [TOTAL_TOKENS, N],
|
||||
strides = [N, 1],
|
||||
block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],
|
||||
)
|
||||
|
||||
if USE_TMA_LOAD_X and not TMA_LOAD_BOTH:
|
||||
x_desc = tl._experimental_make_tensor_descriptor(
|
||||
x_ptr,
|
||||
shape=[TOTAL_TOKENS, K],
|
||||
strides=[K, 1],
|
||||
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
||||
shape = [TOTAL_TOKENS, K],
|
||||
strides = [K, 1],
|
||||
block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],
|
||||
)
|
||||
# Output tiles per expert, since each expert weight matrix is [N, K]
|
||||
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
|
|
@ -351,9 +351,9 @@ def _grouped_gemm_dW_kernel(
|
|||
tl.static_assert(K % BLOCK_SIZE_K == 0, "K must be divisible by BLOCK_SIZE_K")
|
||||
dW_desc = tl._experimental_make_tensor_descriptor(
|
||||
dW_ptr,
|
||||
shape=[NUM_EXPERTS, N, K],
|
||||
strides=[N * K, K, 1],
|
||||
block_shape=[1, BLOCK_SIZE_N, BLOCK_SIZE_K],
|
||||
shape = [NUM_EXPERTS, N, K],
|
||||
strides = [N * K, K, 1],
|
||||
block_shape = [1, BLOCK_SIZE_N, BLOCK_SIZE_K],
|
||||
)
|
||||
|
||||
for tile_idx in range(
|
||||
|
|
@ -377,7 +377,7 @@ def _grouped_gemm_dW_kernel(
|
|||
m_end = 0
|
||||
for expert_idx in range(NUM_EXPERTS):
|
||||
# We need to instantiate a fresh accumulator for each expert
|
||||
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=acc_dtype)
|
||||
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype = acc_dtype)
|
||||
|
||||
m_start = m_end
|
||||
# Need to figure out why this cast is needed, otherwise compiler complains about mismatching types
|
||||
|
|
@ -392,16 +392,16 @@ def _grouped_gemm_dW_kernel(
|
|||
if TMA_LOAD_BOTH:
|
||||
dY_desc = tl._experimental_make_tensor_descriptor(
|
||||
dY_ptr,
|
||||
shape=[m_end, N],
|
||||
strides=[N, 1],
|
||||
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
||||
shape = [m_end, N],
|
||||
strides = [N, 1],
|
||||
block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],
|
||||
)
|
||||
|
||||
x_desc = tl._experimental_make_tensor_descriptor(
|
||||
x_ptr,
|
||||
shape=[m_end, K],
|
||||
strides=[K, 1],
|
||||
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
||||
shape = [m_end, K],
|
||||
strides = [K, 1],
|
||||
block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],
|
||||
)
|
||||
|
||||
for tile_m_idx in range(0, m_size, BLOCK_SIZE_M):
|
||||
|
|
@ -425,7 +425,7 @@ def _grouped_gemm_dW_kernel(
|
|||
# indices_to_gather = m_start + gather_offsets
|
||||
expert_token_idx = tl.load(
|
||||
gather_indices_ptr + indices_to_gather,
|
||||
mask=indices_to_gather < TOTAL_TOKENS,
|
||||
mask = indices_to_gather < TOTAL_TOKENS,
|
||||
)
|
||||
expert_token_offsets = expert_token_idx[:, None]
|
||||
|
||||
|
|
@ -461,7 +461,7 @@ def _grouped_gemm_dW_kernel(
|
|||
x_ptr
|
||||
+ x_row_load_idx
|
||||
+ (k_offset + block_range_k)[None, :],
|
||||
mask=mk_mask,
|
||||
mask = mk_mask,
|
||||
)
|
||||
|
||||
if USE_TMA_LOAD_dY:
|
||||
|
|
@ -471,7 +471,7 @@ def _grouped_gemm_dW_kernel(
|
|||
dY_ptr
|
||||
+ dY_row_load_idx
|
||||
+ (n_offset + block_range_n)[None, :],
|
||||
mask=mn_mask,
|
||||
mask = mn_mask,
|
||||
)
|
||||
|
||||
accumulator += tl.dot(
|
||||
|
|
@ -491,12 +491,12 @@ def _grouped_gemm_dW_kernel(
|
|||
+ store_row_offs[:, None] * K
|
||||
+ (k_offset + block_range_k)[None, :],
|
||||
y,
|
||||
mask=nk_mask,
|
||||
mask = nk_mask,
|
||||
)
|
||||
|
||||
|
||||
_autotuned_grouped_gemm_dW_kernel = triton.autotune(
|
||||
configs=get_dW_kernel_configs(),
|
||||
prune_configs_by={"early_config_prune": prune_kernel_configs_backward_dW},
|
||||
key=["NUM_EXPERTS", "NUM_TOKENS", "N", "K", "PERMUTE_X", "PERMUTE_Y"],
|
||||
configs = get_dW_kernel_configs(),
|
||||
prune_configs_by = {"early_config_prune": prune_kernel_configs_backward_dW},
|
||||
key = ["NUM_EXPERTS", "NUM_TOKENS", "N", "K", "PERMUTE_X", "PERMUTE_Y"],
|
||||
)(_grouped_gemm_dW_kernel)
|
||||
|
|
|
|||
|
|
@ -51,9 +51,9 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe):
|
|||
def __init__(
|
||||
self,
|
||||
config: Llama4TextConfig,
|
||||
overlap_router_shared=False,
|
||||
verbose=False,
|
||||
debug=False,
|
||||
overlap_router_shared = False,
|
||||
verbose = False,
|
||||
debug = False,
|
||||
):
|
||||
super().__init__(config)
|
||||
self.overlap_router_shared = overlap_router_shared
|
||||
|
|
@ -136,7 +136,7 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe):
|
|||
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
||||
router_logits = self.router(hidden_states)
|
||||
routing_weights, selected_experts = torch.topk(
|
||||
router_logits, self.top_k, dim=-1
|
||||
router_logits, self.top_k, dim = -1
|
||||
)
|
||||
|
||||
routing_weights = F.sigmoid(routing_weights.float()).to(hidden_states.dtype)
|
||||
|
|
@ -195,7 +195,7 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe):
|
|||
)
|
||||
|
||||
if self.top_k > 1:
|
||||
hidden_states = hidden_states.sum(dim=1)
|
||||
hidden_states = hidden_states.sum(dim = 1)
|
||||
hidden_states_after_weight_merge = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# 1. Compute tokens per expert and indices for gathering tokes from token order to expert order
|
||||
|
|
@ -212,7 +212,7 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe):
|
|||
|
||||
# Start expert computation
|
||||
first_gemm = torch_grouped_gemm(
|
||||
X=hidden_states, W=self.experts.gate_up_proj, m_sizes=token_counts_by_expert
|
||||
X = hidden_states, W = self.experts.gate_up_proj, m_sizes = token_counts_by_expert
|
||||
)
|
||||
assert first_gemm.shape == (total_tokens, 2 * self.experts.expert_dim)
|
||||
|
||||
|
|
@ -221,7 +221,7 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe):
|
|||
|
||||
# See comment above
|
||||
second_gemm = torch_grouped_gemm(
|
||||
X=intermediate, W=self.experts.down_proj, m_sizes=token_counts_by_expert
|
||||
X = intermediate, W = self.experts.down_proj, m_sizes = token_counts_by_expert
|
||||
)
|
||||
assert second_gemm.shape == (total_tokens, hidden_dim)
|
||||
|
||||
|
|
@ -234,17 +234,17 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe):
|
|||
|
||||
result = (
|
||||
Llama4MoeResult(
|
||||
token_counts_by_expert=token_counts_by_expert,
|
||||
gather_indices=gather_indices,
|
||||
topk_weights=routing_weights,
|
||||
hidden_states_after_weight_merge=hidden_states_after_weight_merge,
|
||||
first_gemm=first_gemm,
|
||||
intermediate=intermediate,
|
||||
second_gemm=second_gemm,
|
||||
hidden_states_unpermute=hidden_states_unpermute,
|
||||
shared_expert_out=shared_expert_out,
|
||||
final_out=final_out,
|
||||
router_logits=router_logits,
|
||||
token_counts_by_expert = token_counts_by_expert,
|
||||
gather_indices = gather_indices,
|
||||
topk_weights = routing_weights,
|
||||
hidden_states_after_weight_merge = hidden_states_after_weight_merge,
|
||||
first_gemm = first_gemm,
|
||||
intermediate = intermediate,
|
||||
second_gemm = second_gemm,
|
||||
hidden_states_unpermute = hidden_states_unpermute,
|
||||
shared_expert_out = shared_expert_out,
|
||||
final_out = final_out,
|
||||
router_logits = router_logits,
|
||||
)
|
||||
if self.debug
|
||||
else (final_out, routing_weights)
|
||||
|
|
@ -257,7 +257,7 @@ class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe):
|
|||
def __init__(
|
||||
self,
|
||||
config: Llama4TextConfig,
|
||||
overlap_router_shared=False,
|
||||
overlap_router_shared = False,
|
||||
permute_x: bool = False,
|
||||
permute_y: bool = True,
|
||||
autotune: bool = True,
|
||||
|
|
@ -266,9 +266,9 @@ class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe):
|
|||
kernel_config_bwd_dX: KernelConfigBackward_dX = None,
|
||||
dW_only: bool = False,
|
||||
dX_only: bool = False,
|
||||
verbose=False,
|
||||
verbose = False,
|
||||
):
|
||||
super().__init__(config, overlap_router_shared=overlap_router_shared)
|
||||
super().__init__(config, overlap_router_shared = overlap_router_shared)
|
||||
assert not permute_x, "Llama4 triton grouped gemm does not support permute x due to pre-multiplication of router weights"
|
||||
self.permute_x = permute_x
|
||||
self.permute_y = permute_y
|
||||
|
|
@ -321,7 +321,7 @@ class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe):
|
|||
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
||||
router_logits = self.router(hidden_states)
|
||||
routing_weights, selected_experts = torch.topk(
|
||||
router_logits, self.top_k, dim=-1
|
||||
router_logits, self.top_k, dim = -1
|
||||
)
|
||||
|
||||
routing_weights = F.sigmoid(routing_weights.float()).to(hidden_states.dtype)
|
||||
|
|
@ -380,7 +380,7 @@ class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe):
|
|||
)
|
||||
|
||||
if self.top_k > 1:
|
||||
hidden_states = hidden_states.sum(dim=1)
|
||||
hidden_states = hidden_states.sum(dim = 1)
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# 1. Compute tokens per expert and indices for gathering tokes from token order to expert order
|
||||
|
|
@ -395,37 +395,37 @@ class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe):
|
|||
|
||||
# Start expert computation
|
||||
hidden_states = grouped_gemm(
|
||||
X=hidden_states,
|
||||
W=self.experts.gate_up_proj,
|
||||
m_sizes=token_counts_by_expert,
|
||||
gather_indices=gather_indices,
|
||||
topk=self.top_k,
|
||||
permute_x=self.permute_x,
|
||||
permute_y=False, # output of first grouped gemm should never be permuted
|
||||
autotune=self.autotune,
|
||||
kernel_config_fwd=self.kernel_config_fwd,
|
||||
kernel_config_bwd_dW=self.kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX=self.kernel_config_bwd_dX,
|
||||
is_first_gemm=True,
|
||||
dW_only=self.dW_only,
|
||||
dX_only=self.dX_only,
|
||||
X = hidden_states,
|
||||
W = self.experts.gate_up_proj,
|
||||
m_sizes = token_counts_by_expert,
|
||||
gather_indices = gather_indices,
|
||||
topk = self.top_k,
|
||||
permute_x = self.permute_x,
|
||||
permute_y = False, # output of first grouped gemm should never be permuted
|
||||
autotune = self.autotune,
|
||||
kernel_config_fwd = self.kernel_config_fwd,
|
||||
kernel_config_bwd_dW = self.kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX = self.kernel_config_bwd_dX,
|
||||
is_first_gemm = True,
|
||||
dW_only = self.dW_only,
|
||||
dX_only = self.dX_only,
|
||||
)
|
||||
hidden_states = self.act_and_mul(hidden_states)
|
||||
hidden_states = grouped_gemm(
|
||||
X=hidden_states,
|
||||
W=self.experts.down_proj,
|
||||
m_sizes=token_counts_by_expert,
|
||||
gather_indices=gather_indices,
|
||||
topk=self.top_k,
|
||||
permute_x=False,
|
||||
permute_y=self.permute_y,
|
||||
autotune=self.autotune,
|
||||
kernel_config_fwd=self.kernel_config_fwd,
|
||||
kernel_config_bwd_dW=self.kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX=self.kernel_config_bwd_dX,
|
||||
is_first_gemm=False,
|
||||
dW_only=self.dW_only,
|
||||
dX_only=self.dX_only,
|
||||
X = hidden_states,
|
||||
W = self.experts.down_proj,
|
||||
m_sizes = token_counts_by_expert,
|
||||
gather_indices = gather_indices,
|
||||
topk = self.top_k,
|
||||
permute_x = False,
|
||||
permute_y = self.permute_y,
|
||||
autotune = self.autotune,
|
||||
kernel_config_fwd = self.kernel_config_fwd,
|
||||
kernel_config_bwd_dW = self.kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX = self.kernel_config_bwd_dX,
|
||||
is_first_gemm = False,
|
||||
dW_only = self.dW_only,
|
||||
dX_only = self.dX_only,
|
||||
)
|
||||
|
||||
# Post-processing
|
||||
|
|
|
|||
|
|
@ -80,14 +80,14 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
|
|||
gate,
|
||||
gate_up_proj,
|
||||
down_proj,
|
||||
permute_x=permute_x,
|
||||
permute_y=permute_y,
|
||||
autotune=autotune,
|
||||
kernel_config_fwd=kernel_config_fwd,
|
||||
kernel_config_bwd_dW=kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX=kernel_config_bwd_dX,
|
||||
dW_only=dW_only,
|
||||
dX_only=dX_only,
|
||||
permute_x = permute_x,
|
||||
permute_y = permute_y,
|
||||
autotune = autotune,
|
||||
kernel_config_fwd = kernel_config_fwd,
|
||||
kernel_config_bwd_dW = kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX = kernel_config_bwd_dX,
|
||||
dW_only = dW_only,
|
||||
dX_only = dX_only,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
|
@ -112,37 +112,37 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
|
|||
hidden_states = permute(hidden_states, gather_indices, self.top_k)
|
||||
# Start expert computation
|
||||
hidden_states = grouped_gemm(
|
||||
X=hidden_states,
|
||||
W=self.gate_up_proj,
|
||||
m_sizes=token_counts_by_expert,
|
||||
gather_indices=gather_indices,
|
||||
topk=self.top_k,
|
||||
permute_x=self.permute_x,
|
||||
permute_y=False, # output of first grouped gemm should never be permuted
|
||||
autotune=self.autotune,
|
||||
kernel_config_fwd=self.kernel_config_fwd,
|
||||
kernel_config_bwd_dW=self.kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX=self.kernel_config_bwd_dX,
|
||||
is_first_gemm=True,
|
||||
dW_only=self.dW_only,
|
||||
dX_only=self.dX_only,
|
||||
X = hidden_states,
|
||||
W = self.gate_up_proj,
|
||||
m_sizes = token_counts_by_expert,
|
||||
gather_indices = gather_indices,
|
||||
topk = self.top_k,
|
||||
permute_x = self.permute_x,
|
||||
permute_y = False, # output of first grouped gemm should never be permuted
|
||||
autotune = self.autotune,
|
||||
kernel_config_fwd = self.kernel_config_fwd,
|
||||
kernel_config_bwd_dW = self.kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX = self.kernel_config_bwd_dX,
|
||||
is_first_gemm = True,
|
||||
dW_only = self.dW_only,
|
||||
dX_only = self.dX_only,
|
||||
)
|
||||
hidden_states = self.act_and_mul(hidden_states)
|
||||
hidden_states = grouped_gemm(
|
||||
X=hidden_states,
|
||||
W=self.down_proj,
|
||||
m_sizes=token_counts_by_expert,
|
||||
gather_indices=gather_indices,
|
||||
topk=self.top_k,
|
||||
permute_x=False,
|
||||
permute_y=self.permute_y,
|
||||
autotune=self.autotune,
|
||||
kernel_config_fwd=self.kernel_config_fwd,
|
||||
kernel_config_bwd_dW=self.kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX=self.kernel_config_bwd_dX,
|
||||
is_first_gemm=False,
|
||||
dW_only=self.dW_only,
|
||||
dX_only=self.dX_only,
|
||||
X = hidden_states,
|
||||
W = self.down_proj,
|
||||
m_sizes = token_counts_by_expert,
|
||||
gather_indices = gather_indices,
|
||||
topk = self.top_k,
|
||||
permute_x = False,
|
||||
permute_y = self.permute_y,
|
||||
autotune = self.autotune,
|
||||
kernel_config_fwd = self.kernel_config_fwd,
|
||||
kernel_config_bwd_dW = self.kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX = self.kernel_config_bwd_dX,
|
||||
is_first_gemm = False,
|
||||
dW_only = self.dW_only,
|
||||
dX_only = self.dX_only,
|
||||
)
|
||||
|
||||
# Post-processing
|
||||
|
|
@ -155,7 +155,7 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
|
|||
hidden_states.view(num_tokens, self.top_k, hidden_dim)
|
||||
* routing_weights[..., None]
|
||||
)
|
||||
hidden_states = hidden_states.sum(dim=1)
|
||||
hidden_states = hidden_states.sum(dim = 1)
|
||||
|
||||
hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)
|
||||
return hidden_states, router_logits
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ def calculate_topk(
|
|||
gating_output.dtype
|
||||
)
|
||||
else:
|
||||
scores = F.softmax(gating_output.to(torch.float32), dim=1).to(
|
||||
scores = F.softmax(gating_output.to(torch.float32), dim = 1).to(
|
||||
gating_output.dtype
|
||||
)
|
||||
|
||||
|
|
@ -67,13 +67,13 @@ def calculate_topk(
|
|||
else:
|
||||
scores = gating_output
|
||||
|
||||
topk_weights, topk_ids = torch.topk(scores, k=top_k, dim=1)
|
||||
topk_weights, topk_ids = torch.topk(scores, k = top_k, dim = 1)
|
||||
|
||||
if post_act:
|
||||
topk_weights = _activation(topk_weights)
|
||||
|
||||
if renormalize:
|
||||
topk_weights /= torch.sum(topk_weights, dim=-1, keepdim=True).to(
|
||||
topk_weights /= torch.sum(topk_weights, dim = -1, keepdim = True).to(
|
||||
gating_output.dtype
|
||||
)
|
||||
|
||||
|
|
@ -94,12 +94,12 @@ def get_routing_indices(
|
|||
# group tokens together by expert indices from 0 to num_experts and pass that to experts forward
|
||||
token_counts_by_expert = torch.histc(
|
||||
selected_experts.view(-1),
|
||||
bins=num_experts,
|
||||
min=0,
|
||||
max=num_experts,
|
||||
bins = num_experts,
|
||||
min = 0,
|
||||
max = num_experts,
|
||||
)
|
||||
# token_indices_experts_sorted shape (bs*slen*top_k,)
|
||||
gather_indices = torch.argsort(selected_experts.view(-1), stable=True)
|
||||
gather_indices = torch.argsort(selected_experts.view(-1), stable = True)
|
||||
if return_scatter_indices:
|
||||
scatter_indices = gather_indices.argsort()
|
||||
return token_counts_by_expert, gather_indices, scatter_indices
|
||||
|
|
@ -107,7 +107,7 @@ def get_routing_indices(
|
|||
return token_counts_by_expert, gather_indices
|
||||
|
||||
|
||||
def torch_grouped_gemm(X, W, m_sizes, transpose=True):
|
||||
def torch_grouped_gemm(X, W, m_sizes, transpose = True):
|
||||
"""
|
||||
X: [M, K] if forward, else [M, N]
|
||||
W: [E, N, K]
|
||||
|
|
@ -127,7 +127,7 @@ def torch_grouped_gemm(X, W, m_sizes, transpose=True):
|
|||
|
||||
N = W.shape[1]
|
||||
|
||||
result = torch.zeros((M, N), dtype=X.dtype, device=X.device)
|
||||
result = torch.zeros((M, N), dtype = X.dtype, device = X.device)
|
||||
|
||||
m_start = 0
|
||||
for g in range(E):
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ NUM_AUTOTUNE_CONFIGS = 50
|
|||
|
||||
|
||||
@contextmanager
|
||||
def annotated_context(prelude, epilogue="Passed!", char="-", num_chars=80):
|
||||
def annotated_context(prelude, epilogue = "Passed!", char = "-", num_chars = 80):
|
||||
print(char * num_chars)
|
||||
print(prelude)
|
||||
yield
|
||||
|
|
@ -81,7 +81,7 @@ def prep_triton_kernel_traits(autotune):
|
|||
|
||||
|
||||
def sparse_to_dense(t: torch.Tensor):
|
||||
t = t.sum(dim=0).view(-1)
|
||||
t = t.sum(dim = 0).view(-1)
|
||||
return t
|
||||
|
||||
|
||||
|
|
@ -91,9 +91,9 @@ def _check_diff(
|
|||
t2: torch.Tensor,
|
||||
atol,
|
||||
rtol,
|
||||
precision=".6f",
|
||||
verbose=False,
|
||||
msg="",
|
||||
precision = ".6f",
|
||||
verbose = False,
|
||||
msg = "",
|
||||
):
|
||||
t2 = t2.view_as(t1)
|
||||
diff = t1.sub(t2).abs().max().item()
|
||||
|
|
@ -101,7 +101,7 @@ def _check_diff(
|
|||
if msg == "":
|
||||
msg = "diff"
|
||||
print(f"{msg}: {diff:{precision}}")
|
||||
assert torch.allclose(t1, t2, atol=atol, rtol=rtol)
|
||||
assert torch.allclose(t1, t2, atol = atol, rtol = rtol)
|
||||
|
||||
|
||||
def run_backwards(y: torch.Tensor, grad_output: torch.Tensor, module: torch.nn.Module):
|
||||
|
|
@ -115,19 +115,19 @@ def _check_grads(
|
|||
m2: torch.nn.Module,
|
||||
atol,
|
||||
rtol,
|
||||
precision=".6f",
|
||||
verbose=False,
|
||||
msg="",
|
||||
precision = ".6f",
|
||||
verbose = False,
|
||||
msg = "",
|
||||
):
|
||||
for name, param in m1.named_parameters():
|
||||
_check_diff(
|
||||
param.grad,
|
||||
m2.get_parameter(name).grad,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
precision=precision,
|
||||
verbose=verbose,
|
||||
msg=f"{msg}:{name}.grad",
|
||||
atol = atol,
|
||||
rtol = rtol,
|
||||
precision = precision,
|
||||
verbose = verbose,
|
||||
msg = f"{msg}:{name}.grad",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -139,19 +139,19 @@ def model_config():
|
|||
@pytest.mark.parametrize(
|
||||
"overlap_router_shared",
|
||||
[False, True],
|
||||
ids=lambda x: "overlap_router_shared" if x else "no_overlap",
|
||||
ids = lambda x: "overlap_router_shared" if x else "no_overlap",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"permute_y", [False, True], ids=lambda x: "permute_y" if x else "no_permute_y"
|
||||
"permute_y", [False, True], ids = lambda x: "permute_y" if x else "no_permute_y"
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"permute_x", [False], ids=lambda x: "permute_x" if x else "no_permute_x"
|
||||
"permute_x", [False], ids = lambda x: "permute_x" if x else "no_permute_x"
|
||||
) # Llama4 does not support permute_x
|
||||
@pytest.mark.parametrize(
|
||||
"autotune", [True], ids=lambda x: "autotune" if x else "manual"
|
||||
"autotune", [True], ids = lambda x: "autotune" if x else "manual"
|
||||
)
|
||||
@pytest.mark.parametrize("seqlen", SEQ_LENS, ids=lambda x: f"seqlen={x}")
|
||||
@pytest.mark.parametrize("dtype", DTYPES, ids=str)
|
||||
@pytest.mark.parametrize("seqlen", SEQ_LENS, ids = lambda x: f"seqlen={x}")
|
||||
@pytest.mark.parametrize("dtype", DTYPES, ids = str)
|
||||
def test_llama4_ref(
|
||||
dtype: torch.dtype,
|
||||
seqlen,
|
||||
|
|
@ -161,9 +161,9 @@ def test_llama4_ref(
|
|||
overlap_router_shared: bool,
|
||||
model_config: Llama4TextConfig, # test fixture
|
||||
bs: int = 1,
|
||||
device="cuda",
|
||||
precision=".6f",
|
||||
verbose=False,
|
||||
device = "cuda",
|
||||
precision = ".6f",
|
||||
verbose = False,
|
||||
):
|
||||
torch.manual_seed(
|
||||
SEED
|
||||
|
|
@ -172,24 +172,24 @@ def test_llama4_ref(
|
|||
hidden_dim = model_config.hidden_size
|
||||
atol, rtol = TOLERANCES[dtype]
|
||||
check_diff = partial(
|
||||
_check_diff, atol=atol, rtol=rtol, precision=precision, verbose=verbose
|
||||
_check_diff, atol = atol, rtol = rtol, precision = precision, verbose = verbose
|
||||
)
|
||||
check_grads = partial(
|
||||
_check_grads, atol=atol, rtol=rtol, precision=precision, verbose=verbose
|
||||
_check_grads, atol = atol, rtol = rtol, precision = precision, verbose = verbose
|
||||
)
|
||||
|
||||
# Reference op -- HF
|
||||
llama4_ref = Llama4TextMoe(model_config).to(dtype=dtype, device=device)
|
||||
llama4_ref = Llama4TextMoe(model_config).to(dtype = dtype, device = device)
|
||||
|
||||
# Torch grouped gemm impl
|
||||
llama4_gg_ref = Llama4GroupedGemmTextMoe(
|
||||
model_config, overlap_router_shared=overlap_router_shared
|
||||
).to(dtype=dtype, device=device)
|
||||
model_config, overlap_router_shared = overlap_router_shared
|
||||
).to(dtype = dtype, device = device)
|
||||
llama4_gg_ref.copy_weights(llama4_ref)
|
||||
llama4_gg_ref.check_weights(llama4_ref)
|
||||
|
||||
x_ref = torch.randn(
|
||||
bs, seqlen, hidden_dim, dtype=dtype, device=device, requires_grad=True
|
||||
bs, seqlen, hidden_dim, dtype = dtype, device = device, requires_grad = True
|
||||
)
|
||||
x_torch_gg = x_ref.detach().clone().requires_grad_()
|
||||
x_triton = x_ref.detach().clone().requires_grad_()
|
||||
|
|
@ -198,9 +198,9 @@ def test_llama4_ref(
|
|||
y_torch_gg, routing_torch_gg = llama4_gg_ref(x_torch_gg)
|
||||
assert y_ref.shape == y_torch_gg.shape, f"{y_ref.shape} != {y_torch_gg.shape}"
|
||||
with annotated_context("Testing torch grouped gemm Llama4TextMoe"):
|
||||
check_diff(y_ref, y_torch_gg, msg="y_torch_gg")
|
||||
check_diff(y_ref, y_torch_gg, msg = "y_torch_gg")
|
||||
check_diff(
|
||||
sparse_to_dense(routing_ref), routing_torch_gg, msg="routing_torch_gg"
|
||||
sparse_to_dense(routing_ref), routing_torch_gg, msg = "routing_torch_gg"
|
||||
)
|
||||
|
||||
kernel_config_fwd, kernel_config_bwd_dW, kernel_config_bwd_dX = (
|
||||
|
|
@ -209,38 +209,38 @@ def test_llama4_ref(
|
|||
|
||||
llama4_triton = Llama4TritonTextMoe(
|
||||
model_config,
|
||||
overlap_router_shared=overlap_router_shared,
|
||||
permute_x=permute_x,
|
||||
permute_y=permute_y,
|
||||
autotune=autotune,
|
||||
kernel_config_fwd=kernel_config_fwd,
|
||||
kernel_config_bwd_dW=kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX=kernel_config_bwd_dX,
|
||||
).to(device=device, dtype=dtype)
|
||||
overlap_router_shared = overlap_router_shared,
|
||||
permute_x = permute_x,
|
||||
permute_y = permute_y,
|
||||
autotune = autotune,
|
||||
kernel_config_fwd = kernel_config_fwd,
|
||||
kernel_config_bwd_dW = kernel_config_bwd_dW,
|
||||
kernel_config_bwd_dX = kernel_config_bwd_dX,
|
||||
).to(device = device, dtype = dtype)
|
||||
llama4_triton.copy_weights(llama4_ref)
|
||||
llama4_triton.check_weights(llama4_ref)
|
||||
|
||||
y_triton, routing_triton = llama4_triton(x_triton)
|
||||
with annotated_context("Testing triton grouped gemm Llama4TextMoe forward"):
|
||||
check_diff(y_ref, y_triton, msg="y_triton")
|
||||
check_diff(sparse_to_dense(routing_ref), routing_triton, msg="routing_triton")
|
||||
check_diff(y_ref, y_triton, msg = "y_triton")
|
||||
check_diff(sparse_to_dense(routing_ref), routing_triton, msg = "routing_triton")
|
||||
|
||||
ref_grad = torch.randn_like(y_ref)
|
||||
run_backwards(y_ref, ref_grad, llama4_ref)
|
||||
run_backwards(y_torch_gg, ref_grad, llama4_gg_ref)
|
||||
with annotated_context("Testing torch group gemm Llama4TextMoe backward"):
|
||||
check_grads(llama4_ref, llama4_gg_ref, msg="torch_gg")
|
||||
check_grads(llama4_ref, llama4_gg_ref, msg = "torch_gg")
|
||||
|
||||
run_backwards(y_triton, ref_grad, llama4_triton)
|
||||
with annotated_context("Testing triton group gemm Llama4TextMoe backward"):
|
||||
check_grads(llama4_ref, llama4_triton, msg="triton")
|
||||
check_grads(llama4_ref, llama4_triton, msg = "triton")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--seqlen", type=int, default=1024)
|
||||
parser.add_argument("--seqlen", type = int, default = 1024)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["bfloat16", "float16"], default="bfloat16"
|
||||
"--dtype", type = str, choices = ["bfloat16", "float16"], default = "bfloat16"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args.dtype = getattr(torch, args.dtype)
|
||||
|
|
@ -251,12 +251,12 @@ if __name__ == "__main__":
|
|||
text_config: Llama4TextConfig = get_text_config(model_id)
|
||||
for overlap in [False, True]:
|
||||
test_llama4_ref(
|
||||
seqlen=args.seqlen,
|
||||
model_config=text_config,
|
||||
dtype=args.dtype,
|
||||
autotune=True,
|
||||
permute_x=False,
|
||||
permute_y=True,
|
||||
overlap_router_shared=overlap,
|
||||
verbose=True,
|
||||
seqlen = args.seqlen,
|
||||
model_config = text_config,
|
||||
dtype = args.dtype,
|
||||
autotune = True,
|
||||
permute_x = False,
|
||||
permute_y = True,
|
||||
overlap_router_shared = overlap,
|
||||
verbose = True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -45,16 +45,16 @@ def _rms_layernorm_forward(
|
|||
X += row_idx * X_row_stride
|
||||
r += row_idx * r_row_stride
|
||||
|
||||
X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||
W_row = tl.load(W + col_offsets, mask=mask, other=0) # .to(tl.float32)
|
||||
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
W_row = tl.load(W + col_offsets, mask = mask, other = 0) # .to(tl.float32)
|
||||
|
||||
row_var = tl.sum(X_row * X_row, axis=0) / n_cols
|
||||
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
|
||||
inv_var = tl.math.rsqrt(row_var + eps)
|
||||
tl.store(r, inv_var)
|
||||
normed = X_row * inv_var
|
||||
normed = normed.to(W_row.dtype) # Exact copy from HF
|
||||
output = normed * W_row
|
||||
tl.store(Y + col_offsets, output, mask=mask)
|
||||
tl.store(Y + col_offsets, output, mask = mask)
|
||||
|
||||
|
||||
def _rms_layernorm_backward(
|
||||
|
|
@ -92,9 +92,9 @@ def _rms_layernorm_backward(
|
|||
else:
|
||||
dX = dY
|
||||
|
||||
dY_row = tl.load(dY + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||
X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||
W_row = tl.load(W + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||
dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
|
||||
# Get saved row variance
|
||||
inv_var = tl.load(r).to(tl.float32)
|
||||
|
|
@ -105,9 +105,9 @@ def _rms_layernorm_backward(
|
|||
else:
|
||||
dY_W = dY_row * W_row
|
||||
|
||||
rowsum_dY_normed = tl.sum(dY_W * normed, axis=0)
|
||||
rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
|
||||
output = inv_var / n_cols * (n_cols * dY_W - normed * rowsum_dY_normed)
|
||||
tl.store(dX + col_offsets, output, mask=mask)
|
||||
tl.store(dX + col_offsets, output, mask = mask)
|
||||
|
||||
|
||||
_rms_layernorm_backward = triton.jit(_rms_layernorm_backward)
|
||||
|
|
@ -143,16 +143,16 @@ def _gemma_rms_layernorm_forward(
|
|||
X += row_idx * X_row_stride
|
||||
r += row_idx * r_row_stride
|
||||
|
||||
X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||
W_row = tl.load(W + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
|
||||
row_var = tl.sum(X_row * X_row, axis=0) / n_cols
|
||||
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
|
||||
inv_var = tl.math.rsqrt(row_var + eps)
|
||||
tl.store(r, inv_var)
|
||||
normed = X_row * inv_var
|
||||
output = normed * (W_row + 1.0)
|
||||
|
||||
tl.store(Y + col_offsets, output, mask=mask)
|
||||
tl.store(Y + col_offsets, output, mask = mask)
|
||||
|
||||
|
||||
class Fast_RMS_Layernorm(torch.autograd.Function):
|
||||
|
|
@ -169,8 +169,8 @@ class Fast_RMS_Layernorm(torch.autograd.Function):
|
|||
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||
device = X.device
|
||||
|
||||
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=device)
|
||||
r = torch.empty(n_rows, dtype=torch.float32, device=device)
|
||||
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)
|
||||
r = torch.empty(n_rows, dtype = torch.float32, device = device)
|
||||
|
||||
fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
|
||||
with torch_gpu_device(device):
|
||||
|
|
@ -185,8 +185,8 @@ class Fast_RMS_Layernorm(torch.autograd.Function):
|
|||
r.stride(0),
|
||||
n_cols,
|
||||
eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
num_warps = num_warps,
|
||||
)
|
||||
ctx.eps = eps
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
|
|
@ -222,9 +222,9 @@ class Fast_RMS_Layernorm(torch.autograd.Function):
|
|||
# dW, dW.stride(0),
|
||||
n_cols,
|
||||
ctx.eps,
|
||||
GEMMA=ctx.GEMMA,
|
||||
BLOCK_SIZE=ctx.BLOCK_SIZE,
|
||||
num_warps=ctx.num_warps,
|
||||
GEMMA = ctx.GEMMA,
|
||||
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
||||
num_warps = ctx.num_warps,
|
||||
)
|
||||
dX = dX.view(*shape)
|
||||
return dX, None, None, None
|
||||
|
|
@ -248,7 +248,7 @@ from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
|||
|
||||
class Unsloth_LlamaRMSNorm(LlamaRMSNorm):
|
||||
def forward(self, X):
|
||||
return fast_rms_layernorm(self, X, gemma=False)
|
||||
return fast_rms_layernorm(self, X, gemma = False)
|
||||
|
||||
|
||||
try:
|
||||
|
|
@ -256,7 +256,7 @@ try:
|
|||
|
||||
class Unsloth_MllamaTextRMSNorm(MllamaTextRMSNorm):
|
||||
def forward(self, X):
|
||||
return fast_rms_layernorm(self, X, gemma=False)
|
||||
return fast_rms_layernorm(self, X, gemma = False)
|
||||
|
||||
|
||||
except:
|
||||
|
|
@ -292,25 +292,25 @@ def unpatch_rms_layernorm():
|
|||
|
||||
|
||||
def test_rms_layernorm(
|
||||
dim=1024,
|
||||
eps=1e-5,
|
||||
dtype=torch.float16,
|
||||
bsz=21,
|
||||
random_state=3407,
|
||||
seqlen=3341,
|
||||
dim = 1024,
|
||||
eps = 1e-5,
|
||||
dtype = torch.float16,
|
||||
bsz = 21,
|
||||
random_state = 3407,
|
||||
seqlen = 3341,
|
||||
):
|
||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||
|
||||
layernorm = LlamaRMSNorm((dim,), eps=eps).to("cuda")
|
||||
layernorm = LlamaRMSNorm((dim,), eps = eps).to("cuda")
|
||||
torch.cuda.manual_seed(random_state)
|
||||
torch.manual_seed(random_state)
|
||||
torch.nn.init.uniform_(layernorm.weight)
|
||||
X = torch.randn((bsz, seqlen, dim), dtype=dtype, device="cuda")
|
||||
X = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda")
|
||||
XX = X.clone()
|
||||
X.requires_grad_(True)
|
||||
XX.requires_grad_(True)
|
||||
Y = layernorm(X)
|
||||
YY = torch.randn((bsz, seqlen, dim), dtype=dtype, device="cuda", requires_grad=True)
|
||||
YY = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda", requires_grad = True)
|
||||
Y.backward(YY)
|
||||
correct_grad = X.grad.clone()
|
||||
# from unsloth.kernels import fast_rms_layernorm
|
||||
|
|
@ -322,14 +322,14 @@ def test_rms_layernorm(
|
|||
def testing_suite_layernorm():
|
||||
for dim in [512, 1024, 2048]:
|
||||
for dtype in [torch.float16, torch.bfloat16]:
|
||||
with torch.autocast(device_type="cuda", dtype=dtype):
|
||||
with torch.autocast(device_type = "cuda", dtype = dtype):
|
||||
for seqlen in [3341, 2048, 349]:
|
||||
for random_state in [3407, 42]:
|
||||
test_rms_layernorm(
|
||||
dim=dim,
|
||||
eps=1e-5,
|
||||
dtype=dtype,
|
||||
bsz=21,
|
||||
random_state=random_state,
|
||||
seqlen=seqlen,
|
||||
dim = dim,
|
||||
eps = 1e-5,
|
||||
dtype = dtype,
|
||||
bsz = 21,
|
||||
random_state = random_state,
|
||||
seqlen = seqlen,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -50,16 +50,16 @@ def _rope_embedding(
|
|||
+ (row_position % seqlen) * sin_row_stride
|
||||
+ half_head_dim * 0
|
||||
+ col_offsets,
|
||||
mask=mask,
|
||||
other=0,
|
||||
mask = mask,
|
||||
other = 0,
|
||||
)
|
||||
cos1 = tl.load(
|
||||
cos
|
||||
+ (row_position % seqlen) * cos_row_stride
|
||||
+ half_head_dim * 0
|
||||
+ col_offsets,
|
||||
mask=mask,
|
||||
other=0,
|
||||
mask = mask,
|
||||
other = 0,
|
||||
)
|
||||
|
||||
if BACKWARD_PASS:
|
||||
|
|
@ -78,11 +78,11 @@ def _rope_embedding(
|
|||
)
|
||||
|
||||
# For Gemma - sometimes RoPE must be done in float32 and not bfloat16
|
||||
Q1 = tl.load(Q + offs_q1, mask=mask, other=0).to(sin1.dtype)
|
||||
Q2 = tl.load(Q + offs_q2, mask=mask, other=0).to(sin1.dtype)
|
||||
Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)
|
||||
Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)
|
||||
|
||||
tl.store(Q + offs_q1, Q1 * cos1 - Q2 * sin1, mask=mask)
|
||||
tl.store(Q + offs_q2, Q2 * cos1 + Q1 * sin1, mask=mask)
|
||||
tl.store(Q + offs_q1, Q1 * cos1 - Q2 * sin1, mask = mask)
|
||||
tl.store(Q + offs_q2, Q2 * cos1 + Q1 * sin1, mask = mask)
|
||||
|
||||
|
||||
_rope_embedding = triton.jit(_rope_embedding)
|
||||
|
|
@ -134,9 +134,9 @@ class Fast_RoPE_Embedding(torch.autograd.Function):
|
|||
seq_len,
|
||||
head_dim,
|
||||
n_heads,
|
||||
BACKWARD_PASS=False,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
BACKWARD_PASS = False,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
num_warps = num_warps,
|
||||
)
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
|
|
@ -177,9 +177,9 @@ class Fast_RoPE_Embedding(torch.autograd.Function):
|
|||
seq_len,
|
||||
head_dim,
|
||||
n_heads,
|
||||
BACKWARD_PASS=True,
|
||||
BLOCK_SIZE=ctx.BLOCK_SIZE,
|
||||
num_warps=ctx.num_warps,
|
||||
BACKWARD_PASS = True,
|
||||
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
||||
num_warps = ctx.num_warps,
|
||||
)
|
||||
dY = dY.view(batch, seq_len, n_heads, head_dim)
|
||||
return (
|
||||
|
|
@ -211,7 +211,7 @@ class Slow_RoPE_Embedding(torch.autograd.Function):
|
|||
|
||||
# Q * cos + rotate_half(Q) * sin
|
||||
half = Q.shape[-1] // 2
|
||||
RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim=-1)
|
||||
RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)
|
||||
Q *= cos
|
||||
Q.addcmul_(RH_Q, sin)
|
||||
# RH_Q *= sin
|
||||
|
|
@ -224,7 +224,7 @@ class Slow_RoPE_Embedding(torch.autograd.Function):
|
|||
cos, sin = ctx.saved_tensors
|
||||
# Q * cos + rotate_half.T(Q) * sin
|
||||
half = dY.shape[-1] // 2
|
||||
RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim=-1)
|
||||
RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
|
||||
dY *= cos
|
||||
dY.addcmul_(RH_dY, sin)
|
||||
# RH_dY *= sin
|
||||
|
|
|
|||
|
|
@ -43,8 +43,8 @@ def _fg_kernel(
|
|||
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)
|
||||
g_row = tl.load(g + offsets, mask=mask, other=0) # .to(tl.float32)
|
||||
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
||||
g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32)
|
||||
|
||||
# f = e * sigmoid(e)
|
||||
f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row))
|
||||
|
|
@ -53,13 +53,13 @@ def _fg_kernel(
|
|||
h_row = f_row * g_row
|
||||
|
||||
# Store h
|
||||
tl.store(h + offsets, h_row, mask=mask)
|
||||
tl.store(h + offsets, h_row, mask = mask)
|
||||
|
||||
|
||||
def swiglu_fg_kernel(e, g):
|
||||
batch, seq_len, hd = e.shape
|
||||
n_elements = e.numel()
|
||||
h = torch.empty((batch, seq_len, hd), dtype=e.dtype, device=e.device)
|
||||
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = e.device)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
with torch_gpu_device(e.device):
|
||||
_fg_kernel[grid](
|
||||
|
|
@ -67,8 +67,8 @@ def swiglu_fg_kernel(e, g):
|
|||
g,
|
||||
h,
|
||||
n_elements,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
LONG_INDEXING=0 if n_elements <= INT32_SAFETY_BUFFER else 1,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
|
||||
)
|
||||
return h
|
||||
|
||||
|
|
@ -101,9 +101,9 @@ def _DWf_DW_dfg_kernel(
|
|||
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
DW_row = tl.load(DW + offsets, mask=mask, other=0) # .to(tl.float32)
|
||||
e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)
|
||||
g_row = tl.load(g + offsets, mask=mask, other=0) # .to(tl.float32)
|
||||
DW_row = tl.load(DW + offsets, mask = mask, other = 0) # .to(tl.float32)
|
||||
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
||||
g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32)
|
||||
|
||||
# e = e.float()
|
||||
# se = 1.0 / (1.0 + torch.exp(-e))
|
||||
|
|
@ -122,9 +122,9 @@ def _DWf_DW_dfg_kernel(
|
|||
de_row = de_row.to(DW_row.dtype)
|
||||
|
||||
# Store derivatives in buffers
|
||||
tl.store(DW + offsets, h_row, mask=mask) # h = f * g
|
||||
tl.store(e + offsets, df_row, mask=mask) # df = DW * f
|
||||
tl.store(g + offsets, de_row, mask=mask) # de
|
||||
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
|
||||
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
|
||||
tl.store(g + offsets, de_row, mask = mask) # de
|
||||
|
||||
|
||||
def swiglu_DWf_DW_dfg_kernel(DW, e, g):
|
||||
|
|
@ -137,7 +137,7 @@ def swiglu_DWf_DW_dfg_kernel(DW, e, g):
|
|||
e,
|
||||
g,
|
||||
n_elements,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
LONG_INDEXING=0 if n_elements <= INT32_SAFETY_BUFFER else 1,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
|
||||
)
|
||||
return DW, e, g
|
||||
|
|
|
|||
|
|
@ -47,12 +47,12 @@ if Version(torch.__version__) < Version("2.4.0"):
|
|||
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
|
||||
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
|
||||
else:
|
||||
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
|
||||
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
|
||||
|
||||
if DEVICE_TYPE == "xpu":
|
||||
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="xpu")
|
||||
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="xpu")
|
||||
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "xpu")
|
||||
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "xpu")
|
||||
|
||||
|
||||
# tl.math.tanh now is libdevice.tanh
|
||||
|
|
@ -349,7 +349,7 @@ def _maybe_fake_quantize_activations(
|
|||
if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
|
||||
|
||||
@torch.inference_mode
|
||||
def fast_dequantize(W, quant_state=None, out=None, use_global_buffer=False):
|
||||
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
|
||||
# TODO: After adding XPU BNB support, check this function
|
||||
if isinstance(W, Float8Tensor):
|
||||
return W.dequantize()
|
||||
|
|
@ -390,13 +390,13 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
|
|||
ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index]
|
||||
if WEIGHT_BUFFER is None:
|
||||
WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(
|
||||
size, dtype=dtype, device=device, requires_grad=False
|
||||
size, dtype = dtype, device = device, requires_grad = False
|
||||
)
|
||||
ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(
|
||||
n_elements_absmax,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
requires_grad=False,
|
||||
dtype = torch.float32,
|
||||
device = device,
|
||||
requires_grad = False,
|
||||
)
|
||||
|
||||
if size > WEIGHT_BUFFER.numel():
|
||||
|
|
@ -409,16 +409,16 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
|
|||
else:
|
||||
if out is None:
|
||||
out = torch_empty(
|
||||
shape, dtype=dtype, device=device, requires_grad=False
|
||||
shape, dtype = dtype, device = device, requires_grad = False
|
||||
)
|
||||
else:
|
||||
assert out.shape == shape
|
||||
assert out.dtype == dtype
|
||||
out_absmax = torch_empty(
|
||||
n_elements_absmax,
|
||||
dtype=torch_float32,
|
||||
device=device,
|
||||
requires_grad=False,
|
||||
dtype = torch_float32,
|
||||
device = device,
|
||||
requires_grad = False,
|
||||
)
|
||||
|
||||
# NF4 dequantization of statistics
|
||||
|
|
@ -458,7 +458,7 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
|
|||
elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
|
||||
|
||||
@torch.inference_mode
|
||||
def fast_dequantize(W, quant_state=None, out=None, use_global_buffer=False):
|
||||
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
|
||||
if isinstance(W, Float8Tensor):
|
||||
return W.dequantize()
|
||||
if quant_state is None:
|
||||
|
|
@ -500,13 +500,13 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
|
|||
ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index]
|
||||
if WEIGHT_BUFFER is None:
|
||||
WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(
|
||||
size, dtype=dtype, device=device, requires_grad=False
|
||||
size, dtype = dtype, device = device, requires_grad = False
|
||||
)
|
||||
ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(
|
||||
n_elements_absmax,
|
||||
dtype=torch_float32,
|
||||
device=device,
|
||||
requires_grad=False,
|
||||
dtype = torch_float32,
|
||||
device = device,
|
||||
requires_grad = False,
|
||||
)
|
||||
|
||||
if size > WEIGHT_BUFFER.numel():
|
||||
|
|
@ -519,16 +519,16 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
|
|||
else:
|
||||
if out is None:
|
||||
out = torch_empty(
|
||||
shape, dtype=dtype, device=device, requires_grad=False
|
||||
shape, dtype = dtype, device = device, requires_grad = False
|
||||
)
|
||||
else:
|
||||
assert out.shape == shape
|
||||
assert out.dtype == dtype
|
||||
out_absmax = torch_empty(
|
||||
n_elements_absmax,
|
||||
dtype=torch_float32,
|
||||
device=device,
|
||||
requires_grad=False,
|
||||
dtype = torch_float32,
|
||||
device = device,
|
||||
requires_grad = False,
|
||||
)
|
||||
pass
|
||||
|
||||
|
|
@ -570,7 +570,7 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
|
|||
else:
|
||||
|
||||
@torch.inference_mode
|
||||
def fast_dequantize(W, quant_state=None, out=None, use_global_buffer=False):
|
||||
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
|
||||
if isinstance(W, Float8Tensor):
|
||||
return W.dequantize()
|
||||
if quant_state is None:
|
||||
|
|
@ -601,12 +601,12 @@ else:
|
|||
|
||||
# Create weight matrix
|
||||
if out is None:
|
||||
out = torch_empty(shape, dtype=dtype, device=device, requires_grad=False)
|
||||
out = torch_empty(shape, dtype = dtype, device = device, requires_grad = False)
|
||||
else:
|
||||
assert out.shape == shape
|
||||
assert out.dtype == dtype
|
||||
out_absmax = torch_empty(
|
||||
n_elements_absmax, dtype=torch_float32, device=device, requires_grad=False
|
||||
n_elements_absmax, dtype = torch_float32, device = device, requires_grad = False
|
||||
)
|
||||
|
||||
# Do dequantization
|
||||
|
|
@ -645,9 +645,9 @@ else:
|
|||
# INTEL GPU Specific Logic
|
||||
if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
|
||||
|
||||
def fast_gemv(X, W, quant_state, out=None):
|
||||
def fast_gemv(X, W, quant_state, out = None):
|
||||
if quant_state is None:
|
||||
return torch_matmul(X, W, out=out)
|
||||
return torch_matmul(X, W, out = out)
|
||||
# For fast X @ W where seq_len == 1
|
||||
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
|
||||
_, q_len, hd = X.shape
|
||||
|
|
@ -686,8 +686,8 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
|
|||
1,
|
||||
bout,
|
||||
),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
dtype = dtype,
|
||||
device = device,
|
||||
)
|
||||
# else:
|
||||
# assert(out.shape == (1, 1, bout,))
|
||||
|
|
@ -710,7 +710,7 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
|
|||
ldb = ctypes_c_int32(ldb)
|
||||
ldc = ctypes_c_int32(ldc)
|
||||
|
||||
df = torch_empty(absmax.shape, dtype=torch_float32, device=device)
|
||||
df = torch_empty(absmax.shape, dtype = torch_float32, device = device)
|
||||
with torch_gpu_device(device):
|
||||
cdequantize_blockwise_fp32(
|
||||
get_ptr(code2),
|
||||
|
|
@ -751,9 +751,9 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
|
|||
|
||||
elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
|
||||
|
||||
def fast_gemv(X, W, quant_state, out=None):
|
||||
def fast_gemv(X, W, quant_state, out = None):
|
||||
if quant_state is None:
|
||||
return torch_matmul(X, W, out=out)
|
||||
return torch_matmul(X, W, out = out)
|
||||
# For fast X @ W where seq_len == 1
|
||||
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
|
||||
_, q_len, hd = X.shape
|
||||
|
|
@ -793,8 +793,8 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
|
|||
1,
|
||||
bout,
|
||||
),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
dtype = dtype,
|
||||
device = device,
|
||||
)
|
||||
# else:
|
||||
# assert(out.shape == (1, 1, bout,))
|
||||
|
|
@ -813,7 +813,7 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
|
|||
ldb = ctypes_c_int32(ldb)
|
||||
ldc = ctypes_c_int32(ldc)
|
||||
|
||||
df = torch_empty(absmax.shape, dtype=torch_float32, device=device)
|
||||
df = torch_empty(absmax.shape, dtype = torch_float32, device = device)
|
||||
with torch_gpu_device(device):
|
||||
cdequantize_blockwise_fp32(
|
||||
get_ptr(code2),
|
||||
|
|
@ -856,9 +856,9 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
|
|||
pass
|
||||
else:
|
||||
|
||||
def fast_gemv(X, W, quant_state, out=None):
|
||||
def fast_gemv(X, W, quant_state, out = None):
|
||||
if quant_state is None:
|
||||
return torch_matmul(X, W, out=out)
|
||||
return torch_matmul(X, W, out = out)
|
||||
# For fast X @ W where seq_len == 1
|
||||
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
|
||||
_, q_len, hd = X.shape
|
||||
|
|
@ -894,8 +894,8 @@ else:
|
|||
1,
|
||||
bout,
|
||||
),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
dtype = dtype,
|
||||
device = device,
|
||||
)
|
||||
# else:
|
||||
# assert(out.shape == (1, 1, bout,))
|
||||
|
|
@ -914,7 +914,7 @@ else:
|
|||
ldb = ctypes_c_int32(ldb)
|
||||
ldc = ctypes_c_int32(ldc)
|
||||
|
||||
df = torch_empty(absmax.shape, dtype=torch_float32, device=device)
|
||||
df = torch_empty(absmax.shape, dtype = torch_float32, device = device)
|
||||
cdequantize_blockwise_fp32(
|
||||
get_ptr(code2),
|
||||
get_ptr(absmax),
|
||||
|
|
@ -953,21 +953,21 @@ else:
|
|||
pass
|
||||
|
||||
|
||||
def fast_linear_forward(proj, X, temp_lora=None, out=None):
|
||||
def fast_linear_forward(proj, X, temp_lora = None, out = None):
|
||||
W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
|
||||
bsz, q_len, in_dim = X.shape
|
||||
if q_len != 1:
|
||||
return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
|
||||
|
||||
if W_quant is None:
|
||||
out = torch_matmul(X, W.t(), out=out)
|
||||
out = torch_matmul(X, W.t(), out = out)
|
||||
elif W.dtype == torch.float8_e4m3fn:
|
||||
out = fp8_linear(X, W, W_quant, bias)
|
||||
elif bsz == 1 and q_len == 1:
|
||||
out = fast_gemv(X, W, W_quant, out=out)
|
||||
out = fast_gemv(X, W, W_quant, out = out)
|
||||
else:
|
||||
W = fast_dequantize(W.t(), W_quant, use_global_buffer=True)
|
||||
out = torch_matmul(X, W, out=out)
|
||||
W = fast_dequantize(W.t(), W_quant, use_global_buffer = True)
|
||||
out = torch_matmul(X, W, out = out)
|
||||
|
||||
# Add in LoRA weights
|
||||
if lora_A is not None:
|
||||
|
|
@ -980,14 +980,14 @@ def fast_linear_forward(proj, X, temp_lora=None, out=None):
|
|||
|
||||
if bsz == 1:
|
||||
out = out.view(out_dim)
|
||||
temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out=temp_lora)
|
||||
out.addmv_(lora_B._fast_lora, temp_lora, alpha=lora_S)
|
||||
temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
|
||||
out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
|
||||
else:
|
||||
out = out.view(bsz, out_dim)
|
||||
temp_lora = torch_mm(
|
||||
X.view(bsz, in_dim), lora_A._fast_lora.t(), out=temp_lora
|
||||
X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora
|
||||
)
|
||||
out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha=lora_S)
|
||||
out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S)
|
||||
out = out.view(bsz, 1, out_dim)
|
||||
|
||||
if bias is not None:
|
||||
|
|
@ -996,7 +996,7 @@ def fast_linear_forward(proj, X, temp_lora=None, out=None):
|
|||
return out
|
||||
|
||||
|
||||
def matmul_lora(X, W, W_quant, A, B, s, out=None):
|
||||
def matmul_lora(X, W, W_quant, A, B, s, out = None):
|
||||
dtype = X.dtype
|
||||
|
||||
if X.dim() == 3:
|
||||
|
|
@ -1015,12 +1015,12 @@ def matmul_lora(X, W, W_quant, A, B, s, out=None):
|
|||
W = W.dequantize()
|
||||
else:
|
||||
W = W.contiguous()
|
||||
out = torch_matmul(X, W.t(), out=out)
|
||||
out = torch_matmul(X, W.t(), out = out)
|
||||
elif W.dtype == torch.float8_e4m3fn:
|
||||
out = fp8_linear(X, W, W_quant)
|
||||
else:
|
||||
W = fast_dequantize(W, W_quant, use_global_buffer=True)
|
||||
out = torch_matmul(X, W.t(), out=out)
|
||||
W = fast_dequantize(W, W_quant, use_global_buffer = True)
|
||||
out = torch_matmul(X, W.t(), out = out)
|
||||
if W_quant is not None:
|
||||
del W
|
||||
|
||||
|
|
@ -1028,7 +1028,7 @@ def matmul_lora(X, W, W_quant, A, B, s, out=None):
|
|||
# LoRA is enabled
|
||||
A, B = A.t(), B.t()
|
||||
XA = torch_matmul(X, A.to(dtype))
|
||||
out.addmm_(XA, B.to(dtype), alpha=s)
|
||||
out.addmm_(XA, B.to(dtype), alpha = s)
|
||||
# out += (X @ A.to(dtype)) @ (s * B.to(dtype))
|
||||
|
||||
return out.view(batch, seq_len, -1) if reshape else out
|
||||
|
|
|
|||
|
|
@ -150,23 +150,23 @@ for temporary_patch in TEMPORARY_PATCHES:
|
|||
|
||||
# =============================================
|
||||
# Disable some warnings which can get annoying
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torch")
|
||||
warnings.filterwarnings(action="ignore", category=FutureWarning, module="torch")
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="huggingface_hub")
|
||||
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
|
||||
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "torch")
|
||||
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
|
||||
warnings.filterwarnings(
|
||||
action="ignore", category=FutureWarning, module="huggingface_hub"
|
||||
action = "ignore", category = FutureWarning, module = "huggingface_hub"
|
||||
)
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="trl")
|
||||
warnings.filterwarnings(action="ignore", category=FutureWarning, module="trl")
|
||||
warnings.filterwarnings(action="ignore", category=FutureWarning, module="xformers")
|
||||
warnings.filterwarnings(action="ignore", category=RuntimeWarning, module="subprocess")
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="transformers")
|
||||
warnings.filterwarnings(action="ignore", category=FutureWarning, module="accelerate")
|
||||
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "trl")
|
||||
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "trl")
|
||||
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "xformers")
|
||||
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "subprocess")
|
||||
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers")
|
||||
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate")
|
||||
warnings.filterwarnings(
|
||||
action="ignore", category=RuntimeWarning, module="multiprocessing"
|
||||
action = "ignore", category = RuntimeWarning, module = "multiprocessing"
|
||||
)
|
||||
warnings.filterwarnings(action="ignore", category=RuntimeWarning, module="multiprocess")
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="triton")
|
||||
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocess")
|
||||
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "triton")
|
||||
# Stop "Special tokens have been added in the vocabulary, ..."
|
||||
import logging
|
||||
|
||||
|
|
@ -345,10 +345,10 @@ except:
|
|||
# You passed `quantization_config` or equivalent parameters
|
||||
try:
|
||||
warnings.filterwarnings(
|
||||
action="ignore",
|
||||
message=r".*quantization_config.*",
|
||||
category=UserWarning,
|
||||
append=True,
|
||||
action = "ignore",
|
||||
message = r".*quantization_config.*",
|
||||
category = UserWarning,
|
||||
append = True,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
|
@ -357,10 +357,10 @@ except:
|
|||
# Will be fixed in torch 2.8.1 https://github.com/pytorch/pytorch/issues/158463
|
||||
try:
|
||||
warnings.filterwarnings(
|
||||
action="ignore",
|
||||
message=r".*Logical operators 'and' and 'or'.*",
|
||||
category=UserWarning,
|
||||
append=True,
|
||||
action = "ignore",
|
||||
message = r".*Logical operators 'and' and 'or'.*",
|
||||
category = UserWarning,
|
||||
append = True,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
|
@ -465,7 +465,7 @@ def extract_quant_model_param_count(model):
|
|||
return count
|
||||
|
||||
|
||||
def get_model_param_count(model, trainable_only=False):
|
||||
def get_model_param_count(model, trainable_only = False):
|
||||
"""
|
||||
Calculate model's total param count. If trainable_only is True then count only those requiring grads
|
||||
"""
|
||||
|
|
@ -596,14 +596,14 @@ if DEVICE_TYPE in ("cuda", "hip"):
|
|||
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
|
||||
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
|
||||
else:
|
||||
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
|
||||
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
|
||||
elif DEVICE_TYPE == "xpu":
|
||||
if Version(torch_version) < Version("2.6.0"):
|
||||
raise RuntimeError("torch.xpu currently only supports torch.version >= 2.6.0")
|
||||
else:
|
||||
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="xpu")
|
||||
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="xpu")
|
||||
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "xpu")
|
||||
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "xpu")
|
||||
# =============================================
|
||||
|
||||
# =============================================
|
||||
|
|
@ -934,9 +934,9 @@ import torch._inductor.utils
|
|||
|
||||
torch._inductor.utils.is_big_gpu = is_big_gpu
|
||||
patch_torch_compile(
|
||||
debug=UNSLOTH_COMPILE_DEBUG,
|
||||
O3=UNSLOTH_COMPILE_MAXIMUM,
|
||||
ignore_errors=UNSLOTH_COMPILE_IGNORE_ERRORS,
|
||||
debug = UNSLOTH_COMPILE_DEBUG,
|
||||
O3 = UNSLOTH_COMPILE_MAXIMUM,
|
||||
ignore_errors = UNSLOTH_COMPILE_IGNORE_ERRORS,
|
||||
)
|
||||
|
||||
torch_compile_options = {
|
||||
|
|
@ -983,9 +983,9 @@ def patch_regional_compilation():
|
|||
[
|
||||
torch.compile(
|
||||
x,
|
||||
dynamic=True,
|
||||
options=torch_compile_options,
|
||||
fullgraph=False,
|
||||
dynamic = True,
|
||||
options = torch_compile_options,
|
||||
fullgraph = False,
|
||||
)
|
||||
for x in args[0]
|
||||
]
|
||||
|
|
@ -1008,14 +1008,14 @@ def prepare_model_for_kbit_training(
|
|||
use_reentrant: Optional[bool] = True,
|
||||
) -> Any:
|
||||
return prepare_model_for_training(
|
||||
model=model,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_reentrant=use_reentrant,
|
||||
full_finetuning=False,
|
||||
train_layernorms=False,
|
||||
train_embedding=False,
|
||||
train_lm_head=False,
|
||||
float32_mixed_precision=True,
|
||||
model = model,
|
||||
use_gradient_checkpointing = use_gradient_checkpointing,
|
||||
use_reentrant = use_reentrant,
|
||||
full_finetuning = False,
|
||||
train_layernorms = False,
|
||||
train_embedding = False,
|
||||
train_lm_head = False,
|
||||
float32_mixed_precision = True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1033,7 +1033,7 @@ if Version(peft_version) < Version("0.12.0"):
|
|||
text = "if weight is not None:\n"
|
||||
start = source.find(text) + len(text)
|
||||
end = source.find("self.to(weight.device)", start)
|
||||
spaces = re.findall(r"^([ ]{1,})break", source, flags=re.MULTILINE)[0]
|
||||
spaces = re.findall(r"^([ ]{1,})break", source, flags = re.MULTILINE)[0]
|
||||
source = source.replace(source[start:end], spaces)
|
||||
spaces = len(re.match(r"[\s]{1,}", source).group(0))
|
||||
lines = source.split("\n")
|
||||
|
|
@ -1070,7 +1070,7 @@ import socket
|
|||
|
||||
|
||||
@functools.lru_cache(1)
|
||||
def has_internet(host="8.8.8.8", port=53, timeout=3):
|
||||
def has_internet(host = "8.8.8.8", port = 53, timeout = 3):
|
||||
if os.environ.get("TRANSFORMERS_OFFLINE", "0") == "1":
|
||||
return False
|
||||
try:
|
||||
|
|
@ -1084,12 +1084,12 @@ def has_internet(host="8.8.8.8", port=53, timeout=3):
|
|||
import psutil
|
||||
|
||||
|
||||
def _get_statistics(statistics=None, force_download=True):
|
||||
def _get_statistics(statistics = None, force_download = True):
|
||||
# We log some basic stats about which environment is being used.
|
||||
# We simply download a README.md file from HF - all data is made public.
|
||||
# This is simply so we can check if some envs are broken or not.
|
||||
# You can disable this by commenting the below out
|
||||
n_cpus = psutil.cpu_count(logical=False)
|
||||
n_cpus = psutil.cpu_count(logical = False)
|
||||
keynames = "\n" + "\n".join(os.environ.keys())
|
||||
# Check modelscope for down detection
|
||||
global USE_MODELSCOPE
|
||||
|
|
@ -1150,12 +1150,12 @@ def _get_statistics(statistics=None, force_download=True):
|
|||
if has_internet():
|
||||
|
||||
def stats_check():
|
||||
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as f:
|
||||
with tempfile.TemporaryDirectory(ignore_cleanup_errors = True) as f:
|
||||
snapshot_download(
|
||||
f"unslothai/{statistics}",
|
||||
force_download=True,
|
||||
cache_dir=f,
|
||||
local_dir=f,
|
||||
force_download = True,
|
||||
cache_dir = f,
|
||||
local_dir = f,
|
||||
)
|
||||
|
||||
time_limited_stats_check = execute_with_time_limit(120)(stats_check)
|
||||
|
|
@ -1178,7 +1178,7 @@ def _get_statistics(statistics=None, force_download=True):
|
|||
stats_check()
|
||||
|
||||
|
||||
def get_statistics(local_files_only=False):
|
||||
def get_statistics(local_files_only = False):
|
||||
# We log some basic stats about which environment is being used.
|
||||
# This is also to check if HuggingFace is down or not!
|
||||
# We simply download a README.md file from HF - all data is made public.
|
||||
|
|
@ -1201,7 +1201,7 @@ def get_statistics(local_files_only=False):
|
|||
disable_progress_bars()
|
||||
disabled = True
|
||||
_get_statistics(None)
|
||||
_get_statistics("repeat", force_download=False)
|
||||
_get_statistics("repeat", force_download = False)
|
||||
total_memory = (
|
||||
torch.xpu.get_device_properties(0).total_memory
|
||||
if DEVICE_TYPE == "xpu"
|
||||
|
|
@ -1242,7 +1242,7 @@ BitsAndBytesConfig__init__ = re.sub(
|
|||
r"if[\s]{1,}kwargs\:[\s]{1,}.+?\n",
|
||||
"",
|
||||
BitsAndBytesConfig__init__,
|
||||
flags=re.MULTILINE,
|
||||
flags = re.MULTILINE,
|
||||
)
|
||||
BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.split("\n")
|
||||
length_spaces = len(re.match(r"[\s]{1,}", BitsAndBytesConfig__init__[0]).group(0))
|
||||
|
|
@ -1322,12 +1322,12 @@ def offload_to_disk(
|
|||
torch.save(
|
||||
W,
|
||||
filename,
|
||||
pickle_module=pickle,
|
||||
pickle_protocol=pickle.HIGHEST_PROTOCOL,
|
||||
pickle_module = pickle,
|
||||
pickle_protocol = pickle.HIGHEST_PROTOCOL,
|
||||
)
|
||||
# We must use weights_only = False due to pickling
|
||||
offloaded_W = torch.load(
|
||||
filename, map_location="cpu", mmap=True, weights_only=False
|
||||
filename, map_location = "cpu", mmap = True, weights_only = False
|
||||
)
|
||||
offloaded_W._offloaded_file_location = filename
|
||||
return offloaded_W
|
||||
|
|
@ -1352,7 +1352,7 @@ def offload_output_embeddings(
|
|||
model.get_output_embeddings(), model, "output_embeddings", temporary_location
|
||||
)
|
||||
|
||||
new_output_embeddings = torch.nn.Linear(1, 1, bias=None)
|
||||
new_output_embeddings = torch.nn.Linear(1, 1, bias = None)
|
||||
del new_output_embeddings.weight
|
||||
new_output_embeddings.weight = offloaded_W
|
||||
new_output_embeddings.in_features = offloaded_W.shape[1]
|
||||
|
|
@ -1376,10 +1376,10 @@ def is_vLLM_available():
|
|||
|
||||
# Patches models to add RoPE Scaling
|
||||
def patch_linear_scaling(
|
||||
model_name="gemma2",
|
||||
rope_module=None,
|
||||
scaled_rope_module=None,
|
||||
attention_module=None,
|
||||
model_name = "gemma2",
|
||||
rope_module = None,
|
||||
scaled_rope_module = None,
|
||||
attention_module = None,
|
||||
):
|
||||
assert rope_module is not None and scaled_rope_module is not None
|
||||
assert attention_module is not None
|
||||
|
|
@ -1430,13 +1430,13 @@ def patch_linear_scaling(
|
|||
pass
|
||||
"""
|
||||
fix_rope_function = fix_rope_function.format(
|
||||
rope_function=rope_module.__name__,
|
||||
scaled_rope_function=scaled_rope_module.__name__,
|
||||
rope_function = rope_module.__name__,
|
||||
scaled_rope_function = scaled_rope_module.__name__,
|
||||
)
|
||||
rotary_emb = re.findall(
|
||||
r"self\.rotary\_emb \= .+?\)",
|
||||
function,
|
||||
flags=re.DOTALL | re.MULTILINE,
|
||||
flags = re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
if len(rotary_emb) == 0:
|
||||
return None, exec_code + "\n\n" + function
|
||||
|
|
@ -1449,12 +1449,12 @@ def patch_linear_scaling(
|
|||
|
||||
# Patches for Llama-3 LlamaExtendedRotaryEmbedding
|
||||
def patch_llama_rope_scaling(
|
||||
model_name="llama",
|
||||
rope_module=None,
|
||||
scaled_rope_module=None,
|
||||
extended_rope_module=None,
|
||||
attention_module=None,
|
||||
longrope_module=None,
|
||||
model_name = "llama",
|
||||
rope_module = None,
|
||||
scaled_rope_module = None,
|
||||
extended_rope_module = None,
|
||||
attention_module = None,
|
||||
longrope_module = None,
|
||||
):
|
||||
assert (
|
||||
rope_module is not None
|
||||
|
|
@ -1528,17 +1528,17 @@ def patch_llama_rope_scaling(
|
|||
"""
|
||||
|
||||
fix_rope_function = fix_rope_function.format(
|
||||
rope_function=rope_module.__name__,
|
||||
scaled_rope_function=scaled_rope_module.__name__,
|
||||
extended_rope_function=extended_rope_module.__name__,
|
||||
longrope_rope_function=(
|
||||
rope_function = rope_module.__name__,
|
||||
scaled_rope_function = scaled_rope_module.__name__,
|
||||
extended_rope_function = extended_rope_module.__name__,
|
||||
longrope_rope_function = (
|
||||
longrope_module if longrope_module is not None else rope_module
|
||||
).__name__,
|
||||
)
|
||||
rotary_emb = re.findall(
|
||||
r"self\.rotary\_emb \= .+?\)",
|
||||
function,
|
||||
flags=re.DOTALL | re.MULTILINE,
|
||||
flags = re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
if len(rotary_emb) == 0:
|
||||
return None, function
|
||||
|
|
@ -1548,15 +1548,15 @@ def patch_llama_rope_scaling(
|
|||
return init_name, function
|
||||
|
||||
|
||||
def create_boolean_mask(n=4096, sliding_window=2048):
|
||||
def create_boolean_mask(n = 4096, sliding_window = 2048):
|
||||
# Creates a boolean mask for attention
|
||||
mask = torch.ones(n, n, dtype=torch.bool)
|
||||
mask = torch.ones(n, n, dtype = torch.bool)
|
||||
if sliding_window == 0:
|
||||
return torch.triu(mask, diagonal=1, out=mask)
|
||||
torch.triu(mask, diagonal=0, out=mask)
|
||||
torch.triu(mask.T, diagonal=-sliding_window, out=mask.T)
|
||||
return torch.triu(mask, diagonal = 1, out = mask)
|
||||
torch.triu(mask, diagonal = 0, out = mask)
|
||||
torch.triu(mask.T, diagonal = -sliding_window, out = mask.T)
|
||||
mask = mask.T
|
||||
torch.logical_not(mask, out=mask)
|
||||
torch.logical_not(mask, out = mask)
|
||||
return mask
|
||||
|
||||
|
||||
|
|
@ -1567,37 +1567,37 @@ def test_mask_creation():
|
|||
for s in range(1, 23):
|
||||
correct_mask = (
|
||||
AttentionMaskConverter(
|
||||
is_causal=True,
|
||||
sliding_window=s,
|
||||
is_causal = True,
|
||||
sliding_window = s,
|
||||
)
|
||||
.to_causal_4d(
|
||||
1,
|
||||
n,
|
||||
n,
|
||||
dtype=torch.float16,
|
||||
dtype = torch.float16,
|
||||
)
|
||||
.squeeze(0)
|
||||
.squeeze(0)
|
||||
)
|
||||
correct_mask = correct_mask == correct_mask.min()
|
||||
our_mask = create_boolean_mask(n=n, sliding_window=s)
|
||||
our_mask = create_boolean_mask(n = n, sliding_window = s)
|
||||
assert torch.all(correct_mask == our_mask)
|
||||
correct_mask = (
|
||||
AttentionMaskConverter(
|
||||
is_causal=True,
|
||||
sliding_window=None,
|
||||
is_causal = True,
|
||||
sliding_window = None,
|
||||
)
|
||||
.to_causal_4d(
|
||||
1,
|
||||
n,
|
||||
n,
|
||||
dtype=torch.float16,
|
||||
dtype = torch.float16,
|
||||
)
|
||||
.squeeze(0)
|
||||
.squeeze(0)
|
||||
)
|
||||
correct_mask = correct_mask == correct_mask.min()
|
||||
our_mask = create_boolean_mask(n=n, sliding_window=0)
|
||||
our_mask = create_boolean_mask(n = n, sliding_window = 0)
|
||||
assert torch.all(correct_mask == our_mask)
|
||||
|
||||
|
||||
|
|
@ -1812,34 +1812,34 @@ def unsloth_compile_transformers(
|
|||
dtype,
|
||||
model_name,
|
||||
model_types,
|
||||
token=None,
|
||||
revision=None,
|
||||
trust_remote_code=False,
|
||||
sdpa_dynamic_mask=True,
|
||||
sdpa_bool_masks=True,
|
||||
sdpa_gqa_replace=True,
|
||||
sdpa_dynamic_compile=True,
|
||||
compile_attention=True,
|
||||
disable_causal_masks=True,
|
||||
compile_torch_modules=True,
|
||||
compile_custom_modules=True,
|
||||
compile_function_calls=True,
|
||||
fuse_lm_head=True,
|
||||
gradient_checkpointing=True,
|
||||
manual_replacements=True,
|
||||
fast_lora_forwards=True,
|
||||
fast_residual_stream=True,
|
||||
accurate_accumulation=True,
|
||||
epilogue_fusion=True,
|
||||
max_autotune=False,
|
||||
shape_padding=True,
|
||||
cudagraphs=False,
|
||||
debug=False,
|
||||
fullgraph=True,
|
||||
import_from_cache=False,
|
||||
disable=False,
|
||||
return_logits=False,
|
||||
unsloth_force_compile=False,
|
||||
token = None,
|
||||
revision = None,
|
||||
trust_remote_code = False,
|
||||
sdpa_dynamic_mask = True,
|
||||
sdpa_bool_masks = True,
|
||||
sdpa_gqa_replace = True,
|
||||
sdpa_dynamic_compile = True,
|
||||
compile_attention = True,
|
||||
disable_causal_masks = True,
|
||||
compile_torch_modules = True,
|
||||
compile_custom_modules = True,
|
||||
compile_function_calls = True,
|
||||
fuse_lm_head = True,
|
||||
gradient_checkpointing = True,
|
||||
manual_replacements = True,
|
||||
fast_lora_forwards = True,
|
||||
fast_residual_stream = True,
|
||||
accurate_accumulation = True,
|
||||
epilogue_fusion = True,
|
||||
max_autotune = False,
|
||||
shape_padding = True,
|
||||
cudagraphs = False,
|
||||
debug = False,
|
||||
fullgraph = True,
|
||||
import_from_cache = False,
|
||||
disable = False,
|
||||
return_logits = False,
|
||||
unsloth_force_compile = False,
|
||||
):
|
||||
if Version(torch_version) < Version("2.4.0"):
|
||||
print(
|
||||
|
|
@ -1864,31 +1864,31 @@ def unsloth_compile_transformers(
|
|||
for model_type in model_types:
|
||||
_unsloth_compile_transformers(
|
||||
model_type,
|
||||
sdpa_dynamic_mask=sdpa_dynamic_mask,
|
||||
sdpa_bool_masks=sdpa_bool_masks,
|
||||
sdpa_gqa_replace=sdpa_gqa_replace,
|
||||
sdpa_dynamic_compile=sdpa_dynamic_compile,
|
||||
compile_attention=compile_attention,
|
||||
disable_causal_masks=disable_causal_masks,
|
||||
compile_torch_modules=compile_torch_modules,
|
||||
compile_custom_modules=compile_custom_modules,
|
||||
compile_function_calls=compile_function_calls,
|
||||
fuse_lm_head=fuse_lm_head,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
manual_replacements=manual_replacements,
|
||||
fast_lora_forwards=fast_lora_forwards,
|
||||
fast_residual_stream=fast_residual_stream,
|
||||
accurate_accumulation=accurate_accumulation,
|
||||
epilogue_fusion=epilogue_fusion,
|
||||
max_autotune=max_autotune,
|
||||
shape_padding=shape_padding,
|
||||
cudagraphs=cudagraphs,
|
||||
debug=debug,
|
||||
fullgraph=fullgraph,
|
||||
import_from_cache=import_from_cache,
|
||||
disable=disable,
|
||||
return_logits=return_logits,
|
||||
supports_sdpa=supports_sdpa,
|
||||
sdpa_dynamic_mask = sdpa_dynamic_mask,
|
||||
sdpa_bool_masks = sdpa_bool_masks,
|
||||
sdpa_gqa_replace = sdpa_gqa_replace,
|
||||
sdpa_dynamic_compile = sdpa_dynamic_compile,
|
||||
compile_attention = compile_attention,
|
||||
disable_causal_masks = disable_causal_masks,
|
||||
compile_torch_modules = compile_torch_modules,
|
||||
compile_custom_modules = compile_custom_modules,
|
||||
compile_function_calls = compile_function_calls,
|
||||
fuse_lm_head = fuse_lm_head,
|
||||
gradient_checkpointing = gradient_checkpointing,
|
||||
manual_replacements = manual_replacements,
|
||||
fast_lora_forwards = fast_lora_forwards,
|
||||
fast_residual_stream = fast_residual_stream,
|
||||
accurate_accumulation = accurate_accumulation,
|
||||
epilogue_fusion = epilogue_fusion,
|
||||
max_autotune = max_autotune,
|
||||
shape_padding = shape_padding,
|
||||
cudagraphs = cudagraphs,
|
||||
debug = debug,
|
||||
fullgraph = fullgraph,
|
||||
import_from_cache = import_from_cache,
|
||||
disable = disable,
|
||||
return_logits = return_logits,
|
||||
supports_sdpa = supports_sdpa,
|
||||
)
|
||||
# Redo patches which override compiler
|
||||
for temporary_patch in TEMPORARY_PATCHES:
|
||||
|
|
@ -1993,7 +1993,7 @@ def validate_loftq_config(loftq_config, lora_dropout, bias, init_lora_weights, m
|
|||
"Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"
|
||||
"We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`."
|
||||
)
|
||||
loftq_config = LoftQConfig(loftq_bits=4, loftq_iter=1)
|
||||
loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)
|
||||
|
||||
if hasattr(model.config, "quantization_config"):
|
||||
raise ValueError(
|
||||
|
|
@ -2069,9 +2069,9 @@ class TorchAOConfig:
|
|||
base_config_and_filter_fns: List[
|
||||
Tuple["AOBaseConfig", Optional[Callable[[torch.nn.Module, str], bool]]]
|
||||
] = field(
|
||||
default_factory=lambda: [
|
||||
default_factory = lambda: [
|
||||
(
|
||||
Int4WeightOnlyConfig(group_size=128),
|
||||
Int4WeightOnlyConfig(group_size = 128),
|
||||
lambda m, _: isinstance(m, torch.nn.Linear)
|
||||
and getattr(m, "in_features", 0) >= 128,
|
||||
),
|
||||
|
|
@ -2143,7 +2143,7 @@ def _convert_torchao_model(model):
|
|||
|
||||
module_to_fqn_dict = {}
|
||||
for base_config, filter_fn in model._torchao_config.base_config_and_filter_fns:
|
||||
quantize_(model, QATConfig(base_config, step="convert"), filter_fn=filter_fn)
|
||||
quantize_(model, QATConfig(base_config, step = "convert"), filter_fn = filter_fn)
|
||||
|
||||
# Default filter function used for quantize_
|
||||
if filter_fn is None:
|
||||
|
|
@ -2166,7 +2166,7 @@ def _convert_torchao_model(model):
|
|||
kwargs["modules_to_not_convert"] = []
|
||||
|
||||
quant_config = ModuleFqnToConfig(module_to_fqn_dict)
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config, **kwargs)
|
||||
quantization_config = TorchAoConfig(quant_type = quant_config, **kwargs)
|
||||
model.config.quantization_config = quantization_config
|
||||
|
||||
|
||||
|
|
@ -2199,17 +2199,17 @@ def _prepare_model_for_qat(
|
|||
and m.in_features >= group_size
|
||||
)
|
||||
torchao_config = TorchAOConfig(
|
||||
qat_scheme=qat_scheme,
|
||||
base_config_and_filter_fns=[(base_config, filter_fn)],
|
||||
qat_scheme = qat_scheme,
|
||||
base_config_and_filter_fns = [(base_config, filter_fn)],
|
||||
)
|
||||
elif qat_scheme == "fp8-fp8":
|
||||
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
|
||||
|
||||
base_config = Float8DynamicActivationFloat8WeightConfig(
|
||||
granularity=PerRow()
|
||||
granularity = PerRow()
|
||||
)
|
||||
torchao_config = TorchAOConfig(
|
||||
qat_scheme=qat_scheme, base_config_and_filter_fns=[(base_config, None)]
|
||||
qat_scheme = qat_scheme, base_config_and_filter_fns = [(base_config, None)]
|
||||
)
|
||||
elif qat_scheme == "int8-int4":
|
||||
from torchao.quantization import (
|
||||
|
|
@ -2218,35 +2218,35 @@ def _prepare_model_for_qat(
|
|||
)
|
||||
|
||||
torchao_config = TorchAOConfig(
|
||||
qat_scheme=qat_scheme,
|
||||
base_config_and_filter_fns=[
|
||||
qat_scheme = qat_scheme,
|
||||
base_config_and_filter_fns = [
|
||||
(
|
||||
IntxWeightOnlyConfig(
|
||||
weight_dtype=torch.int8, granularity=PerAxis(0)
|
||||
weight_dtype = torch.int8, granularity = PerAxis(0)
|
||||
),
|
||||
lambda m, fqn: isinstance(m, torch.nn.Embedding),
|
||||
),
|
||||
(
|
||||
Int8DynamicActivationIntxWeightConfig(
|
||||
weight_dtype=torch.int4, weight_granularity=PerGroup(32)
|
||||
weight_dtype = torch.int4, weight_granularity = PerGroup(32)
|
||||
),
|
||||
None,
|
||||
),
|
||||
],
|
||||
prequantization_transform=_untie_input_output_embeddings,
|
||||
prequantization_transform = _untie_input_output_embeddings,
|
||||
)
|
||||
elif qat_scheme == "int4":
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
|
||||
group_size = 128
|
||||
base_config = Int4WeightOnlyConfig(group_size=group_size)
|
||||
base_config = Int4WeightOnlyConfig(group_size = group_size)
|
||||
filter_fn = (
|
||||
lambda m, _: isinstance(m, torch.nn.Linear)
|
||||
and m.in_features >= group_size
|
||||
)
|
||||
torchao_config = TorchAOConfig(
|
||||
qat_scheme=qat_scheme,
|
||||
base_config_and_filter_fns=[(base_config, filter_fn)],
|
||||
qat_scheme = qat_scheme,
|
||||
base_config_and_filter_fns = [(base_config, filter_fn)],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected QAT scheme {qat_scheme}")
|
||||
|
|
@ -2264,7 +2264,7 @@ def _prepare_model_for_qat(
|
|||
if torchao_config.prequantization_transform is not None:
|
||||
torchao_config.prequantization_transform(model)
|
||||
for base_config, filter_fn in torchao_config.base_config_and_filter_fns:
|
||||
quantize_(model, QATConfig(base_config, step="prepare"), filter_fn=filter_fn)
|
||||
quantize_(model, QATConfig(base_config, step = "prepare"), filter_fn = filter_fn)
|
||||
|
||||
return model
|
||||
|
||||
|
|
|
|||
|
|
@ -113,8 +113,8 @@ def Gemma2Attention_fast_forward(
|
|||
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
K = torch.cat([past_key_value[0], K], dim=2)
|
||||
V = torch.cat([past_key_value[1], V], dim=2)
|
||||
K = torch.cat([past_key_value[0], K], dim = 2)
|
||||
V = torch.cat([past_key_value[1], V], dim = 2)
|
||||
past_key_value = (K, V) if use_cache else None
|
||||
|
||||
# Only enable if the attention_mask is True
|
||||
|
|
@ -139,10 +139,10 @@ def Gemma2Attention_fast_forward(
|
|||
Q,
|
||||
K,
|
||||
V,
|
||||
causal=True,
|
||||
softcap=self.config.attn_logit_softcapping,
|
||||
softmax_scale=self._flash_attention_softmax_scale,
|
||||
window_size=window,
|
||||
causal = True,
|
||||
softcap = self.config.attn_logit_softcapping,
|
||||
softmax_scale = self._flash_attention_softmax_scale,
|
||||
window_size = window,
|
||||
)
|
||||
A = A.reshape(bsz, q_len, n_heads * head_dim)
|
||||
else:
|
||||
|
|
@ -174,7 +174,7 @@ def Gemma2DecoderLayer_fast_forward(
|
|||
self, "_flag_for_generation"
|
||||
): # past_key_value is not None:
|
||||
out_weight = torch.empty(
|
||||
self.input_layernorm.weight.shape, dtype=torch.float32, device="cuda:0"
|
||||
self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0"
|
||||
)
|
||||
|
||||
# Self Attention
|
||||
|
|
@ -183,15 +183,15 @@ def Gemma2DecoderLayer_fast_forward(
|
|||
self.input_layernorm, hidden_states, out_weight
|
||||
)
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
_flag_for_generation=self._flag_for_generation,
|
||||
hidden_states = hidden_states,
|
||||
causal_mask = causal_mask,
|
||||
attention_mask = attention_mask,
|
||||
position_ids = position_ids,
|
||||
past_key_value = past_key_value,
|
||||
output_attentions = output_attentions,
|
||||
use_cache = use_cache,
|
||||
padding_mask = padding_mask,
|
||||
_flag_for_generation = self._flag_for_generation,
|
||||
)
|
||||
hidden_states = fast_rms_layernorm_inference_gemma(
|
||||
self.post_attention_layernorm, hidden_states, out_weight
|
||||
|
|
@ -211,31 +211,31 @@ def Gemma2DecoderLayer_fast_forward(
|
|||
else:
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm(
|
||||
self.input_layernorm, hidden_states, gemma=True
|
||||
self.input_layernorm, hidden_states, gemma = True
|
||||
)
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
hidden_states = hidden_states,
|
||||
causal_mask = causal_mask,
|
||||
attention_mask = attention_mask,
|
||||
position_ids = position_ids,
|
||||
past_key_value = past_key_value,
|
||||
output_attentions = output_attentions,
|
||||
use_cache = use_cache,
|
||||
padding_mask = padding_mask,
|
||||
)
|
||||
hidden_states = fast_rms_layernorm(
|
||||
self.post_attention_layernorm, hidden_states, gemma=True
|
||||
self.post_attention_layernorm, hidden_states, gemma = True
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm(
|
||||
self.pre_feedforward_layernorm, hidden_states, gemma=True
|
||||
self.pre_feedforward_layernorm, hidden_states, gemma = True
|
||||
)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = fast_rms_layernorm(
|
||||
self.post_feedforward_layernorm, hidden_states, gemma=True
|
||||
self.post_feedforward_layernorm, hidden_states, gemma = True
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
|
|
@ -260,9 +260,9 @@ def Gemma2Attention_fast_forward_inference(
|
|||
hidden_states: torch.Tensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]],
|
||||
position_ids,
|
||||
do_prefill=False,
|
||||
attention_mask=None,
|
||||
use_sliding_window=False,
|
||||
do_prefill = False,
|
||||
attention_mask = None,
|
||||
use_sliding_window = False,
|
||||
):
|
||||
Xn = hidden_states
|
||||
bsz, _, hd = hidden_states.size()
|
||||
|
|
@ -286,24 +286,24 @@ def Gemma2Attention_fast_forward_inference(
|
|||
if do_prefill:
|
||||
self.paged_attention = torch.empty(
|
||||
(KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
dtype = dtype,
|
||||
device = device,
|
||||
)
|
||||
self.paged_attention_K = self.paged_attention[:, 0]
|
||||
self.paged_attention_V = self.paged_attention[:, 1]
|
||||
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
|
||||
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
|
||||
self.temp_QA = torch.empty(
|
||||
(2, bsz, 1, attention_size), dtype=dtype, device=device
|
||||
(2, bsz, 1, attention_size), dtype = dtype, device = device
|
||||
)
|
||||
self.temp_KV = torch.empty(
|
||||
(2, bsz, 1, n_kv_heads * head_dim), dtype=dtype, device=device
|
||||
(2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
|
||||
)
|
||||
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype=dtype, device=device)
|
||||
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
|
||||
# Only for Gemma2
|
||||
self.temp_O = torch.empty((1, bsz, hidden_size), dtype=dtype, device=device)
|
||||
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
|
||||
self.attention = torch.empty(
|
||||
(bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype=dtype, device=device
|
||||
(bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
|
||||
)
|
||||
|
||||
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
|
||||
|
|
@ -331,9 +331,9 @@ def Gemma2Attention_fast_forward_inference(
|
|||
(bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
|
||||
)
|
||||
|
||||
Qn = fast_linear_forward(self.q_proj, Xn, out=self.temp_QA[0])
|
||||
Kn = fast_linear_forward(self.k_proj, Xn, out=self.temp_KV[0])
|
||||
Vn = fast_linear_forward(self.v_proj, Xn, out=self.temp_KV[1])
|
||||
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
|
||||
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
|
||||
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
|
||||
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
|
||||
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
|
||||
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
|
||||
|
|
@ -348,7 +348,7 @@ def Gemma2Attention_fast_forward_inference(
|
|||
RH_Q = self.RH_Q
|
||||
RH_Q[:, :, :, :h] = Qn[:, :, :, h:]
|
||||
RH_Q[:, :, :, h:] = Qn[:, :, :, :h]
|
||||
torch.neg(RH_Q[:, :, :, :h], out=RH_Q[:, :, :, :h])
|
||||
torch.neg(RH_Q[:, :, :, :h], out = RH_Q[:, :, :, :h])
|
||||
Qn *= cos
|
||||
Qn.addcmul_(RH_Q, sin)
|
||||
|
||||
|
|
@ -357,7 +357,7 @@ def Gemma2Attention_fast_forward_inference(
|
|||
] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
|
||||
RH_K[:, :, :, :h] = Kn[:, :, :, h:]
|
||||
RH_K[:, :, :, h:] = Kn[:, :, :, :h]
|
||||
torch.neg(RH_K[:, :, :, :h], out=RH_K[:, :, :, :h])
|
||||
torch.neg(RH_K[:, :, :, :h], out = RH_K[:, :, :, :h])
|
||||
Kn *= cos
|
||||
Kn.addcmul_(RH_K, sin)
|
||||
|
||||
|
|
@ -400,21 +400,21 @@ def Gemma2Attention_fast_forward_inference(
|
|||
self.scalar
|
||||
) # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
|
||||
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
|
||||
A = torch_matmul(Qn, Knn.transpose(2, 3), out=self.attention[:, :, :, :cached_len])
|
||||
A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len])
|
||||
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
|
||||
|
||||
A *= self.reciprocal_t
|
||||
torch_tanh(A, out=A)
|
||||
torch_tanh(A, out = A)
|
||||
A *= self.t # Logit softcapping
|
||||
|
||||
A[:] = torch_nn_functional_softmax(A, dim=-1, dtype=torch.float32) # .to(A.dtype)
|
||||
A = torch_matmul(A, Vnn, out=Qn)
|
||||
A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32) # .to(A.dtype)
|
||||
A = torch_matmul(A, Vnn, out = Qn)
|
||||
# else:
|
||||
# A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
|
||||
# pass
|
||||
A = A.transpose(1, 2)
|
||||
A = A.reshape(bsz, 1, attention_size)
|
||||
A = fast_linear_forward(self.o_proj, A, out=self.temp_O)
|
||||
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
|
||||
return A, (Kn, Vn)
|
||||
|
||||
|
||||
|
|
@ -425,13 +425,13 @@ def Gemma2Model_fast_forward_inference(
|
|||
input_ids,
|
||||
past_key_values,
|
||||
position_ids,
|
||||
attention_mask=None,
|
||||
attention_mask = None,
|
||||
):
|
||||
out_weights = tuple(
|
||||
torch.empty_like(
|
||||
self.model.layers[0].input_layernorm.weight,
|
||||
dtype=torch.float32,
|
||||
device=torch.device(x),
|
||||
dtype = torch.float32,
|
||||
device = torch.device(x),
|
||||
)
|
||||
for x in range(DEVICE_COUNT)
|
||||
)
|
||||
|
|
@ -441,7 +441,7 @@ def Gemma2Model_fast_forward_inference(
|
|||
# 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
|
||||
# 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
|
||||
hidden_states *= torch.tensor(
|
||||
math_sqrt(self.config.hidden_size), dtype=hidden_states.dtype
|
||||
math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype
|
||||
)
|
||||
|
||||
bsz, q_len, hd = hidden_states.shape
|
||||
|
|
@ -456,7 +456,7 @@ def Gemma2Model_fast_forward_inference(
|
|||
(bsz, q_len),
|
||||
hidden_states,
|
||||
seq_len,
|
||||
sliding_window=self.config.sliding_window,
|
||||
sliding_window = self.config.sliding_window,
|
||||
)
|
||||
GA = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
|
|
@ -484,12 +484,12 @@ def Gemma2Model_fast_forward_inference(
|
|||
)
|
||||
hidden_states, present_key_value = Gemma2Attention_fast_forward_inference(
|
||||
decoder_layer.self_attn,
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=past_key_values[idx],
|
||||
position_ids=position_ids,
|
||||
attention_mask=SWA if use_sliding_window else GA,
|
||||
do_prefill=not hasattr(decoder_layer.self_attn, "paged_attention"),
|
||||
use_sliding_window=use_sliding_window,
|
||||
hidden_states = hidden_states,
|
||||
past_key_value = past_key_values[idx],
|
||||
position_ids = position_ids,
|
||||
attention_mask = SWA if use_sliding_window else GA,
|
||||
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
|
||||
use_sliding_window = use_sliding_window,
|
||||
)
|
||||
hidden_states = fast_rms_layernorm_inference_gemma(
|
||||
decoder_layer.post_attention_layernorm,
|
||||
|
|
@ -518,10 +518,10 @@ def Gemma2Model_fast_forward_inference(
|
|||
)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=[],
|
||||
attentions=[],
|
||||
last_hidden_state = hidden_states,
|
||||
past_key_values = next_decoder_cache,
|
||||
hidden_states = [],
|
||||
attentions = [],
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -529,10 +529,10 @@ class FastGemma2Model(FastLlamaModel):
|
|||
@staticmethod
|
||||
def pre_patch():
|
||||
init_name, function = patch_linear_scaling(
|
||||
model_name="gemma2",
|
||||
rope_module=GemmaFixedRotaryEmbedding,
|
||||
scaled_rope_module=GemmaFixedLinearScalingRotaryEmbedding,
|
||||
attention_module=Gemma2Attention,
|
||||
model_name = "gemma2",
|
||||
rope_module = GemmaFixedRotaryEmbedding,
|
||||
scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,
|
||||
attention_module = Gemma2Attention,
|
||||
)
|
||||
if init_name is not None:
|
||||
exec(function, globals())
|
||||
|
|
@ -564,7 +564,7 @@ class FastGemma2Model(FastLlamaModel):
|
|||
def post_patch(model, tokenizer):
|
||||
# Gemma does not downcast RoPE
|
||||
model, tokenizer = patch_model_and_tokenizer(
|
||||
model, tokenizer, downcast_rope=False
|
||||
model, tokenizer, downcast_rope = False
|
||||
)
|
||||
|
||||
# Add 1 to weight
|
||||
|
|
|
|||
|
|
@ -109,8 +109,8 @@ def GraniteAttention_fast_forward(
|
|||
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
K = torch.cat([past_key_value[0], K], dim=2)
|
||||
V = torch.cat([past_key_value[1], V], dim=2)
|
||||
K = torch.cat([past_key_value[0], K], dim = 2)
|
||||
V = torch.cat([past_key_value[1], V], dim = 2)
|
||||
past_key_value = (K, V) if use_cache else None
|
||||
|
||||
# Attention module
|
||||
|
|
@ -135,7 +135,7 @@ def GraniteAttention_fast_forward(
|
|||
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
|
||||
|
||||
A = xformers_attention(
|
||||
Q, K, V, attn_bias=causal_mask, scale=self.scaling, p=dropout_p
|
||||
Q, K, V, attn_bias = causal_mask, scale = self.scaling, p = dropout_p
|
||||
)
|
||||
A = A.view(bsz, q_len, n_heads, head_dim)
|
||||
|
||||
|
|
@ -148,10 +148,10 @@ def GraniteAttention_fast_forward(
|
|||
Q,
|
||||
K,
|
||||
V,
|
||||
causal=True,
|
||||
window_size=window,
|
||||
softmax_scale=self.scaling,
|
||||
dropout_p=dropout_p,
|
||||
causal = True,
|
||||
window_size = window,
|
||||
softmax_scale = self.scaling,
|
||||
dropout_p = dropout_p,
|
||||
)
|
||||
else:
|
||||
# Grouped query attention
|
||||
|
|
@ -170,10 +170,10 @@ def GraniteAttention_fast_forward(
|
|||
Q,
|
||||
K,
|
||||
V,
|
||||
attn_mask=attention_mask,
|
||||
scale=self.scaling,
|
||||
is_causal=False,
|
||||
dropout_p=dropout_p,
|
||||
attn_mask = attention_mask,
|
||||
scale = self.scaling,
|
||||
is_causal = False,
|
||||
dropout_p = dropout_p,
|
||||
)
|
||||
# Go back to (batch_size, seq_len, n_heads, head_dim)
|
||||
A = A.transpose(1, 2).contiguous()
|
||||
|
|
@ -212,18 +212,18 @@ def GraniteDecoderLayer_fast_forward(
|
|||
self.input_layernorm, hidden_states
|
||||
)
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
_flag_for_generation=self._flag_for_generation,
|
||||
hidden_states = hidden_states,
|
||||
causal_mask = causal_mask,
|
||||
attention_mask = attention_mask,
|
||||
position_ids = position_ids,
|
||||
past_key_value = past_key_value,
|
||||
output_attentions = output_attentions,
|
||||
use_cache = use_cache,
|
||||
padding_mask = padding_mask,
|
||||
position_embeddings = position_embeddings,
|
||||
_flag_for_generation = self._flag_for_generation,
|
||||
)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
|
|
@ -231,28 +231,28 @@ def GraniteDecoderLayer_fast_forward(
|
|||
self.post_attention_layernorm, hidden_states
|
||||
)
|
||||
hidden_states = fast_swiglu_inference(self.mlp, hidden_states)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
|
||||
else:
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
hidden_states = hidden_states,
|
||||
causal_mask = causal_mask,
|
||||
attention_mask = attention_mask,
|
||||
position_ids = position_ids,
|
||||
past_key_value = past_key_value,
|
||||
output_attentions = output_attentions,
|
||||
use_cache = use_cache,
|
||||
padding_mask = padding_mask,
|
||||
position_embeddings = position_embeddings,
|
||||
)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
|
|
@ -275,9 +275,9 @@ def GraniteAttention_fast_forward_inference(
|
|||
hidden_states: torch.Tensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]],
|
||||
position_ids,
|
||||
do_prefill=False,
|
||||
attention_mask=None,
|
||||
use_sliding_window=False,
|
||||
do_prefill = False,
|
||||
attention_mask = None,
|
||||
use_sliding_window = False,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
):
|
||||
assert (
|
||||
|
|
@ -306,24 +306,24 @@ def GraniteAttention_fast_forward_inference(
|
|||
if do_prefill:
|
||||
self.paged_attention = torch.empty(
|
||||
(KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
dtype = dtype,
|
||||
device = device,
|
||||
)
|
||||
self.paged_attention_K = self.paged_attention[:, 0]
|
||||
self.paged_attention_V = self.paged_attention[:, 1]
|
||||
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
|
||||
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
|
||||
self.temp_QA = torch.empty(
|
||||
(2, bsz, 1, attention_size), dtype=dtype, device=device
|
||||
(2, bsz, 1, attention_size), dtype = dtype, device = device
|
||||
)
|
||||
self.temp_KV = torch.empty(
|
||||
(2, bsz, 1, n_kv_heads * head_dim), dtype=dtype, device=device
|
||||
(2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
|
||||
)
|
||||
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype=dtype, device=device)
|
||||
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
|
||||
# Only for Gemma2
|
||||
self.temp_O = torch.empty((1, bsz, hidden_size), dtype=dtype, device=device)
|
||||
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
|
||||
self.attention = torch.empty(
|
||||
(bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype=dtype, device=device
|
||||
(bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
|
||||
)
|
||||
|
||||
self.half_head_dim = head_dim // 2
|
||||
|
|
@ -343,9 +343,9 @@ def GraniteAttention_fast_forward_inference(
|
|||
(bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
|
||||
)
|
||||
|
||||
Qn = fast_linear_forward(self.q_proj, Xn, out=self.temp_QA[0])
|
||||
Kn = fast_linear_forward(self.k_proj, Xn, out=self.temp_KV[0])
|
||||
Vn = fast_linear_forward(self.v_proj, Xn, out=self.temp_KV[1])
|
||||
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
|
||||
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
|
||||
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
|
||||
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
|
||||
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
|
||||
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
|
||||
|
|
@ -359,7 +359,7 @@ def GraniteAttention_fast_forward_inference(
|
|||
RH_Q = self.RH_Q
|
||||
RH_Q[:, :, :, :h] = Qn[:, :, :, h:]
|
||||
RH_Q[:, :, :, h:] = Qn[:, :, :, :h]
|
||||
torch.neg(RH_Q[:, :, :, :h], out=RH_Q[:, :, :, :h])
|
||||
torch.neg(RH_Q[:, :, :, :h], out = RH_Q[:, :, :, :h])
|
||||
Qn *= cos
|
||||
Qn.addcmul_(RH_Q, sin)
|
||||
|
||||
|
|
@ -368,7 +368,7 @@ def GraniteAttention_fast_forward_inference(
|
|||
] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
|
||||
RH_K[:, :, :, :h] = Kn[:, :, :, h:]
|
||||
RH_K[:, :, :, h:] = Kn[:, :, :, :h]
|
||||
torch.neg(RH_K[:, :, :, :h], out=RH_K[:, :, :, :h])
|
||||
torch.neg(RH_K[:, :, :, :h], out = RH_K[:, :, :, :h])
|
||||
Kn *= cos
|
||||
Kn.addcmul_(RH_K, sin)
|
||||
|
||||
|
|
@ -396,18 +396,18 @@ def GraniteAttention_fast_forward_inference(
|
|||
# pass
|
||||
|
||||
Qn *= self.scaling
|
||||
A = torch_matmul(Qn, Kn.transpose(2, 3), out=self.attention[:, :, :, :cached_len])
|
||||
A = torch_matmul(Qn, Kn.transpose(2, 3), out = self.attention[:, :, :, :cached_len])
|
||||
|
||||
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
|
||||
|
||||
A[:] = torch_nn_functional_softmax(A, dim=-1, dtype=torch.float32) # .to(A.dtype)
|
||||
A = torch_matmul(A, Vn, out=Qn)
|
||||
A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32) # .to(A.dtype)
|
||||
A = torch_matmul(A, Vn, out = Qn)
|
||||
# else:
|
||||
# A = scaled_dot_product_attention(Qn, Kn, Vn, attn_mask = attention_mask, is_causal = False)
|
||||
# pass
|
||||
A = A.transpose(1, 2)
|
||||
A = A.reshape(bsz, 1, attention_size)
|
||||
A = fast_linear_forward(self.o_proj, A, out=self.temp_O)
|
||||
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
|
||||
return A, (Kn, Vn)
|
||||
|
||||
|
||||
|
|
@ -418,7 +418,7 @@ def GraniteModel_fast_forward_inference(
|
|||
input_ids,
|
||||
past_key_values,
|
||||
position_ids,
|
||||
attention_mask=None,
|
||||
attention_mask = None,
|
||||
):
|
||||
input_ids = input_ids[:, : self.max_seq_length]
|
||||
hidden_states = self.model.embed_tokens(input_ids)
|
||||
|
|
@ -459,37 +459,37 @@ def GraniteModel_fast_forward_inference(
|
|||
)
|
||||
hidden_states, present_key_value = GraniteAttention_fast_forward_inference(
|
||||
decoder_layer.self_attn,
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=past_key_values[idx],
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_prefill=not hasattr(decoder_layer.self_attn, "paged_attention"),
|
||||
position_embeddings=position_embeddings,
|
||||
hidden_states = hidden_states,
|
||||
past_key_value = past_key_values[idx],
|
||||
position_ids = position_ids,
|
||||
attention_mask = attention_mask,
|
||||
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
|
||||
position_embeddings = position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm_inference(
|
||||
decoder_layer.post_attention_layernorm, hidden_states
|
||||
)
|
||||
hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
|
||||
|
||||
next_decoder_cache.append(present_key_value)
|
||||
hidden_states = fast_rms_layernorm_inference(self.model.norm, hidden_states)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=[],
|
||||
attentions=[],
|
||||
last_hidden_state = hidden_states,
|
||||
past_key_values = next_decoder_cache,
|
||||
hidden_states = [],
|
||||
attentions = [],
|
||||
)
|
||||
|
||||
|
||||
class GraniteRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
def __init__(self, config):
|
||||
super().__init__(config=config)
|
||||
super().__init__(config = config)
|
||||
|
||||
|
||||
def patched_init(original_init):
|
||||
|
|
@ -510,10 +510,10 @@ class FastGraniteModel(FastLlamaModel):
|
|||
@staticmethod
|
||||
def pre_patch():
|
||||
init_name, function = patch_linear_scaling(
|
||||
model_name="granite",
|
||||
rope_module=GraniteRotaryEmbedding,
|
||||
scaled_rope_module=LlamaLinearScalingRotaryEmbedding,
|
||||
attention_module=GraniteAttention,
|
||||
model_name = "granite",
|
||||
rope_module = GraniteRotaryEmbedding,
|
||||
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
|
||||
attention_module = GraniteAttention,
|
||||
)
|
||||
if init_name is not None:
|
||||
exec(function, globals())
|
||||
|
|
@ -548,7 +548,7 @@ class FastGraniteModel(FastLlamaModel):
|
|||
model.config.update({"unsloth_version": __version__})
|
||||
|
||||
# We also do this for the lm_head
|
||||
lm_head = torch.nn.Linear(1, 1, bias=None)
|
||||
lm_head = torch.nn.Linear(1, 1, bias = None)
|
||||
del lm_head.weight
|
||||
lm_head.weight = model.lm_head.weight
|
||||
lm_head.in_features = lm_head.weight.shape[1]
|
||||
|
|
@ -560,7 +560,7 @@ class FastGraniteModel(FastLlamaModel):
|
|||
model.model.embed_tokens.weight.data_ptr()
|
||||
!= model.lm_head.weight.data_ptr()
|
||||
):
|
||||
lm_head = torch.nn.Linear(1, 1, bias=None)
|
||||
lm_head = torch.nn.Linear(1, 1, bias = None)
|
||||
del lm_head.weight
|
||||
lm_head.weight = model.model.embed_tokens.weight
|
||||
lm_head.in_features = lm_head.weight.shape[1]
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -120,33 +120,33 @@ DISABLE_SDPA_MODEL_NAMES = [
|
|||
class FastLanguageModel(FastLlamaModel):
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
model_name="unsloth/Llama-3.2-1B-Instruct",
|
||||
max_seq_length=2048,
|
||||
dtype=None,
|
||||
load_in_4bit=True, # 4bit QLoRA
|
||||
load_in_8bit=False, # 8bit LoRA
|
||||
load_in_16bit=False, # 16bit LoRA
|
||||
full_finetuning=False,
|
||||
token=None,
|
||||
device_map="sequential",
|
||||
rope_scaling=None,
|
||||
fix_tokenizer=True,
|
||||
trust_remote_code=False,
|
||||
use_gradient_checkpointing="unsloth",
|
||||
resize_model_vocab=None,
|
||||
revision=None,
|
||||
use_exact_model_name=False,
|
||||
offload_embedding=False,
|
||||
float32_mixed_precision=None, # Forces float32 mixed precision
|
||||
fast_inference=False, # uses vLLM
|
||||
gpu_memory_utilization=0.5,
|
||||
float8_kv_cache=False,
|
||||
random_state=3407,
|
||||
max_lora_rank=64,
|
||||
disable_log_stats=True,
|
||||
qat_scheme=None,
|
||||
load_in_fp8=False, # fp8 LoRA (True, False, 'block')
|
||||
unsloth_tiled_mlp=False,
|
||||
model_name = "unsloth/Llama-3.2-1B-Instruct",
|
||||
max_seq_length = 2048,
|
||||
dtype = None,
|
||||
load_in_4bit = True, # 4bit QLoRA
|
||||
load_in_8bit = False, # 8bit LoRA
|
||||
load_in_16bit = False, # 16bit LoRA
|
||||
full_finetuning = False,
|
||||
token = None,
|
||||
device_map = "sequential",
|
||||
rope_scaling = None,
|
||||
fix_tokenizer = True,
|
||||
trust_remote_code = False,
|
||||
use_gradient_checkpointing = "unsloth",
|
||||
resize_model_vocab = None,
|
||||
revision = None,
|
||||
use_exact_model_name = False,
|
||||
offload_embedding = False,
|
||||
float32_mixed_precision = None, # Forces float32 mixed precision
|
||||
fast_inference = False, # uses vLLM
|
||||
gpu_memory_utilization = 0.5,
|
||||
float8_kv_cache = False,
|
||||
random_state = 3407,
|
||||
max_lora_rank = 64,
|
||||
disable_log_stats = True,
|
||||
qat_scheme = None,
|
||||
load_in_fp8 = False, # fp8 LoRA (True, False, 'block')
|
||||
unsloth_tiled_mlp = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
|
@ -157,40 +157,40 @@ class FastLanguageModel(FastLlamaModel):
|
|||
try:
|
||||
from huggingface_hub import login
|
||||
|
||||
login(token=token)
|
||||
login(token = token)
|
||||
except:
|
||||
pass
|
||||
if load_in_8bit or full_finetuning or qat_scheme is not None:
|
||||
return FastModel.from_pretrained(
|
||||
model_name=model_name,
|
||||
max_seq_length=max_seq_length,
|
||||
dtype=dtype,
|
||||
load_in_4bit=load_in_4bit,
|
||||
load_in_8bit=load_in_8bit,
|
||||
load_in_16bit=load_in_16bit,
|
||||
full_finetuning=full_finetuning,
|
||||
token=token,
|
||||
device_map=device_map,
|
||||
rope_scaling=rope_scaling, # [TODO] No effect
|
||||
fix_tokenizer=fix_tokenizer, # [TODO] No effect
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
resize_model_vocab=resize_model_vocab, # [TODO] No effect
|
||||
revision=revision,
|
||||
return_logits=False, # Return logits
|
||||
fullgraph=True, # No graph breaks
|
||||
use_exact_model_name=use_exact_model_name,
|
||||
offload_embedding=offload_embedding,
|
||||
float32_mixed_precision=float32_mixed_precision,
|
||||
model_name = model_name,
|
||||
max_seq_length = max_seq_length,
|
||||
dtype = dtype,
|
||||
load_in_4bit = load_in_4bit,
|
||||
load_in_8bit = load_in_8bit,
|
||||
load_in_16bit = load_in_16bit,
|
||||
full_finetuning = full_finetuning,
|
||||
token = token,
|
||||
device_map = device_map,
|
||||
rope_scaling = rope_scaling, # [TODO] No effect
|
||||
fix_tokenizer = fix_tokenizer, # [TODO] No effect
|
||||
trust_remote_code = trust_remote_code,
|
||||
use_gradient_checkpointing = use_gradient_checkpointing,
|
||||
resize_model_vocab = resize_model_vocab, # [TODO] No effect
|
||||
revision = revision,
|
||||
return_logits = False, # Return logits
|
||||
fullgraph = True, # No graph breaks
|
||||
use_exact_model_name = use_exact_model_name,
|
||||
offload_embedding = offload_embedding,
|
||||
float32_mixed_precision = float32_mixed_precision,
|
||||
# Pass vLLM/inference parameters
|
||||
fast_inference=fast_inference,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
float8_kv_cache=float8_kv_cache,
|
||||
random_state=random_state,
|
||||
max_lora_rank=max_lora_rank,
|
||||
disable_log_stats=disable_log_stats,
|
||||
qat_scheme=qat_scheme,
|
||||
load_in_fp8=load_in_fp8,
|
||||
fast_inference = fast_inference,
|
||||
gpu_memory_utilization = gpu_memory_utilization,
|
||||
float8_kv_cache = float8_kv_cache,
|
||||
random_state = random_state,
|
||||
max_lora_rank = max_lora_rank,
|
||||
disable_log_stats = disable_log_stats,
|
||||
qat_scheme = qat_scheme,
|
||||
load_in_fp8 = load_in_fp8,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
@ -231,7 +231,7 @@ class FastLanguageModel(FastLlamaModel):
|
|||
fp8_mode = None
|
||||
if not use_exact_model_name:
|
||||
new_model_name = get_model_name(
|
||||
model_name, load_in_4bit=load_in_4bit, load_in_fp8=load_in_fp8
|
||||
model_name, load_in_4bit = load_in_4bit, load_in_fp8 = load_in_fp8
|
||||
)
|
||||
if new_model_name is None and load_in_fp8 != False:
|
||||
fp8_mode = _get_fp8_mode_and_check_settings(
|
||||
|
|
@ -284,9 +284,9 @@ class FastLanguageModel(FastLlamaModel):
|
|||
try:
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_name,
|
||||
token=token,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
token = token,
|
||||
revision = revision,
|
||||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
is_model = True
|
||||
except Exception as error:
|
||||
|
|
@ -300,9 +300,9 @@ class FastLanguageModel(FastLlamaModel):
|
|||
try:
|
||||
peft_config = PeftConfig.from_pretrained(
|
||||
model_name,
|
||||
token=token,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
token = token,
|
||||
revision = revision,
|
||||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
is_peft = True
|
||||
except Exception as error:
|
||||
|
|
@ -345,7 +345,7 @@ class FastLanguageModel(FastLlamaModel):
|
|||
both_exist = exist_adapter_config and exist_config
|
||||
else:
|
||||
# Because HfFileSystem assumes linux paths, we need to set the path with forward slashes, even on Windows.
|
||||
files = HfFileSystem(token=token).glob(f"{model_name}/*.json")
|
||||
files = HfFileSystem(token = token).glob(f"{model_name}/*.json")
|
||||
files = list(os.path.split(x)[-1] for x in files)
|
||||
if (
|
||||
sum(x == "adapter_config.json" or x == "config.json" for x in files)
|
||||
|
|
@ -393,8 +393,8 @@ class FastLanguageModel(FastLlamaModel):
|
|||
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_name,
|
||||
token=token,
|
||||
trust_remote_code=trust_remote_code,
|
||||
token = token,
|
||||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
|
||||
if not was_disabled:
|
||||
|
|
@ -484,42 +484,42 @@ class FastLanguageModel(FastLlamaModel):
|
|||
# dispatch_model = FastGraniteModel
|
||||
else:
|
||||
return FastModel.from_pretrained(
|
||||
model_name=old_model_name,
|
||||
max_seq_length=max_seq_length,
|
||||
dtype=dtype,
|
||||
load_in_4bit=load_in_4bit,
|
||||
load_in_8bit=load_in_8bit,
|
||||
load_in_16bit=load_in_16bit,
|
||||
full_finetuning=full_finetuning,
|
||||
token=token,
|
||||
device_map=device_map,
|
||||
rope_scaling=rope_scaling, # [TODO] No effect
|
||||
fix_tokenizer=fix_tokenizer, # [TODO] No effect
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
resize_model_vocab=resize_model_vocab, # [TODO] No effect
|
||||
revision=revision,
|
||||
return_logits=False, # Return logits
|
||||
fullgraph=True, # No graph breaks
|
||||
use_exact_model_name=use_exact_model_name,
|
||||
offload_embedding=offload_embedding,
|
||||
float32_mixed_precision=float32_mixed_precision,
|
||||
model_name = old_model_name,
|
||||
max_seq_length = max_seq_length,
|
||||
dtype = dtype,
|
||||
load_in_4bit = load_in_4bit,
|
||||
load_in_8bit = load_in_8bit,
|
||||
load_in_16bit = load_in_16bit,
|
||||
full_finetuning = full_finetuning,
|
||||
token = token,
|
||||
device_map = device_map,
|
||||
rope_scaling = rope_scaling, # [TODO] No effect
|
||||
fix_tokenizer = fix_tokenizer, # [TODO] No effect
|
||||
trust_remote_code = trust_remote_code,
|
||||
use_gradient_checkpointing = use_gradient_checkpointing,
|
||||
resize_model_vocab = resize_model_vocab, # [TODO] No effect
|
||||
revision = revision,
|
||||
return_logits = False, # Return logits
|
||||
fullgraph = True, # No graph breaks
|
||||
use_exact_model_name = use_exact_model_name,
|
||||
offload_embedding = offload_embedding,
|
||||
float32_mixed_precision = float32_mixed_precision,
|
||||
# Pass vLLM/inference parameters
|
||||
fast_inference=fast_inference,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
float8_kv_cache=float8_kv_cache,
|
||||
random_state=random_state,
|
||||
max_lora_rank=max_lora_rank,
|
||||
disable_log_stats=disable_log_stats,
|
||||
qat_scheme=qat_scheme,
|
||||
load_in_fp8=load_in_fp8,
|
||||
unsloth_tiled_mlp=unsloth_tiled_mlp,
|
||||
fast_inference = fast_inference,
|
||||
gpu_memory_utilization = gpu_memory_utilization,
|
||||
float8_kv_cache = float8_kv_cache,
|
||||
random_state = random_state,
|
||||
max_lora_rank = max_lora_rank,
|
||||
disable_log_stats = disable_log_stats,
|
||||
qat_scheme = qat_scheme,
|
||||
load_in_fp8 = load_in_fp8,
|
||||
unsloth_tiled_mlp = unsloth_tiled_mlp,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if use_gradient_checkpointing == "unsloth":
|
||||
patch_unsloth_smart_gradient_checkpointing(dtype=dtype)
|
||||
patch_unsloth_smart_gradient_checkpointing(dtype = dtype)
|
||||
|
||||
# Check if this is local model since the tokenizer gets overwritten
|
||||
if (
|
||||
|
|
@ -535,24 +535,24 @@ class FastLanguageModel(FastLlamaModel):
|
|||
fast_inference, model_name = fast_inference_setup(model_name, model_config)
|
||||
|
||||
model, tokenizer = dispatch_model.from_pretrained(
|
||||
model_name=model_name,
|
||||
max_seq_length=max_seq_length,
|
||||
dtype=_get_dtype(dtype),
|
||||
load_in_4bit=load_in_4bit,
|
||||
token=token,
|
||||
device_map=device_map,
|
||||
rope_scaling=rope_scaling,
|
||||
fix_tokenizer=fix_tokenizer,
|
||||
model_patcher=dispatch_model,
|
||||
tokenizer_name=tokenizer_name,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision if not is_peft else None,
|
||||
fast_inference=fast_inference,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
float8_kv_cache=float8_kv_cache,
|
||||
random_state=random_state,
|
||||
max_lora_rank=max_lora_rank,
|
||||
disable_log_stats=disable_log_stats,
|
||||
model_name = model_name,
|
||||
max_seq_length = max_seq_length,
|
||||
dtype = _get_dtype(dtype),
|
||||
load_in_4bit = load_in_4bit,
|
||||
token = token,
|
||||
device_map = device_map,
|
||||
rope_scaling = rope_scaling,
|
||||
fix_tokenizer = fix_tokenizer,
|
||||
model_patcher = dispatch_model,
|
||||
tokenizer_name = tokenizer_name,
|
||||
trust_remote_code = trust_remote_code,
|
||||
revision = revision if not is_peft else None,
|
||||
fast_inference = fast_inference,
|
||||
gpu_memory_utilization = gpu_memory_utilization,
|
||||
float8_kv_cache = float8_kv_cache,
|
||||
random_state = random_state,
|
||||
max_lora_rank = max_lora_rank,
|
||||
disable_log_stats = disable_log_stats,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
@ -601,10 +601,10 @@ class FastLanguageModel(FastLlamaModel):
|
|||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
old_model_name,
|
||||
token=token,
|
||||
revision=revision,
|
||||
is_trainable=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
token = token,
|
||||
revision = revision,
|
||||
is_trainable = True,
|
||||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
# Patch it as well!
|
||||
model = dispatch_model.patch_peft_model(model, use_gradient_checkpointing)
|
||||
|
|
@ -615,7 +615,7 @@ class FastLanguageModel(FastLlamaModel):
|
|||
"UNSLOTH_TILED_MLP", "arctic" if unsloth_tiled_mlp else "0"
|
||||
)
|
||||
if patch_tiled_mlp_choice != "0" or unsloth_tiled_mlp:
|
||||
patch_tiled_mlp(model, patch_options_str=patch_tiled_mlp_choice)
|
||||
patch_tiled_mlp(model, patch_options_str = patch_tiled_mlp_choice)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
|
@ -645,40 +645,40 @@ class FastModel(FastBaseModel):
|
|||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
model_name="unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
|
||||
max_seq_length=2048,
|
||||
dtype=None,
|
||||
load_in_4bit=True, # 4bit QLoRA
|
||||
load_in_8bit=False, # 8bit LoRA
|
||||
load_in_16bit=False, # 16bit LoRA
|
||||
full_finetuning=False,
|
||||
token=None,
|
||||
device_map="sequential",
|
||||
rope_scaling=None, # [TODO] No effect
|
||||
fix_tokenizer=True, # [TODO] No effect
|
||||
trust_remote_code=False,
|
||||
use_gradient_checkpointing="unsloth",
|
||||
resize_model_vocab=None, # [TODO] No effect
|
||||
revision=None,
|
||||
return_logits=False, # Return logits
|
||||
fullgraph=True, # No graph breaks
|
||||
use_exact_model_name=False,
|
||||
auto_model=None,
|
||||
whisper_language=None,
|
||||
whisper_task=None,
|
||||
unsloth_force_compile=False,
|
||||
offload_embedding=False,
|
||||
float32_mixed_precision=None, # Forces float32 mixed precision
|
||||
model_name = "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
|
||||
max_seq_length = 2048,
|
||||
dtype = None,
|
||||
load_in_4bit = True, # 4bit QLoRA
|
||||
load_in_8bit = False, # 8bit LoRA
|
||||
load_in_16bit = False, # 16bit LoRA
|
||||
full_finetuning = False,
|
||||
token = None,
|
||||
device_map = "sequential",
|
||||
rope_scaling = None, # [TODO] No effect
|
||||
fix_tokenizer = True, # [TODO] No effect
|
||||
trust_remote_code = False,
|
||||
use_gradient_checkpointing = "unsloth",
|
||||
resize_model_vocab = None, # [TODO] No effect
|
||||
revision = None,
|
||||
return_logits = False, # Return logits
|
||||
fullgraph = True, # No graph breaks
|
||||
use_exact_model_name = False,
|
||||
auto_model = None,
|
||||
whisper_language = None,
|
||||
whisper_task = None,
|
||||
unsloth_force_compile = False,
|
||||
offload_embedding = False,
|
||||
float32_mixed_precision = None, # Forces float32 mixed precision
|
||||
# Add the missing vLLM/inference parameters
|
||||
fast_inference=False, # uses vLLM
|
||||
gpu_memory_utilization=0.5,
|
||||
float8_kv_cache=False,
|
||||
random_state=3407,
|
||||
max_lora_rank=64,
|
||||
disable_log_stats=True,
|
||||
qat_scheme=None,
|
||||
load_in_fp8=False, # fp8 LoRA (True, False, 'block')
|
||||
unsloth_tiled_mlp=False,
|
||||
fast_inference = False, # uses vLLM
|
||||
gpu_memory_utilization = 0.5,
|
||||
float8_kv_cache = False,
|
||||
random_state = 3407,
|
||||
max_lora_rank = 64,
|
||||
disable_log_stats = True,
|
||||
qat_scheme = None,
|
||||
load_in_fp8 = False, # fp8 LoRA (True, False, 'block')
|
||||
unsloth_tiled_mlp = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
|
@ -689,7 +689,7 @@ class FastModel(FastBaseModel):
|
|||
try:
|
||||
from huggingface_hub import login
|
||||
|
||||
login(token=token)
|
||||
login(token = token)
|
||||
except:
|
||||
pass
|
||||
if whisper_language is not None:
|
||||
|
|
@ -765,7 +765,7 @@ class FastModel(FastBaseModel):
|
|||
fp8_mode = None
|
||||
if not use_exact_model_name:
|
||||
new_model_name = get_model_name(
|
||||
model_name, load_in_4bit=load_in_4bit, load_in_fp8=load_in_fp8
|
||||
model_name, load_in_4bit = load_in_4bit, load_in_fp8 = load_in_fp8
|
||||
)
|
||||
if new_model_name is None and load_in_fp8 != False:
|
||||
fp8_mode = _get_fp8_mode_and_check_settings(
|
||||
|
|
@ -819,9 +819,9 @@ class FastModel(FastBaseModel):
|
|||
try:
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_name,
|
||||
token=token,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
token = token,
|
||||
revision = revision,
|
||||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
is_model = True
|
||||
except Exception as error:
|
||||
|
|
@ -835,9 +835,9 @@ class FastModel(FastBaseModel):
|
|||
try:
|
||||
peft_config = PeftConfig.from_pretrained(
|
||||
model_name,
|
||||
token=token,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
token = token,
|
||||
revision = revision,
|
||||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
is_peft = True
|
||||
except Exception as error:
|
||||
|
|
@ -1011,7 +1011,7 @@ class FastModel(FastBaseModel):
|
|||
exist_config = os.path.exists(os.path.join(model_name, "config.json"))
|
||||
both_exist = exist_adapter_config and exist_config
|
||||
else:
|
||||
files = HfFileSystem(token=token).glob(f"{model_name}/*.json")
|
||||
files = HfFileSystem(token = token).glob(f"{model_name}/*.json")
|
||||
files = list(os.path.split(x)[-1] for x in files)
|
||||
if (
|
||||
sum(x == "adapter_config.json" or x == "config.json" for x in files)
|
||||
|
|
@ -1059,8 +1059,8 @@ class FastModel(FastBaseModel):
|
|||
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_name,
|
||||
token=token,
|
||||
trust_remote_code=trust_remote_code,
|
||||
token = token,
|
||||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
|
||||
if not was_disabled:
|
||||
|
|
@ -1092,40 +1092,40 @@ class FastModel(FastBaseModel):
|
|||
break
|
||||
# Patch gradient checkpointing
|
||||
if use_gradient_checkpointing == "unsloth":
|
||||
patch_unsloth_smart_gradient_checkpointing(dtype=dtype)
|
||||
patch_unsloth_smart_gradient_checkpointing(dtype = dtype)
|
||||
with redirector:
|
||||
patch_loss_functions(torch_compile=False)
|
||||
patch_loss_functions(torch_compile = False)
|
||||
model_types, supports_sdpa = unsloth_compile_transformers(
|
||||
dtype=dtype,
|
||||
model_name=model_name,
|
||||
model_types=model_types,
|
||||
token=token,
|
||||
sdpa_dynamic_mask=True,
|
||||
sdpa_bool_masks=True,
|
||||
sdpa_gqa_replace=True,
|
||||
sdpa_dynamic_compile=True,
|
||||
compile_attention=True,
|
||||
disable_causal_masks=True,
|
||||
compile_torch_modules=True,
|
||||
compile_custom_modules=True,
|
||||
compile_function_calls=True,
|
||||
fuse_lm_head=True,
|
||||
gradient_checkpointing=True,
|
||||
manual_replacements=True,
|
||||
fast_lora_forwards=True,
|
||||
fast_residual_stream=False,
|
||||
accurate_accumulation=True,
|
||||
epilogue_fusion=True,
|
||||
max_autotune=False,
|
||||
shape_padding=True,
|
||||
cudagraphs=False,
|
||||
debug=False,
|
||||
fullgraph=fullgraph,
|
||||
import_from_cache=False,
|
||||
disable=False,
|
||||
return_logits=return_logits,
|
||||
trust_remote_code=trust_remote_code,
|
||||
unsloth_force_compile=unsloth_force_compile,
|
||||
dtype = dtype,
|
||||
model_name = model_name,
|
||||
model_types = model_types,
|
||||
token = token,
|
||||
sdpa_dynamic_mask = True,
|
||||
sdpa_bool_masks = True,
|
||||
sdpa_gqa_replace = True,
|
||||
sdpa_dynamic_compile = True,
|
||||
compile_attention = True,
|
||||
disable_causal_masks = True,
|
||||
compile_torch_modules = True,
|
||||
compile_custom_modules = True,
|
||||
compile_function_calls = True,
|
||||
fuse_lm_head = True,
|
||||
gradient_checkpointing = True,
|
||||
manual_replacements = True,
|
||||
fast_lora_forwards = True,
|
||||
fast_residual_stream = False,
|
||||
accurate_accumulation = True,
|
||||
epilogue_fusion = True,
|
||||
max_autotune = False,
|
||||
shape_padding = True,
|
||||
cudagraphs = False,
|
||||
debug = False,
|
||||
fullgraph = fullgraph,
|
||||
import_from_cache = False,
|
||||
disable = False,
|
||||
return_logits = return_logits,
|
||||
trust_remote_code = trust_remote_code,
|
||||
unsloth_force_compile = unsloth_force_compile,
|
||||
)
|
||||
# Fix SDPA issues
|
||||
for model_type in DISABLE_SDPA_MODEL_NAMES:
|
||||
|
|
@ -1152,34 +1152,34 @@ class FastModel(FastBaseModel):
|
|||
auto_model = AutoModelForVision2Seq if is_vlm else AutoModelForCausalLM
|
||||
|
||||
model, tokenizer = FastBaseModel.from_pretrained(
|
||||
model_name=model_name,
|
||||
max_seq_length=max_seq_length,
|
||||
dtype=_get_dtype(dtype),
|
||||
load_in_4bit=load_in_4bit,
|
||||
load_in_8bit=load_in_8bit,
|
||||
load_in_16bit=load_in_16bit,
|
||||
full_finetuning=full_finetuning,
|
||||
token=token,
|
||||
device_map=device_map,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision if not is_peft else None,
|
||||
model_types=model_types,
|
||||
tokenizer_name=tokenizer_name,
|
||||
auto_model=auto_model,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
supports_sdpa=supports_sdpa,
|
||||
whisper_language=whisper_language,
|
||||
whisper_task=whisper_task,
|
||||
auto_config=model_config,
|
||||
offload_embedding=offload_embedding,
|
||||
float32_mixed_precision=float32_mixed_precision,
|
||||
model_name = model_name,
|
||||
max_seq_length = max_seq_length,
|
||||
dtype = _get_dtype(dtype),
|
||||
load_in_4bit = load_in_4bit,
|
||||
load_in_8bit = load_in_8bit,
|
||||
load_in_16bit = load_in_16bit,
|
||||
full_finetuning = full_finetuning,
|
||||
token = token,
|
||||
device_map = device_map,
|
||||
trust_remote_code = trust_remote_code,
|
||||
revision = revision if not is_peft else None,
|
||||
model_types = model_types,
|
||||
tokenizer_name = tokenizer_name,
|
||||
auto_model = auto_model,
|
||||
use_gradient_checkpointing = use_gradient_checkpointing,
|
||||
supports_sdpa = supports_sdpa,
|
||||
whisper_language = whisper_language,
|
||||
whisper_task = whisper_task,
|
||||
auto_config = model_config,
|
||||
offload_embedding = offload_embedding,
|
||||
float32_mixed_precision = float32_mixed_precision,
|
||||
# Pass vLLM/inference parameters
|
||||
fast_inference=fast_inference,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
float8_kv_cache=float8_kv_cache,
|
||||
random_state=random_state,
|
||||
max_lora_rank=max_lora_rank,
|
||||
disable_log_stats=disable_log_stats,
|
||||
fast_inference = fast_inference,
|
||||
gpu_memory_utilization = gpu_memory_utilization,
|
||||
float8_kv_cache = float8_kv_cache,
|
||||
random_state = random_state,
|
||||
max_lora_rank = max_lora_rank,
|
||||
disable_log_stats = disable_log_stats,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
@ -1228,14 +1228,14 @@ class FastModel(FastBaseModel):
|
|||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
old_model_name,
|
||||
token=token,
|
||||
revision=revision,
|
||||
is_trainable=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
token = token,
|
||||
revision = revision,
|
||||
is_trainable = True,
|
||||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
# Patch it as well!
|
||||
model = FastBaseModel.post_patch_model(
|
||||
model, use_gradient_checkpointing, trust_remote_code=trust_remote_code
|
||||
model, use_gradient_checkpointing, trust_remote_code = trust_remote_code
|
||||
)
|
||||
|
||||
# Apply QAT if specified
|
||||
|
|
@ -1249,7 +1249,7 @@ class FastModel(FastBaseModel):
|
|||
"UNSLOTH_TILED_MLP", "arctic" if unsloth_tiled_mlp else "0"
|
||||
)
|
||||
if patch_tiled_mlp_choice != "0" or unsloth_tiled_mlp:
|
||||
patch_tiled_mlp(model, patch_options_str=patch_tiled_mlp_choice)
|
||||
patch_tiled_mlp(model, patch_options_str = patch_tiled_mlp_choice)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
|
|
|||
|
|
@ -39,10 +39,10 @@ class FastQwen2Model(FastLlamaModel):
|
|||
@staticmethod
|
||||
def pre_patch():
|
||||
init_name, function = patch_linear_scaling(
|
||||
model_name="qwen2",
|
||||
rope_module=LlamaRotaryEmbedding,
|
||||
scaled_rope_module=LlamaLinearScalingRotaryEmbedding,
|
||||
attention_module=Qwen2Attention,
|
||||
model_name = "qwen2",
|
||||
rope_module = LlamaRotaryEmbedding,
|
||||
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
|
||||
attention_module = Qwen2Attention,
|
||||
)
|
||||
if init_name is not None:
|
||||
exec(function, globals())
|
||||
|
|
@ -72,30 +72,30 @@ class FastQwen2Model(FastLlamaModel):
|
|||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
model_name="Qwen/Qwen2-7B",
|
||||
max_seq_length=4096,
|
||||
dtype=None,
|
||||
load_in_4bit=True,
|
||||
token=None,
|
||||
device_map="sequential",
|
||||
rope_scaling=None, # Qwen2 does not support RoPE scaling
|
||||
fix_tokenizer=True,
|
||||
model_patcher=None,
|
||||
tokenizer_name=None,
|
||||
trust_remote_code=False,
|
||||
model_name = "Qwen/Qwen2-7B",
|
||||
max_seq_length = 4096,
|
||||
dtype = None,
|
||||
load_in_4bit = True,
|
||||
token = None,
|
||||
device_map = "sequential",
|
||||
rope_scaling = None, # Qwen2 does not support RoPE scaling
|
||||
fix_tokenizer = True,
|
||||
model_patcher = None,
|
||||
tokenizer_name = None,
|
||||
trust_remote_code = False,
|
||||
**kwargs,
|
||||
):
|
||||
return FastLlamaModel.from_pretrained(
|
||||
model_name=model_name,
|
||||
max_seq_length=max_seq_length,
|
||||
dtype=dtype,
|
||||
load_in_4bit=load_in_4bit,
|
||||
token=token,
|
||||
device_map=device_map,
|
||||
rope_scaling=rope_scaling,
|
||||
fix_tokenizer=fix_tokenizer,
|
||||
model_patcher=FastQwen2Model,
|
||||
tokenizer_name=tokenizer_name,
|
||||
trust_remote_code=trust_remote_code,
|
||||
model_name = model_name,
|
||||
max_seq_length = max_seq_length,
|
||||
dtype = dtype,
|
||||
load_in_4bit = load_in_4bit,
|
||||
token = token,
|
||||
device_map = device_map,
|
||||
rope_scaling = rope_scaling,
|
||||
fix_tokenizer = fix_tokenizer,
|
||||
model_patcher = FastQwen2Model,
|
||||
tokenizer_name = tokenizer_name,
|
||||
trust_remote_code = trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ def Qwen3Attention_fast_forward(
|
|||
else:
|
||||
# Extend RoPE dynamically to fit in VRA
|
||||
rotary_emb = self.rotary_emb
|
||||
rotary_emb.extend_rope_embedding(V, seq_len=kv_seq_len)
|
||||
rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
|
||||
device_index = Q.device.index
|
||||
|
||||
if position_ids is None:
|
||||
|
|
@ -126,8 +126,8 @@ def Qwen3Attention_fast_forward(
|
|||
Q, K = fast_rope_embedding(Q, K, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
K = torch.cat([past_key_value[0], K], dim=2)
|
||||
V = torch.cat([past_key_value[1], V], dim=2)
|
||||
K = torch.cat([past_key_value[0], K], dim = 2)
|
||||
V = torch.cat([past_key_value[1], V], dim = 2)
|
||||
past_key_value = (K, V) if use_cache else None
|
||||
|
||||
# Attention module
|
||||
|
|
@ -163,7 +163,7 @@ def Qwen3Attention_fast_forward(
|
|||
K = K.view(1, K_M, n_kv_heads, n_groups, head_dim)
|
||||
V = V.view(1, V_M, n_kv_heads, n_groups, head_dim)
|
||||
|
||||
A = xformers_attention(Q, K, V, attn_bias=causal_mask)
|
||||
A = xformers_attention(Q, K, V, attn_bias = causal_mask)
|
||||
A = A.view(bsz, q_len, n_heads, head_dim)
|
||||
|
||||
elif HAS_FLASH_ATTENTION and attention_mask is None:
|
||||
|
|
@ -172,7 +172,7 @@ def Qwen3Attention_fast_forward(
|
|||
V = V.transpose(1, 2)
|
||||
sw = kv_seq_len
|
||||
window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
|
||||
A = flash_attn_func(Q, K, V, causal=True, window_size=window)
|
||||
A = flash_attn_func(Q, K, V, causal = True, window_size = window)
|
||||
else:
|
||||
# Grouped query attention
|
||||
# if n_groups != 1:
|
||||
|
|
@ -195,7 +195,7 @@ def Qwen3Attention_fast_forward(
|
|||
is_causal = False
|
||||
|
||||
A = scaled_dot_product_attention(
|
||||
Q, K, V, attn_mask=attention_mask, is_causal=is_causal
|
||||
Q, K, V, attn_mask = attention_mask, is_causal = is_causal
|
||||
)
|
||||
# Go back to (batch_size, seq_len, n_heads, head_dim)
|
||||
A = A.transpose(1, 2).contiguous()
|
||||
|
|
@ -214,8 +214,8 @@ def Qwen3Attention_fast_forward_inference(
|
|||
hidden_states: torch.Tensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]],
|
||||
position_ids,
|
||||
do_prefill=False,
|
||||
attention_mask=None,
|
||||
do_prefill = False,
|
||||
attention_mask = None,
|
||||
):
|
||||
"""
|
||||
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
|
||||
|
|
@ -267,29 +267,29 @@ def Qwen3Attention_fast_forward_inference(
|
|||
if do_prefill:
|
||||
self.paged_attention = torch.empty(
|
||||
(KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
dtype = dtype,
|
||||
device = device,
|
||||
)
|
||||
self.paged_attention_K = self.paged_attention[:, 0]
|
||||
self.paged_attention_V = self.paged_attention[:, 1]
|
||||
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
|
||||
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
|
||||
self.temp_QA = torch.empty(
|
||||
(2, bsz, 1, attention_size), dtype=dtype, device=device
|
||||
(2, bsz, 1, attention_size), dtype = dtype, device = device
|
||||
)
|
||||
self.temp_KV = torch.empty(
|
||||
(2, bsz, 1, n_kv_heads * head_dim), dtype=dtype, device=device
|
||||
(2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
|
||||
)
|
||||
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype=dtype, device=device)
|
||||
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
|
||||
|
||||
# Mistral Nemo 12b has weird dimensions
|
||||
if attention_size != hidden_size:
|
||||
self.temp_O = torch.empty((1, bsz, hidden_size), dtype=dtype, device=device)
|
||||
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
|
||||
else:
|
||||
self.temp_O = self.temp_QA[1][:, :, :hidden_size]
|
||||
|
||||
self.attention = torch.empty(
|
||||
(bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype=dtype, device=device
|
||||
(bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
|
||||
)
|
||||
self.scalar = 1.0 / math_sqrt(self.head_dim)
|
||||
self.half_head_dim = head_dim // 2
|
||||
|
|
@ -309,9 +309,9 @@ def Qwen3Attention_fast_forward_inference(
|
|||
(bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
|
||||
)
|
||||
|
||||
Qn = fast_linear_forward(self.q_proj, Xn, out=self.temp_QA[0])
|
||||
Kn = fast_linear_forward(self.k_proj, Xn, out=self.temp_KV[0])
|
||||
Vn = fast_linear_forward(self.v_proj, Xn, out=self.temp_KV[1])
|
||||
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
|
||||
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
|
||||
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
|
||||
Qn = Qn.view(
|
||||
bsz, 1, n_heads, head_dim
|
||||
) # .transpose(1, 2) # we will transpose after normalisation
|
||||
|
|
@ -399,30 +399,30 @@ def Qwen3Attention_fast_forward_inference(
|
|||
Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
|
||||
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
|
||||
A = torch_matmul(
|
||||
Qn, Knn.transpose(2, 3), out=self.attention[:, :, :, :cached_len]
|
||||
Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]
|
||||
)
|
||||
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
|
||||
A[:] = torch_nn_functional_softmax(
|
||||
A, dim=-1, dtype=torch.float32
|
||||
A, dim = -1, dtype = torch.float32
|
||||
) # .to(A.dtype)
|
||||
A = torch_matmul(A, Vnn, out=Qn)
|
||||
A = torch_matmul(A, Vnn, out = Qn)
|
||||
else:
|
||||
if SDPA_HAS_GQA:
|
||||
A = scaled_dot_product_attention(
|
||||
Qn,
|
||||
Knn,
|
||||
Vnn,
|
||||
attn_mask=attention_mask,
|
||||
is_causal=is_causal,
|
||||
enable_gqa=True,
|
||||
attn_mask = attention_mask,
|
||||
is_causal = is_causal,
|
||||
enable_gqa = True,
|
||||
)
|
||||
else:
|
||||
A = scaled_dot_product_attention(
|
||||
Qn, Knn, Vnn, attn_mask=attention_mask, is_causal=is_causal
|
||||
Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = is_causal
|
||||
)
|
||||
A = A.transpose(1, 2)
|
||||
A = A.reshape(bsz, 1, attention_size)
|
||||
A = fast_linear_forward(self.o_proj, A, out=self.temp_O)
|
||||
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
|
||||
return A, (Kn, Vn)
|
||||
|
||||
|
||||
|
|
@ -430,10 +430,10 @@ class FastQwen3Model(FastLlamaModel):
|
|||
@staticmethod
|
||||
def pre_patch():
|
||||
init_name, function = patch_linear_scaling(
|
||||
model_name="Qwen3",
|
||||
rope_module=LlamaRotaryEmbedding,
|
||||
scaled_rope_module=LlamaLinearScalingRotaryEmbedding,
|
||||
attention_module=Qwen3Attention,
|
||||
model_name = "Qwen3",
|
||||
rope_module = LlamaRotaryEmbedding,
|
||||
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
|
||||
attention_module = Qwen3Attention,
|
||||
)
|
||||
if init_name is not None:
|
||||
exec(function, globals())
|
||||
|
|
@ -463,30 +463,30 @@ class FastQwen3Model(FastLlamaModel):
|
|||
|
||||
@staticmethod
|
||||
def from_pretrained( # TODO: Change after release
|
||||
model_name="Qwen/Qwen3-7B",
|
||||
max_seq_length=4096,
|
||||
dtype=None,
|
||||
load_in_4bit=True,
|
||||
token=None,
|
||||
device_map="sequential",
|
||||
rope_scaling=None,
|
||||
fix_tokenizer=True,
|
||||
model_patcher=None,
|
||||
tokenizer_name=None,
|
||||
trust_remote_code=False,
|
||||
model_name = "Qwen/Qwen3-7B",
|
||||
max_seq_length = 4096,
|
||||
dtype = None,
|
||||
load_in_4bit = True,
|
||||
token = None,
|
||||
device_map = "sequential",
|
||||
rope_scaling = None,
|
||||
fix_tokenizer = True,
|
||||
model_patcher = None,
|
||||
tokenizer_name = None,
|
||||
trust_remote_code = False,
|
||||
**kwargs,
|
||||
):
|
||||
return FastLlamaModel.from_pretrained(
|
||||
model_name=model_name,
|
||||
max_seq_length=max_seq_length,
|
||||
dtype=dtype,
|
||||
load_in_4bit=load_in_4bit,
|
||||
token=token,
|
||||
device_map=device_map,
|
||||
rope_scaling=rope_scaling,
|
||||
fix_tokenizer=fix_tokenizer,
|
||||
model_patcher=FastQwen3Model,
|
||||
tokenizer_name=tokenizer_name,
|
||||
trust_remote_code=trust_remote_code,
|
||||
model_name = model_name,
|
||||
max_seq_length = max_seq_length,
|
||||
dtype = dtype,
|
||||
load_in_4bit = load_in_4bit,
|
||||
token = token,
|
||||
device_map = device_map,
|
||||
rope_scaling = rope_scaling,
|
||||
fix_tokenizer = fix_tokenizer,
|
||||
model_patcher = FastQwen3Model,
|
||||
tokenizer_name = tokenizer_name,
|
||||
trust_remote_code = trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -49,29 +49,29 @@ from unsloth_zoo.utils import Version, _get_dtype
|
|||
torch_nn_functional_softmax = torch.nn.functional.softmax
|
||||
|
||||
|
||||
def Qwen3MoeSparseMoeBlock_fast_forward(self, X, temp_gate=None, temp_up=None):
|
||||
def Qwen3MoeSparseMoeBlock_fast_forward(self, X, temp_gate = None, temp_up = None):
|
||||
# adapted from https://github.com/huggingface/transformers/pull/36878/files#diff-0855b77fc27ad9449158a1c74953f909b011c00de7125f7c8e68d0ff209c092aR356-R370
|
||||
|
||||
bsz, seq_len, hd = X.shape
|
||||
X = X.view(-1, hd)
|
||||
|
||||
router_logits = fast_linear_forward(
|
||||
self.gate_proj, X, out=temp_gate
|
||||
self.gate_proj, X, out = temp_gate
|
||||
) # pretty much the only change from transformers implementation.
|
||||
|
||||
routing_weights = torch_nn_functional_softmax(
|
||||
router_logits, dim=-1, dtype=torch.float32
|
||||
router_logits, dim = -1, dtype = torch.float32
|
||||
)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim = -1)
|
||||
routing_weights /= routing_weights.sum(dim = -1, keepdim = True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(X.dtype)
|
||||
final_X = torch.zeros((bsz * seq_len, hd), dtype=torch.float32, device=X.device)
|
||||
final_X = torch.zeros((bsz * seq_len, hd), dtype = torch.float32, device = X.device)
|
||||
|
||||
# One hot encode the selected experts to create an expert mask
|
||||
# this will be used to easily index which expert is going to be sollicitated
|
||||
expert_mask = torch.nn.functional.one_hot(
|
||||
selected_experts, num_classes=self.num_experts
|
||||
selected_experts, num_classes = self.num_experts
|
||||
).permute(2, 1, 0)
|
||||
|
||||
# Loop over all available experts in the model and perform the computation on each expert
|
||||
|
|
@ -119,16 +119,16 @@ def Qwen3MoeDecoderLayer_fast_forward(
|
|||
self.input_layernorm, hidden_states
|
||||
)
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
_flag_for_generation=self._flag_for_generation,
|
||||
hidden_states = hidden_states,
|
||||
causal_mask = causal_mask,
|
||||
attention_mask = attention_mask,
|
||||
position_ids = position_ids,
|
||||
past_key_value = past_key_value,
|
||||
output_attentions = output_attentions,
|
||||
use_cache = use_cache,
|
||||
padding_mask = padding_mask,
|
||||
position_embeddings = position_embeddings,
|
||||
_flag_for_generation = self._flag_for_generation,
|
||||
)
|
||||
hidden_states += residual
|
||||
|
||||
|
|
@ -145,15 +145,15 @@ def Qwen3MoeDecoderLayer_fast_forward(
|
|||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
hidden_states = hidden_states,
|
||||
causal_mask = causal_mask,
|
||||
attention_mask = attention_mask,
|
||||
position_ids = position_ids,
|
||||
past_key_value = past_key_value,
|
||||
output_attentions = output_attentions,
|
||||
use_cache = use_cache,
|
||||
padding_mask = padding_mask,
|
||||
position_embeddings = position_embeddings,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
|
|
@ -177,10 +177,10 @@ class FastQwen3MoeModel(FastQwen3Model):
|
|||
@staticmethod
|
||||
def pre_patch():
|
||||
init_name, function = patch_linear_scaling(
|
||||
model_name="Qwen3Moe",
|
||||
rope_module=LlamaRotaryEmbedding,
|
||||
scaled_rope_module=LlamaLinearScalingRotaryEmbedding,
|
||||
attention_module=Qwen3MoeAttention,
|
||||
model_name = "Qwen3Moe",
|
||||
rope_module = LlamaRotaryEmbedding,
|
||||
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
|
||||
attention_module = Qwen3MoeAttention,
|
||||
)
|
||||
if init_name is not None:
|
||||
exec(function, globals())
|
||||
|
|
@ -214,30 +214,30 @@ class FastQwen3MoeModel(FastQwen3Model):
|
|||
|
||||
@staticmethod
|
||||
def from_pretrained( # TODO: Change after release
|
||||
model_name="Qwen/Qwen3-7B",
|
||||
max_seq_length=4096,
|
||||
dtype=None,
|
||||
load_in_4bit=True,
|
||||
token=None,
|
||||
device_map="sequential",
|
||||
rope_scaling=None,
|
||||
fix_tokenizer=True,
|
||||
model_patcher=None,
|
||||
tokenizer_name=None,
|
||||
trust_remote_code=False,
|
||||
model_name = "Qwen/Qwen3-7B",
|
||||
max_seq_length = 4096,
|
||||
dtype = None,
|
||||
load_in_4bit = True,
|
||||
token = None,
|
||||
device_map = "sequential",
|
||||
rope_scaling = None,
|
||||
fix_tokenizer = True,
|
||||
model_patcher = None,
|
||||
tokenizer_name = None,
|
||||
trust_remote_code = False,
|
||||
**kwargs,
|
||||
):
|
||||
return FastLlamaModel.from_pretrained(
|
||||
model_name=model_name,
|
||||
max_seq_length=max_seq_length,
|
||||
dtype=dtype,
|
||||
load_in_4bit=load_in_4bit,
|
||||
token=token,
|
||||
device_map=device_map,
|
||||
rope_scaling=rope_scaling,
|
||||
fix_tokenizer=fix_tokenizer,
|
||||
model_patcher=FastQwen3Model,
|
||||
tokenizer_name=tokenizer_name,
|
||||
trust_remote_code=trust_remote_code,
|
||||
model_name = model_name,
|
||||
max_seq_length = max_seq_length,
|
||||
dtype = dtype,
|
||||
load_in_4bit = load_in_4bit,
|
||||
token = token,
|
||||
device_map = device_map,
|
||||
rope_scaling = rope_scaling,
|
||||
fix_tokenizer = fix_tokenizer,
|
||||
model_patcher = FastQwen3Model,
|
||||
tokenizer_name = tokenizer_name,
|
||||
trust_remote_code = trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -30,70 +30,70 @@ class DeepseekR1ModelInfo(ModelInfo):
|
|||
|
||||
# Deepseek V3 Model Meta
|
||||
DeepseekV3Meta = ModelMeta(
|
||||
org="deepseek-ai",
|
||||
base_name="DeepSeek",
|
||||
instruct_tags=[None],
|
||||
model_version="3",
|
||||
model_sizes=[""],
|
||||
model_info_cls=DeepseekV3ModelInfo,
|
||||
is_multimodal=False,
|
||||
quant_types=[QuantType.NONE, QuantType.BF16],
|
||||
org = "deepseek-ai",
|
||||
base_name = "DeepSeek",
|
||||
instruct_tags = [None],
|
||||
model_version = "3",
|
||||
model_sizes = [""],
|
||||
model_info_cls = DeepseekV3ModelInfo,
|
||||
is_multimodal = False,
|
||||
quant_types = [QuantType.NONE, QuantType.BF16],
|
||||
)
|
||||
|
||||
DeepseekV3_0324Meta = ModelMeta(
|
||||
org="deepseek-ai",
|
||||
base_name="DeepSeek",
|
||||
instruct_tags=[None],
|
||||
model_version="3-0324",
|
||||
model_sizes=[""],
|
||||
model_info_cls=DeepseekV3ModelInfo,
|
||||
is_multimodal=False,
|
||||
quant_types=[QuantType.NONE, QuantType.GGUF],
|
||||
org = "deepseek-ai",
|
||||
base_name = "DeepSeek",
|
||||
instruct_tags = [None],
|
||||
model_version = "3-0324",
|
||||
model_sizes = [""],
|
||||
model_info_cls = DeepseekV3ModelInfo,
|
||||
is_multimodal = False,
|
||||
quant_types = [QuantType.NONE, QuantType.GGUF],
|
||||
)
|
||||
|
||||
DeepseekR1Meta = ModelMeta(
|
||||
org="deepseek-ai",
|
||||
base_name="DeepSeek-R1",
|
||||
instruct_tags=[None],
|
||||
model_version="",
|
||||
model_sizes=[""],
|
||||
model_info_cls=DeepseekR1ModelInfo,
|
||||
is_multimodal=False,
|
||||
quant_types=[QuantType.NONE, QuantType.BF16, QuantType.GGUF],
|
||||
org = "deepseek-ai",
|
||||
base_name = "DeepSeek-R1",
|
||||
instruct_tags = [None],
|
||||
model_version = "",
|
||||
model_sizes = [""],
|
||||
model_info_cls = DeepseekR1ModelInfo,
|
||||
is_multimodal = False,
|
||||
quant_types = [QuantType.NONE, QuantType.BF16, QuantType.GGUF],
|
||||
)
|
||||
|
||||
DeepseekR1ZeroMeta = ModelMeta(
|
||||
org="deepseek-ai",
|
||||
base_name="DeepSeek-R1",
|
||||
instruct_tags=[None],
|
||||
model_version="Zero",
|
||||
model_sizes=[""],
|
||||
model_info_cls=DeepseekR1ModelInfo,
|
||||
is_multimodal=False,
|
||||
quant_types=[QuantType.NONE, QuantType.GGUF],
|
||||
org = "deepseek-ai",
|
||||
base_name = "DeepSeek-R1",
|
||||
instruct_tags = [None],
|
||||
model_version = "Zero",
|
||||
model_sizes = [""],
|
||||
model_info_cls = DeepseekR1ModelInfo,
|
||||
is_multimodal = False,
|
||||
quant_types = [QuantType.NONE, QuantType.GGUF],
|
||||
)
|
||||
|
||||
DeepseekR1DistillLlamaMeta = ModelMeta(
|
||||
org="deepseek-ai",
|
||||
base_name="DeepSeek-R1-Distill",
|
||||
instruct_tags=[None],
|
||||
model_version="Llama",
|
||||
model_sizes=["8", "70"],
|
||||
model_info_cls=DeepseekR1ModelInfo,
|
||||
is_multimodal=False,
|
||||
quant_types={"8": [QuantType.UNSLOTH, QuantType.GGUF], "70": [QuantType.GGUF]},
|
||||
org = "deepseek-ai",
|
||||
base_name = "DeepSeek-R1-Distill",
|
||||
instruct_tags = [None],
|
||||
model_version = "Llama",
|
||||
model_sizes = ["8", "70"],
|
||||
model_info_cls = DeepseekR1ModelInfo,
|
||||
is_multimodal = False,
|
||||
quant_types = {"8": [QuantType.UNSLOTH, QuantType.GGUF], "70": [QuantType.GGUF]},
|
||||
)
|
||||
|
||||
# Deepseek R1 Distill Qwen Model Meta
|
||||
DeepseekR1DistillQwenMeta = ModelMeta(
|
||||
org="deepseek-ai",
|
||||
base_name="DeepSeek-R1-Distill",
|
||||
instruct_tags=[None],
|
||||
model_version="Qwen",
|
||||
model_sizes=["1.5", "7", "14", "32"],
|
||||
model_info_cls=DeepseekR1ModelInfo,
|
||||
is_multimodal=False,
|
||||
quant_types={
|
||||
org = "deepseek-ai",
|
||||
base_name = "DeepSeek-R1-Distill",
|
||||
instruct_tags = [None],
|
||||
model_version = "Qwen",
|
||||
model_sizes = ["1.5", "7", "14", "32"],
|
||||
model_info_cls = DeepseekR1ModelInfo,
|
||||
is_multimodal = False,
|
||||
quant_types = {
|
||||
"1.5": [QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF],
|
||||
"7": [QuantType.UNSLOTH, QuantType.BNB],
|
||||
"14": [QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF],
|
||||
|
|
@ -106,7 +106,7 @@ def register_deepseek_v3_models(include_original_model: bool = False):
|
|||
global _IS_DEEPSEEK_V3_REGISTERED
|
||||
if _IS_DEEPSEEK_V3_REGISTERED:
|
||||
return
|
||||
_register_models(DeepseekV3Meta, include_original_model=include_original_model)
|
||||
_register_models(DeepseekV3Meta, include_original_model = include_original_model)
|
||||
_IS_DEEPSEEK_V3_REGISTERED = True
|
||||
|
||||
|
||||
|
|
@ -114,7 +114,7 @@ def register_deepseek_v3_0324_models(include_original_model: bool = False):
|
|||
global _IS_DEEPSEEK_V3_0324_REGISTERED
|
||||
if _IS_DEEPSEEK_V3_0324_REGISTERED:
|
||||
return
|
||||
_register_models(DeepseekV3_0324Meta, include_original_model=include_original_model)
|
||||
_register_models(DeepseekV3_0324Meta, include_original_model = include_original_model)
|
||||
_IS_DEEPSEEK_V3_0324_REGISTERED = True
|
||||
|
||||
|
||||
|
|
@ -122,7 +122,7 @@ def register_deepseek_r1_models(include_original_model: bool = False):
|
|||
global _IS_DEEPSEEK_R1_REGISTERED
|
||||
if _IS_DEEPSEEK_R1_REGISTERED:
|
||||
return
|
||||
_register_models(DeepseekR1Meta, include_original_model=include_original_model)
|
||||
_register_models(DeepseekR1Meta, include_original_model = include_original_model)
|
||||
_IS_DEEPSEEK_R1_REGISTERED = True
|
||||
|
||||
|
||||
|
|
@ -130,7 +130,7 @@ def register_deepseek_r1_zero_models(include_original_model: bool = False):
|
|||
global _IS_DEEPSEEK_R1_ZERO_REGISTERED
|
||||
if _IS_DEEPSEEK_R1_ZERO_REGISTERED:
|
||||
return
|
||||
_register_models(DeepseekR1ZeroMeta, include_original_model=include_original_model)
|
||||
_register_models(DeepseekR1ZeroMeta, include_original_model = include_original_model)
|
||||
_IS_DEEPSEEK_R1_ZERO_REGISTERED = True
|
||||
|
||||
|
||||
|
|
@ -139,7 +139,7 @@ def register_deepseek_r1_distill_llama_models(include_original_model: bool = Fal
|
|||
if _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED:
|
||||
return
|
||||
_register_models(
|
||||
DeepseekR1DistillLlamaMeta, include_original_model=include_original_model
|
||||
DeepseekR1DistillLlamaMeta, include_original_model = include_original_model
|
||||
)
|
||||
_IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED = True
|
||||
|
||||
|
|
@ -149,21 +149,21 @@ def register_deepseek_r1_distill_qwen_models(include_original_model: bool = Fals
|
|||
if _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED:
|
||||
return
|
||||
_register_models(
|
||||
DeepseekR1DistillQwenMeta, include_original_model=include_original_model
|
||||
DeepseekR1DistillQwenMeta, include_original_model = include_original_model
|
||||
)
|
||||
_IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED = True
|
||||
|
||||
|
||||
def register_deepseek_models(include_original_model: bool = False):
|
||||
register_deepseek_v3_models(include_original_model=include_original_model)
|
||||
register_deepseek_v3_0324_models(include_original_model=include_original_model)
|
||||
register_deepseek_r1_models(include_original_model=include_original_model)
|
||||
register_deepseek_r1_zero_models(include_original_model=include_original_model)
|
||||
register_deepseek_v3_models(include_original_model = include_original_model)
|
||||
register_deepseek_v3_0324_models(include_original_model = include_original_model)
|
||||
register_deepseek_r1_models(include_original_model = include_original_model)
|
||||
register_deepseek_r1_zero_models(include_original_model = include_original_model)
|
||||
register_deepseek_r1_distill_llama_models(
|
||||
include_original_model=include_original_model
|
||||
include_original_model = include_original_model
|
||||
)
|
||||
register_deepseek_r1_distill_qwen_models(
|
||||
include_original_model=include_original_model
|
||||
include_original_model = include_original_model
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -172,7 +172,7 @@ def _list_deepseek_r1_distill_models():
|
|||
from unsloth.utils.hf_hub import list_models
|
||||
|
||||
models: list[HfModelInfo] = list_models(
|
||||
author="unsloth", search="Distill", limit=1000
|
||||
author = "unsloth", search = "Distill", limit = 1000
|
||||
)
|
||||
distill_models = []
|
||||
for model in models:
|
||||
|
|
@ -185,14 +185,14 @@ def _list_deepseek_r1_distill_models():
|
|||
return distill_models
|
||||
|
||||
|
||||
register_deepseek_models(include_original_model=True)
|
||||
register_deepseek_models(include_original_model = True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info
|
||||
|
||||
MODEL_REGISTRY.clear()
|
||||
|
||||
register_deepseek_models(include_original_model=True)
|
||||
register_deepseek_models(include_original_model = True)
|
||||
|
||||
for model_id, model_info in MODEL_REGISTRY.items():
|
||||
model_info = _check_model_info(model_id)
|
||||
|
|
|
|||
|
|
@ -15,26 +15,26 @@ class GemmaModelInfo(ModelInfo):
|
|||
|
||||
# Gemma3 Base Model Meta
|
||||
GemmaMeta3Base = ModelMeta(
|
||||
org="google",
|
||||
base_name="gemma",
|
||||
instruct_tags=["pt"], # pt = base
|
||||
model_version="3",
|
||||
model_sizes=["1", "4", "12", "27"],
|
||||
model_info_cls=GemmaModelInfo,
|
||||
is_multimodal=True,
|
||||
quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
|
||||
org = "google",
|
||||
base_name = "gemma",
|
||||
instruct_tags = ["pt"], # pt = base
|
||||
model_version = "3",
|
||||
model_sizes = ["1", "4", "12", "27"],
|
||||
model_info_cls = GemmaModelInfo,
|
||||
is_multimodal = True,
|
||||
quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
|
||||
)
|
||||
|
||||
# Gemma3 Instruct Model Meta
|
||||
GemmaMeta3Instruct = ModelMeta(
|
||||
org="google",
|
||||
base_name="gemma",
|
||||
instruct_tags=["it"], # it = instruction tuned
|
||||
model_version="3",
|
||||
model_sizes=["1", "4", "12", "27"],
|
||||
model_info_cls=GemmaModelInfo,
|
||||
is_multimodal=True,
|
||||
quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
|
||||
org = "google",
|
||||
base_name = "gemma",
|
||||
instruct_tags = ["it"], # it = instruction tuned
|
||||
model_version = "3",
|
||||
model_sizes = ["1", "4", "12", "27"],
|
||||
model_info_cls = GemmaModelInfo,
|
||||
is_multimodal = True,
|
||||
quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -42,7 +42,7 @@ def register_gemma_3_base_models(include_original_model: bool = False):
|
|||
global _IS_GEMMA_3_BASE_REGISTERED
|
||||
if _IS_GEMMA_3_BASE_REGISTERED:
|
||||
return
|
||||
_register_models(GemmaMeta3Base, include_original_model=include_original_model)
|
||||
_register_models(GemmaMeta3Base, include_original_model = include_original_model)
|
||||
_IS_GEMMA_3_BASE_REGISTERED = True
|
||||
|
||||
|
||||
|
|
@ -50,13 +50,13 @@ def register_gemma_3_instruct_models(include_original_model: bool = False):
|
|||
global _IS_GEMMA_3_INSTRUCT_REGISTERED
|
||||
if _IS_GEMMA_3_INSTRUCT_REGISTERED:
|
||||
return
|
||||
_register_models(GemmaMeta3Instruct, include_original_model=include_original_model)
|
||||
_register_models(GemmaMeta3Instruct, include_original_model = include_original_model)
|
||||
_IS_GEMMA_3_INSTRUCT_REGISTERED = True
|
||||
|
||||
|
||||
def register_gemma_models(include_original_model: bool = False):
|
||||
register_gemma_3_base_models(include_original_model=include_original_model)
|
||||
register_gemma_3_instruct_models(include_original_model=include_original_model)
|
||||
register_gemma_3_base_models(include_original_model = include_original_model)
|
||||
register_gemma_3_instruct_models(include_original_model = include_original_model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -64,7 +64,7 @@ if __name__ == "__main__":
|
|||
|
||||
MODEL_REGISTRY.clear()
|
||||
|
||||
register_gemma_models(include_original_model=True)
|
||||
register_gemma_models(include_original_model = True)
|
||||
|
||||
for model_id, model_info in MODEL_REGISTRY.items():
|
||||
model_info = _check_model_info(model_id)
|
||||
|
|
|
|||
|
|
@ -23,14 +23,14 @@ class MistralSmallModelInfo(ModelInfo):
|
|||
|
||||
|
||||
MistralSmall_2503_Base_Meta = ModelMeta(
|
||||
org="mistralai",
|
||||
base_name="Mistral-Small",
|
||||
instruct_tags=["Base"],
|
||||
model_version=_MISTRAL_SMALL_03_25_VERSION,
|
||||
model_sizes=["24"],
|
||||
model_info_cls=MistralSmallModelInfo,
|
||||
is_multimodal=False,
|
||||
quant_types=[QuantType.NONE, QuantType.UNSLOTH, QuantType.BNB],
|
||||
org = "mistralai",
|
||||
base_name = "Mistral-Small",
|
||||
instruct_tags = ["Base"],
|
||||
model_version = _MISTRAL_SMALL_03_25_VERSION,
|
||||
model_sizes = ["24"],
|
||||
model_info_cls = MistralSmallModelInfo,
|
||||
is_multimodal = False,
|
||||
quant_types = [QuantType.NONE, QuantType.UNSLOTH, QuantType.BNB],
|
||||
)
|
||||
|
||||
MistralSmall_2503_Instruct_Meta = copy.deepcopy(MistralSmall_2503_Base_Meta)
|
||||
|
|
@ -54,23 +54,23 @@ def register_mistral_small_models(include_original_model: bool = False):
|
|||
if _IS_MISTRAL_SMALL_REGISTERED:
|
||||
return
|
||||
_register_models(
|
||||
MistralSmall_2503_Base_Meta, include_original_model=include_original_model
|
||||
MistralSmall_2503_Base_Meta, include_original_model = include_original_model
|
||||
)
|
||||
_register_models(
|
||||
MistralSmall_2503_Instruct_Meta, include_original_model=include_original_model
|
||||
MistralSmall_2503_Instruct_Meta, include_original_model = include_original_model
|
||||
)
|
||||
_register_models(
|
||||
MistralSmall_2501_Base_Meta, include_original_model=include_original_model
|
||||
MistralSmall_2501_Base_Meta, include_original_model = include_original_model
|
||||
)
|
||||
_register_models(
|
||||
MistralSmall_2501_Instruct_Meta, include_original_model=include_original_model
|
||||
MistralSmall_2501_Instruct_Meta, include_original_model = include_original_model
|
||||
)
|
||||
|
||||
_IS_MISTRAL_SMALL_REGISTERED = True
|
||||
|
||||
|
||||
def register_mistral_models(include_original_model: bool = False):
|
||||
register_mistral_small_models(include_original_model=include_original_model)
|
||||
register_mistral_small_models(include_original_model = include_original_model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -78,7 +78,7 @@ if __name__ == "__main__":
|
|||
|
||||
MODEL_REGISTRY.clear()
|
||||
|
||||
register_mistral_models(include_original_model=True)
|
||||
register_mistral_models(include_original_model = True)
|
||||
|
||||
for model_id, model_info in MODEL_REGISTRY.items():
|
||||
model_info = _check_model_info(model_id)
|
||||
|
|
|
|||
|
|
@ -43,50 +43,50 @@ class QwenQVQPreviewModelInfo(ModelInfo):
|
|||
|
||||
# Qwen2.5 Model Meta
|
||||
Qwen_2_5_Meta = ModelMeta(
|
||||
org="Qwen",
|
||||
base_name="Qwen",
|
||||
instruct_tags=[None, "Instruct"],
|
||||
model_version="2.5",
|
||||
model_sizes=["3", "7"],
|
||||
model_info_cls=QwenModelInfo,
|
||||
is_multimodal=False,
|
||||
quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
|
||||
org = "Qwen",
|
||||
base_name = "Qwen",
|
||||
instruct_tags = [None, "Instruct"],
|
||||
model_version = "2.5",
|
||||
model_sizes = ["3", "7"],
|
||||
model_info_cls = QwenModelInfo,
|
||||
is_multimodal = False,
|
||||
quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
|
||||
)
|
||||
|
||||
# Qwen2.5 VL Model Meta
|
||||
Qwen_2_5_VLMeta = ModelMeta(
|
||||
org="Qwen",
|
||||
base_name="Qwen",
|
||||
instruct_tags=["Instruct"], # No base, only instruction tuned
|
||||
model_version="2.5",
|
||||
model_sizes=["3", "7", "32", "72"],
|
||||
model_info_cls=QwenVLModelInfo,
|
||||
is_multimodal=True,
|
||||
quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
|
||||
org = "Qwen",
|
||||
base_name = "Qwen",
|
||||
instruct_tags = ["Instruct"], # No base, only instruction tuned
|
||||
model_version = "2.5",
|
||||
model_sizes = ["3", "7", "32", "72"],
|
||||
model_info_cls = QwenVLModelInfo,
|
||||
is_multimodal = True,
|
||||
quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
|
||||
)
|
||||
|
||||
# Qwen QwQ Model Meta
|
||||
QwenQwQMeta = ModelMeta(
|
||||
org="Qwen",
|
||||
base_name="QwQ",
|
||||
instruct_tags=[None],
|
||||
model_version="",
|
||||
model_sizes=["32"],
|
||||
model_info_cls=QwenQwQModelInfo,
|
||||
is_multimodal=False,
|
||||
quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
|
||||
org = "Qwen",
|
||||
base_name = "QwQ",
|
||||
instruct_tags = [None],
|
||||
model_version = "",
|
||||
model_sizes = ["32"],
|
||||
model_info_cls = QwenQwQModelInfo,
|
||||
is_multimodal = False,
|
||||
quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
|
||||
)
|
||||
|
||||
# Qwen QVQ Preview Model Meta
|
||||
QwenQVQPreviewMeta = ModelMeta(
|
||||
org="Qwen",
|
||||
base_name="QVQ",
|
||||
instruct_tags=[None],
|
||||
model_version="",
|
||||
model_sizes=["72"],
|
||||
model_info_cls=QwenQVQPreviewModelInfo,
|
||||
is_multimodal=True,
|
||||
quant_types=[QuantType.NONE, QuantType.BNB],
|
||||
org = "Qwen",
|
||||
base_name = "QVQ",
|
||||
instruct_tags = [None],
|
||||
model_version = "",
|
||||
model_sizes = ["72"],
|
||||
model_info_cls = QwenQVQPreviewModelInfo,
|
||||
is_multimodal = True,
|
||||
quant_types = [QuantType.NONE, QuantType.BNB],
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -94,7 +94,7 @@ def register_qwen_2_5_models(include_original_model: bool = False):
|
|||
global _IS_QWEN_2_5_REGISTERED
|
||||
if _IS_QWEN_2_5_REGISTERED:
|
||||
return
|
||||
_register_models(Qwen_2_5_Meta, include_original_model=include_original_model)
|
||||
_register_models(Qwen_2_5_Meta, include_original_model = include_original_model)
|
||||
_IS_QWEN_2_5_REGISTERED = True
|
||||
|
||||
|
||||
|
|
@ -102,7 +102,7 @@ def register_qwen_2_5_vl_models(include_original_model: bool = False):
|
|||
global _IS_QWEN_2_5_VL_REGISTERED
|
||||
if _IS_QWEN_2_5_VL_REGISTERED:
|
||||
return
|
||||
_register_models(Qwen_2_5_VLMeta, include_original_model=include_original_model)
|
||||
_register_models(Qwen_2_5_VLMeta, include_original_model = include_original_model)
|
||||
_IS_QWEN_2_5_VL_REGISTERED = True
|
||||
|
||||
|
||||
|
|
@ -110,15 +110,15 @@ def register_qwen_qwq_models(include_original_model: bool = False):
|
|||
global _IS_QWEN_QWQ_REGISTERED
|
||||
if _IS_QWEN_QWQ_REGISTERED:
|
||||
return
|
||||
_register_models(QwenQwQMeta, include_original_model=include_original_model)
|
||||
_register_models(QwenQVQPreviewMeta, include_original_model=include_original_model)
|
||||
_register_models(QwenQwQMeta, include_original_model = include_original_model)
|
||||
_register_models(QwenQVQPreviewMeta, include_original_model = include_original_model)
|
||||
_IS_QWEN_QWQ_REGISTERED = True
|
||||
|
||||
|
||||
def register_qwen_models(include_original_model: bool = False):
|
||||
register_qwen_2_5_models(include_original_model=include_original_model)
|
||||
register_qwen_2_5_vl_models(include_original_model=include_original_model)
|
||||
register_qwen_qwq_models(include_original_model=include_original_model)
|
||||
register_qwen_2_5_models(include_original_model = include_original_model)
|
||||
register_qwen_2_5_vl_models(include_original_model = include_original_model)
|
||||
register_qwen_qwq_models(include_original_model = include_original_model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -126,7 +126,7 @@ if __name__ == "__main__":
|
|||
|
||||
MODEL_REGISTRY.clear()
|
||||
|
||||
register_qwen_models(include_original_model=True)
|
||||
register_qwen_models(include_original_model = True)
|
||||
|
||||
for model_id, model_info in MODEL_REGISTRY.items():
|
||||
model_info = _check_model_info(model_id)
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class ModelInfo:
|
|||
|
||||
@classmethod
|
||||
def construct_model_name(
|
||||
cls, base_name, version, size, quant_type, instruct_tag, key=""
|
||||
cls, base_name, version, size, quant_type, instruct_tag, key = ""
|
||||
):
|
||||
key = cls.append_instruct_tag(key, instruct_tag)
|
||||
key = cls.append_quant_type(key, quant_type)
|
||||
|
|
@ -81,10 +81,10 @@ class ModelMeta:
|
|||
base_name: str
|
||||
model_version: str
|
||||
model_info_cls: type[ModelInfo]
|
||||
model_sizes: list[str] = field(default_factory=list)
|
||||
instruct_tags: list[str] = field(default_factory=list)
|
||||
model_sizes: list[str] = field(default_factory = list)
|
||||
instruct_tags: list[str] = field(default_factory = list)
|
||||
quant_types: list[QuantType] | dict[str, list[QuantType]] = field(
|
||||
default_factory=list
|
||||
default_factory = list
|
||||
)
|
||||
is_multimodal: bool = False
|
||||
|
||||
|
|
@ -104,11 +104,11 @@ def register_model(
|
|||
name: str = None,
|
||||
):
|
||||
name = name or model_info_cls.construct_model_name(
|
||||
base_name=base_name,
|
||||
version=version,
|
||||
size=size,
|
||||
quant_type=quant_type,
|
||||
instruct_tag=instruct_tag,
|
||||
base_name = base_name,
|
||||
version = version,
|
||||
size = size,
|
||||
quant_type = quant_type,
|
||||
instruct_tag = instruct_tag,
|
||||
)
|
||||
key = f"{org}/{name}"
|
||||
|
||||
|
|
@ -118,14 +118,14 @@ def register_model(
|
|||
)
|
||||
|
||||
MODEL_REGISTRY[key] = model_info_cls(
|
||||
org=org,
|
||||
base_name=base_name,
|
||||
version=version,
|
||||
size=size,
|
||||
is_multimodal=is_multimodal,
|
||||
instruct_tag=instruct_tag,
|
||||
quant_type=quant_type,
|
||||
name=name,
|
||||
org = org,
|
||||
base_name = base_name,
|
||||
version = version,
|
||||
size = size,
|
||||
is_multimodal = is_multimodal,
|
||||
instruct_tag = instruct_tag,
|
||||
quant_type = quant_type,
|
||||
name = name,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -137,7 +137,7 @@ def _check_model_info(model_id: str, properties: list[str] = ["lastModified"]):
|
|||
api = HfApi()
|
||||
|
||||
try:
|
||||
model_info: HfModelInfo = api.model_info(model_id, expand=properties)
|
||||
model_info: HfModelInfo = api.model_info(model_id, expand = properties)
|
||||
except Exception as e:
|
||||
if isinstance(e, RepositoryNotFoundError):
|
||||
warnings.warn(f"{model_id} not found on Hugging Face")
|
||||
|
|
@ -168,24 +168,24 @@ def _register_models(model_meta: ModelMeta, include_original_model: bool = False
|
|||
# NOTE: models registered with org="unsloth" and QUANT_TYPE.NONE are aliases of QUANT_TYPE.UNSLOTH
|
||||
_org = "unsloth" # unsloth models -- these are all quantized versions of the original model
|
||||
register_model(
|
||||
model_info_cls=model_info_cls,
|
||||
org=_org,
|
||||
base_name=base_name,
|
||||
version=model_version,
|
||||
size=size,
|
||||
instruct_tag=instruct_tag,
|
||||
quant_type=quant_type,
|
||||
is_multimodal=is_multimodal,
|
||||
model_info_cls = model_info_cls,
|
||||
org = _org,
|
||||
base_name = base_name,
|
||||
version = model_version,
|
||||
size = size,
|
||||
instruct_tag = instruct_tag,
|
||||
quant_type = quant_type,
|
||||
is_multimodal = is_multimodal,
|
||||
)
|
||||
# include original model from releasing organization
|
||||
if include_original_model:
|
||||
register_model(
|
||||
model_info_cls=model_info_cls,
|
||||
org=org,
|
||||
base_name=base_name,
|
||||
version=model_version,
|
||||
size=size,
|
||||
instruct_tag=instruct_tag,
|
||||
quant_type=QuantType.NONE,
|
||||
is_multimodal=is_multimodal,
|
||||
model_info_cls = model_info_cls,
|
||||
org = org,
|
||||
base_name = base_name,
|
||||
version = model_version,
|
||||
size = size,
|
||||
instruct_tag = instruct_tag,
|
||||
quant_type = QuantType.NONE,
|
||||
is_multimodal = is_multimodal,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ def _create_unsloth_optimizer(
|
|||
model,
|
||||
optimizer_cls,
|
||||
optimizer_kwargs,
|
||||
embedding_lr=5e-5,
|
||||
embedding_lr = 5e-5,
|
||||
):
|
||||
lr = optimizer_kwargs["lr"]
|
||||
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def get_model_info(
|
|||
if _HFAPI is None:
|
||||
_HFAPI = HfApi()
|
||||
try:
|
||||
model_info: ModelInfo = _HFAPI.model_info(model_id, expand=properties)
|
||||
model_info: ModelInfo = _HFAPI.model_info(model_id, expand = properties)
|
||||
except Exception as e:
|
||||
print(f"Error getting model info for {model_id}: {e}")
|
||||
model_info = None
|
||||
|
|
@ -68,11 +68,11 @@ def list_models(
|
|||
properties = None
|
||||
|
||||
models: list[ModelInfo] = _HFAPI.list_models(
|
||||
author=author,
|
||||
search=search,
|
||||
sort=sort,
|
||||
limit=limit,
|
||||
expand=properties,
|
||||
full=full,
|
||||
author = author,
|
||||
search = search,
|
||||
sort = sort,
|
||||
limit = limit,
|
||||
expand = properties,
|
||||
full = full,
|
||||
)
|
||||
return models
|
||||
|
|
|
|||
Loading…
Reference in a new issue