mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
EOL LF (unix line endings) normalization (#3478)
This commit is contained in:
parent
f62c454a86
commit
f845cf964f
2 changed files with 230 additions and 228 deletions
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
# Normalize Python files to LF line endings
|
||||||
|
*.py text eol=lf
|
||||||
456
unsloth-cli.py
456
unsloth-cli.py
|
|
@ -1,228 +1,228 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
"""
|
"""
|
||||||
🦥 Starter Script for Fine-Tuning FastLanguageModel with Unsloth
|
🦥 Starter Script for Fine-Tuning FastLanguageModel with Unsloth
|
||||||
|
|
||||||
This script is designed as a starting point for fine-tuning your models using unsloth.
|
This script is designed as a starting point for fine-tuning your models using unsloth.
|
||||||
It includes configurable options for model loading, PEFT parameters, training arguments,
|
It includes configurable options for model loading, PEFT parameters, training arguments,
|
||||||
and model saving/pushing functionalities.
|
and model saving/pushing functionalities.
|
||||||
|
|
||||||
You will likely want to customize this script to suit your specific use case
|
You will likely want to customize this script to suit your specific use case
|
||||||
and requirements.
|
and requirements.
|
||||||
|
|
||||||
Here are a few suggestions for customization:
|
Here are a few suggestions for customization:
|
||||||
- Modify the dataset loading and preprocessing steps to match your data.
|
- Modify the dataset loading and preprocessing steps to match your data.
|
||||||
- Customize the model saving and pushing configurations.
|
- Customize the model saving and pushing configurations.
|
||||||
|
|
||||||
Usage: (most of the options have valid default values this is an extended example for demonstration purposes)
|
Usage: (most of the options have valid default values this is an extended example for demonstration purposes)
|
||||||
python unsloth-cli.py --model_name "unsloth/llama-3-8b" --max_seq_length 8192 --dtype None --load_in_4bit \
|
python unsloth-cli.py --model_name "unsloth/llama-3-8b" --max_seq_length 8192 --dtype None --load_in_4bit \
|
||||||
--r 64 --lora_alpha 32 --lora_dropout 0.1 --bias "none" --use_gradient_checkpointing "unsloth" \
|
--r 64 --lora_alpha 32 --lora_dropout 0.1 --bias "none" --use_gradient_checkpointing "unsloth" \
|
||||||
--random_state 3407 --use_rslora --per_device_train_batch_size 4 --gradient_accumulation_steps 8 \
|
--random_state 3407 --use_rslora --per_device_train_batch_size 4 --gradient_accumulation_steps 8 \
|
||||||
--warmup_steps 5 --max_steps 400 --learning_rate 2e-6 --logging_steps 1 --optim "adamw_8bit" \
|
--warmup_steps 5 --max_steps 400 --learning_rate 2e-6 --logging_steps 1 --optim "adamw_8bit" \
|
||||||
--weight_decay 0.005 --lr_scheduler_type "linear" --seed 3407 --output_dir "outputs" \
|
--weight_decay 0.005 --lr_scheduler_type "linear" --seed 3407 --output_dir "outputs" \
|
||||||
--report_to "tensorboard" --save_model --save_path "model" --quantization_method "f16" \
|
--report_to "tensorboard" --save_model --save_path "model" --quantization_method "f16" \
|
||||||
--push_model --hub_path "hf/model" --hub_token "your_hf_token"
|
--push_model --hub_path "hf/model" --hub_token "your_hf_token"
|
||||||
|
|
||||||
To see a full list of configurable options, use:
|
To see a full list of configurable options, use:
|
||||||
python unsloth-cli.py --help
|
python unsloth-cli.py --help
|
||||||
|
|
||||||
Happy fine-tuning!
|
Happy fine-tuning!
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
def run(args):
|
def run(args):
|
||||||
import torch
|
import torch
|
||||||
from unsloth import FastLanguageModel
|
from unsloth import FastLanguageModel
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from transformers.utils import strtobool
|
from transformers.utils import strtobool
|
||||||
from trl import SFTTrainer, SFTConfig
|
from trl import SFTTrainer, SFTConfig
|
||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
from unsloth import is_bfloat16_supported
|
from unsloth import is_bfloat16_supported
|
||||||
import logging
|
import logging
|
||||||
logging.getLogger('hf-to-gguf').setLevel(logging.WARNING)
|
logging.getLogger('hf-to-gguf').setLevel(logging.WARNING)
|
||||||
|
|
||||||
# Load model and tokenizer
|
# Load model and tokenizer
|
||||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||||
model_name=args.model_name,
|
model_name=args.model_name,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
load_in_4bit=args.load_in_4bit,
|
load_in_4bit=args.load_in_4bit,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Configure PEFT model
|
# Configure PEFT model
|
||||||
model = FastLanguageModel.get_peft_model(
|
model = FastLanguageModel.get_peft_model(
|
||||||
model,
|
model,
|
||||||
r=args.r,
|
r=args.r,
|
||||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
||||||
"gate_proj", "up_proj", "down_proj"],
|
"gate_proj", "up_proj", "down_proj"],
|
||||||
lora_alpha=args.lora_alpha,
|
lora_alpha=args.lora_alpha,
|
||||||
lora_dropout=args.lora_dropout,
|
lora_dropout=args.lora_dropout,
|
||||||
bias=args.bias,
|
bias=args.bias,
|
||||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||||
random_state=args.random_state,
|
random_state=args.random_state,
|
||||||
use_rslora=args.use_rslora,
|
use_rslora=args.use_rslora,
|
||||||
loftq_config=args.loftq_config,
|
loftq_config=args.loftq_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
||||||
|
|
||||||
### Instruction:
|
### Instruction:
|
||||||
{}
|
{}
|
||||||
|
|
||||||
### Input:
|
### Input:
|
||||||
{}
|
{}
|
||||||
|
|
||||||
### Response:
|
### Response:
|
||||||
{}"""
|
{}"""
|
||||||
|
|
||||||
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
|
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
|
||||||
def formatting_prompts_func(examples):
|
def formatting_prompts_func(examples):
|
||||||
instructions = examples["instruction"]
|
instructions = examples["instruction"]
|
||||||
inputs = examples["input"]
|
inputs = examples["input"]
|
||||||
outputs = examples["output"]
|
outputs = examples["output"]
|
||||||
texts = []
|
texts = []
|
||||||
for instruction, input, output in zip(instructions, inputs, outputs):
|
for instruction, input, output in zip(instructions, inputs, outputs):
|
||||||
text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
|
text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
|
||||||
texts.append(text)
|
texts.append(text)
|
||||||
return {"text": texts}
|
return {"text": texts}
|
||||||
|
|
||||||
use_modelscope = strtobool(os.environ.get('UNSLOTH_USE_MODELSCOPE', 'False'))
|
use_modelscope = strtobool(os.environ.get('UNSLOTH_USE_MODELSCOPE', 'False'))
|
||||||
if use_modelscope:
|
if use_modelscope:
|
||||||
from modelscope import MsDataset
|
from modelscope import MsDataset
|
||||||
dataset = MsDataset.load(args.dataset, split="train")
|
dataset = MsDataset.load(args.dataset, split="train")
|
||||||
else:
|
else:
|
||||||
# Load and format dataset
|
# Load and format dataset
|
||||||
dataset = load_dataset(args.dataset, split="train")
|
dataset = load_dataset(args.dataset, split="train")
|
||||||
dataset = dataset.map(formatting_prompts_func, batched=True)
|
dataset = dataset.map(formatting_prompts_func, batched=True)
|
||||||
print("Data is formatted and ready!")
|
print("Data is formatted and ready!")
|
||||||
|
|
||||||
# Configure training arguments
|
# Configure training arguments
|
||||||
training_args = SFTConfig(
|
training_args = SFTConfig(
|
||||||
per_device_train_batch_size=args.per_device_train_batch_size,
|
per_device_train_batch_size=args.per_device_train_batch_size,
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
warmup_steps=args.warmup_steps,
|
warmup_steps=args.warmup_steps,
|
||||||
max_steps=args.max_steps,
|
max_steps=args.max_steps,
|
||||||
learning_rate=args.learning_rate,
|
learning_rate=args.learning_rate,
|
||||||
fp16=not is_bfloat16_supported(),
|
fp16=not is_bfloat16_supported(),
|
||||||
bf16=is_bfloat16_supported(),
|
bf16=is_bfloat16_supported(),
|
||||||
logging_steps=args.logging_steps,
|
logging_steps=args.logging_steps,
|
||||||
optim=args.optim,
|
optim=args.optim,
|
||||||
weight_decay=args.weight_decay,
|
weight_decay=args.weight_decay,
|
||||||
lr_scheduler_type=args.lr_scheduler_type,
|
lr_scheduler_type=args.lr_scheduler_type,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
output_dir=args.output_dir,
|
output_dir=args.output_dir,
|
||||||
report_to=args.report_to,
|
report_to=args.report_to,
|
||||||
max_length=args.max_seq_length,
|
max_length=args.max_seq_length,
|
||||||
dataset_num_proc=2,
|
dataset_num_proc=2,
|
||||||
packing=False,
|
packing=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize trainer
|
# Initialize trainer
|
||||||
trainer = SFTTrainer(
|
trainer = SFTTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
processing_class=tokenizer,
|
processing_class=tokenizer,
|
||||||
train_dataset=dataset,
|
train_dataset=dataset,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Train model
|
# Train model
|
||||||
trainer_stats = trainer.train()
|
trainer_stats = trainer.train()
|
||||||
|
|
||||||
# Save model
|
# Save model
|
||||||
if args.save_model:
|
if args.save_model:
|
||||||
# if args.quantization_method is a list, we will save the model for each quantization method
|
# if args.quantization_method is a list, we will save the model for each quantization method
|
||||||
if args.save_gguf:
|
if args.save_gguf:
|
||||||
if isinstance(args.quantization, list):
|
if isinstance(args.quantization, list):
|
||||||
for quantization_method in args.quantization:
|
for quantization_method in args.quantization:
|
||||||
print(f"Saving model with quantization method: {quantization_method}")
|
print(f"Saving model with quantization method: {quantization_method}")
|
||||||
model.save_pretrained_gguf(
|
model.save_pretrained_gguf(
|
||||||
args.save_path,
|
args.save_path,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
quantization_method=quantization_method,
|
quantization_method=quantization_method,
|
||||||
)
|
)
|
||||||
if args.push_model:
|
if args.push_model:
|
||||||
model.push_to_hub_gguf(
|
model.push_to_hub_gguf(
|
||||||
hub_path=args.hub_path,
|
hub_path=args.hub_path,
|
||||||
hub_token=args.hub_token,
|
hub_token=args.hub_token,
|
||||||
quantization_method=quantization_method,
|
quantization_method=quantization_method,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(f"Saving model with quantization method: {args.quantization}")
|
print(f"Saving model with quantization method: {args.quantization}")
|
||||||
model.save_pretrained_gguf(args.save_path, tokenizer, quantization_method=args.quantization)
|
model.save_pretrained_gguf(args.save_path, tokenizer, quantization_method=args.quantization)
|
||||||
if args.push_model:
|
if args.push_model:
|
||||||
model.push_to_hub_gguf(
|
model.push_to_hub_gguf(
|
||||||
hub_path=args.hub_path,
|
hub_path=args.hub_path,
|
||||||
hub_token=args.hub_token,
|
hub_token=args.hub_token,
|
||||||
quantization_method=quantization_method,
|
quantization_method=quantization_method,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model.save_pretrained_merged(args.save_path, tokenizer, args.save_method)
|
model.save_pretrained_merged(args.save_path, tokenizer, args.save_method)
|
||||||
if args.push_model:
|
if args.push_model:
|
||||||
model.push_to_hub_merged(args.save_path, tokenizer, args.hub_token)
|
model.push_to_hub_merged(args.save_path, tokenizer, args.hub_token)
|
||||||
else:
|
else:
|
||||||
print("Warning: The model is not saved!")
|
print("Warning: The model is not saved!")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
# Define argument parser
|
# Define argument parser
|
||||||
parser = argparse.ArgumentParser(description="🦥 Fine-tune your llm faster using unsloth!")
|
parser = argparse.ArgumentParser(description="🦥 Fine-tune your llm faster using unsloth!")
|
||||||
|
|
||||||
model_group = parser.add_argument_group("🤖 Model Options")
|
model_group = parser.add_argument_group("🤖 Model Options")
|
||||||
model_group.add_argument('--model_name', type=str, default="unsloth/llama-3-8b", help="Model name to load")
|
model_group.add_argument('--model_name', type=str, default="unsloth/llama-3-8b", help="Model name to load")
|
||||||
model_group.add_argument('--max_seq_length', type=int, default=2048, help="Maximum sequence length, default is 2048. We auto support RoPE Scaling internally!")
|
model_group.add_argument('--max_seq_length', type=int, default=2048, help="Maximum sequence length, default is 2048. We auto support RoPE Scaling internally!")
|
||||||
model_group.add_argument('--dtype', type=str, default=None, help="Data type for model (None for auto detection)")
|
model_group.add_argument('--dtype', type=str, default=None, help="Data type for model (None for auto detection)")
|
||||||
model_group.add_argument('--load_in_4bit', action='store_true', help="Use 4bit quantization to reduce memory usage")
|
model_group.add_argument('--load_in_4bit', action='store_true', help="Use 4bit quantization to reduce memory usage")
|
||||||
model_group.add_argument('--dataset', type=str, default="yahma/alpaca-cleaned", help="Huggingface dataset to use for training")
|
model_group.add_argument('--dataset', type=str, default="yahma/alpaca-cleaned", help="Huggingface dataset to use for training")
|
||||||
|
|
||||||
lora_group = parser.add_argument_group("🧠 LoRA Options", "These options are used to configure the LoRA model.")
|
lora_group = parser.add_argument_group("🧠 LoRA Options", "These options are used to configure the LoRA model.")
|
||||||
lora_group.add_argument('--r', type=int, default=16, help="Rank for Lora model, default is 16. (common values: 8, 16, 32, 64, 128)")
|
lora_group.add_argument('--r', type=int, default=16, help="Rank for Lora model, default is 16. (common values: 8, 16, 32, 64, 128)")
|
||||||
lora_group.add_argument('--lora_alpha', type=int, default=16, help="LoRA alpha parameter, default is 16. (common values: 8, 16, 32, 64, 128)")
|
lora_group.add_argument('--lora_alpha', type=int, default=16, help="LoRA alpha parameter, default is 16. (common values: 8, 16, 32, 64, 128)")
|
||||||
lora_group.add_argument('--lora_dropout', type=float, default=0.0, help="LoRA dropout rate, default is 0.0 which is optimized.")
|
lora_group.add_argument('--lora_dropout', type=float, default=0.0, help="LoRA dropout rate, default is 0.0 which is optimized.")
|
||||||
lora_group.add_argument('--bias', type=str, default="none", help="Bias setting for LoRA")
|
lora_group.add_argument('--bias', type=str, default="none", help="Bias setting for LoRA")
|
||||||
lora_group.add_argument('--use_gradient_checkpointing', type=str, default="unsloth", help="Use gradient checkpointing")
|
lora_group.add_argument('--use_gradient_checkpointing', type=str, default="unsloth", help="Use gradient checkpointing")
|
||||||
lora_group.add_argument('--random_state', type=int, default=3407, help="Random state for reproducibility, default is 3407.")
|
lora_group.add_argument('--random_state', type=int, default=3407, help="Random state for reproducibility, default is 3407.")
|
||||||
lora_group.add_argument('--use_rslora', action='store_true', help="Use rank stabilized LoRA")
|
lora_group.add_argument('--use_rslora', action='store_true', help="Use rank stabilized LoRA")
|
||||||
lora_group.add_argument('--loftq_config', type=str, default=None, help="Configuration for LoftQ")
|
lora_group.add_argument('--loftq_config', type=str, default=None, help="Configuration for LoftQ")
|
||||||
|
|
||||||
|
|
||||||
training_group = parser.add_argument_group("🎓 Training Options")
|
training_group = parser.add_argument_group("🎓 Training Options")
|
||||||
training_group.add_argument('--per_device_train_batch_size', type=int, default=2, help="Batch size per device during training, default is 2.")
|
training_group.add_argument('--per_device_train_batch_size', type=int, default=2, help="Batch size per device during training, default is 2.")
|
||||||
training_group.add_argument('--gradient_accumulation_steps', type=int, default=4, help="Number of gradient accumulation steps, default is 4.")
|
training_group.add_argument('--gradient_accumulation_steps', type=int, default=4, help="Number of gradient accumulation steps, default is 4.")
|
||||||
training_group.add_argument('--warmup_steps', type=int, default=5, help="Number of warmup steps, default is 5.")
|
training_group.add_argument('--warmup_steps', type=int, default=5, help="Number of warmup steps, default is 5.")
|
||||||
training_group.add_argument('--max_steps', type=int, default=400, help="Maximum number of training steps.")
|
training_group.add_argument('--max_steps', type=int, default=400, help="Maximum number of training steps.")
|
||||||
training_group.add_argument('--learning_rate', type=float, default=2e-4, help="Learning rate, default is 2e-4.")
|
training_group.add_argument('--learning_rate', type=float, default=2e-4, help="Learning rate, default is 2e-4.")
|
||||||
training_group.add_argument('--optim', type=str, default="adamw_8bit", help="Optimizer type.")
|
training_group.add_argument('--optim', type=str, default="adamw_8bit", help="Optimizer type.")
|
||||||
training_group.add_argument('--weight_decay', type=float, default=0.01, help="Weight decay, default is 0.01.")
|
training_group.add_argument('--weight_decay', type=float, default=0.01, help="Weight decay, default is 0.01.")
|
||||||
training_group.add_argument('--lr_scheduler_type', type=str, default="linear", help="Learning rate scheduler type, default is 'linear'.")
|
training_group.add_argument('--lr_scheduler_type', type=str, default="linear", help="Learning rate scheduler type, default is 'linear'.")
|
||||||
training_group.add_argument('--seed', type=int, default=3407, help="Seed for reproducibility, default is 3407.")
|
training_group.add_argument('--seed', type=int, default=3407, help="Seed for reproducibility, default is 3407.")
|
||||||
|
|
||||||
|
|
||||||
# Report/Logging arguments
|
# Report/Logging arguments
|
||||||
report_group = parser.add_argument_group("📊 Report Options")
|
report_group = parser.add_argument_group("📊 Report Options")
|
||||||
report_group.add_argument('--report_to', type=str, default="tensorboard",
|
report_group.add_argument('--report_to', type=str, default="tensorboard",
|
||||||
choices=["azure_ml", "clearml", "codecarbon", "comet_ml", "dagshub", "dvclive", "flyte", "mlflow", "neptune", "tensorboard", "wandb", "all", "none"],
|
choices=["azure_ml", "clearml", "codecarbon", "comet_ml", "dagshub", "dvclive", "flyte", "mlflow", "neptune", "tensorboard", "wandb", "all", "none"],
|
||||||
help="The list of integrations to report the results and logs to. Supported platforms are: \n\t\t 'azure_ml', 'clearml', 'codecarbon', 'comet_ml', 'dagshub', 'dvclive', 'flyte', 'mlflow', 'neptune', 'tensorboard', and 'wandb'. Use 'all' to report to all integrations installed, 'none' for no integrations.")
|
help="The list of integrations to report the results and logs to. Supported platforms are: \n\t\t 'azure_ml', 'clearml', 'codecarbon', 'comet_ml', 'dagshub', 'dvclive', 'flyte', 'mlflow', 'neptune', 'tensorboard', and 'wandb'. Use 'all' to report to all integrations installed, 'none' for no integrations.")
|
||||||
report_group.add_argument('--logging_steps', type=int, default=1, help="Logging steps, default is 1")
|
report_group.add_argument('--logging_steps', type=int, default=1, help="Logging steps, default is 1")
|
||||||
|
|
||||||
# Saving and pushing arguments
|
# Saving and pushing arguments
|
||||||
save_group = parser.add_argument_group('💾 Save Model Options')
|
save_group = parser.add_argument_group('💾 Save Model Options')
|
||||||
save_group.add_argument('--output_dir', type=str, default="outputs", help="Output directory")
|
save_group.add_argument('--output_dir', type=str, default="outputs", help="Output directory")
|
||||||
save_group.add_argument('--save_model', action='store_true', help="Save the model after training")
|
save_group.add_argument('--save_model', action='store_true', help="Save the model after training")
|
||||||
save_group.add_argument('--save_method', type=str, default="merged_16bit", choices=["merged_16bit", "merged_4bit", "lora"], help="Save method for the model, default is 'merged_16bit'")
|
save_group.add_argument('--save_method', type=str, default="merged_16bit", choices=["merged_16bit", "merged_4bit", "lora"], help="Save method for the model, default is 'merged_16bit'")
|
||||||
save_group.add_argument('--save_gguf', action='store_true', help="Convert the model to GGUF after training")
|
save_group.add_argument('--save_gguf', action='store_true', help="Convert the model to GGUF after training")
|
||||||
save_group.add_argument('--save_path', type=str, default="model", help="Path to save the model")
|
save_group.add_argument('--save_path', type=str, default="model", help="Path to save the model")
|
||||||
save_group.add_argument('--quantization', type=str, default="q8_0", nargs="+",
|
save_group.add_argument('--quantization', type=str, default="q8_0", nargs="+",
|
||||||
help="Quantization method for saving the model. common values ('f16', 'q4_k_m', 'q8_0'), Check our wiki for all quantization methods https://github.com/unslothai/unsloth/wiki#saving-to-gguf ")
|
help="Quantization method for saving the model. common values ('f16', 'q4_k_m', 'q8_0'), Check our wiki for all quantization methods https://github.com/unslothai/unsloth/wiki#saving-to-gguf ")
|
||||||
|
|
||||||
push_group = parser.add_argument_group('🚀 Push Model Options')
|
push_group = parser.add_argument_group('🚀 Push Model Options')
|
||||||
push_group.add_argument('--push_model', action='store_true', help="Push the model to Hugging Face hub after training")
|
push_group.add_argument('--push_model', action='store_true', help="Push the model to Hugging Face hub after training")
|
||||||
push_group.add_argument('--push_gguf', action='store_true', help="Push the model as GGUF to Hugging Face hub after training")
|
push_group.add_argument('--push_gguf', action='store_true', help="Push the model as GGUF to Hugging Face hub after training")
|
||||||
push_group.add_argument('--hub_path', type=str, default="hf/model", help="Path on Hugging Face hub to push the model")
|
push_group.add_argument('--hub_path', type=str, default="hf/model", help="Path on Hugging Face hub to push the model")
|
||||||
push_group.add_argument('--hub_token', type=str, help="Token for pushing the model to Hugging Face hub")
|
push_group.add_argument('--hub_token', type=str, help="Token for pushing the model to Hugging Face hub")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
run(args)
|
run(args)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue