mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
* Rename cli/ to unsloth_cli/ to fix namespace collision with stringzilla stringzilla installs a namespace package at cli/ (cli/split.py, cli/wc.py) in site-packages without an __init__.py. When unsloth is installed as an editable package (pip install -e .), the entry point script does `from cli import app` which finds stringzilla's namespace cli/ first and fails with `ImportError: cannot import name 'app' from 'cli'`. Non-editable installs happened to work because unsloth's cli/__init__.py overwrites the namespace directory, but this is fragile and breaks if stringzilla is installed after unsloth. Renaming to unsloth_cli/ avoids the collision entirely and fixes both editable and non-editable install paths. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update stale cli/ references in comments and license files --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
144 lines
4.7 KiB
Python
144 lines
4.7 KiB
Python
# SPDX-License-Identifier: AGPL-3.0-only
|
|
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
|
|
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import typer
|
|
|
|
from unsloth_cli.config import Config, load_config
|
|
from unsloth_cli.options import add_options_from_config
|
|
|
|
|
|
@add_options_from_config(Config)
|
|
def train(
|
|
config: Optional[Path] = typer.Option(
|
|
None,
|
|
"--config",
|
|
"-c",
|
|
help = "Path to YAML/JSON config file. CLI flags override config values.",
|
|
),
|
|
hf_token: Optional[str] = typer.Option(
|
|
None, "--hf-token", envvar = "HF_TOKEN", help = "Hugging Face token if needed."
|
|
),
|
|
wandb_token: Optional[str] = typer.Option(
|
|
None, "--wandb-token", envvar = "WANDB_API_KEY", help = "Weights & Biases API key."
|
|
),
|
|
dry_run: bool = typer.Option(
|
|
False,
|
|
"--dry-run",
|
|
help = "Show resolved config and exit without training.",
|
|
),
|
|
config_overrides: dict = None,
|
|
):
|
|
"""Launch training using the existing Unsloth training backend."""
|
|
try:
|
|
cfg = load_config(config)
|
|
except FileNotFoundError as e:
|
|
typer.echo(f"Error: {e}", err = True)
|
|
raise typer.Exit(code = 2)
|
|
|
|
cfg.apply_overrides(**config_overrides)
|
|
|
|
# CLI/env tokens take precedence over config
|
|
# Handle case where typer.Option isn't resolved (decorator interaction)
|
|
from typer.models import OptionInfo
|
|
|
|
if isinstance(hf_token, OptionInfo):
|
|
hf_token = None
|
|
if isinstance(wandb_token, OptionInfo):
|
|
wandb_token = None
|
|
hf_token = hf_token or cfg.logging.hf_token
|
|
wandb_token = wandb_token or cfg.logging.wandb_token
|
|
|
|
if dry_run:
|
|
import yaml
|
|
|
|
data = cfg.model_dump()
|
|
data["training"]["output_dir"] = str(data["training"]["output_dir"])
|
|
typer.echo(yaml.dump(data, default_flow_style = False, sort_keys = False))
|
|
raise typer.Exit(code = 0)
|
|
|
|
if not cfg.model:
|
|
typer.echo("Error: provide --model or set model in --config", err = True)
|
|
raise typer.Exit(code = 2)
|
|
|
|
if not cfg.data.dataset and not cfg.data.local_dataset:
|
|
typer.echo(
|
|
"Error: provide --dataset or --local-dataset (or via --config)", err = True
|
|
)
|
|
raise typer.Exit(code = 2)
|
|
|
|
# Check if the model path is a LoRA adapter (has adapter_config.json)
|
|
model_path = Path(cfg.model) if cfg.model else None
|
|
model_is_lora = (
|
|
model_path
|
|
and model_path.is_dir()
|
|
and (model_path / "adapter_config.json").exists()
|
|
)
|
|
use_lora = cfg.training.training_type.lower() == "lora"
|
|
|
|
if model_is_lora and not use_lora:
|
|
typer.echo(
|
|
"Error: Cannot do full finetuning on a LoRA adapter. "
|
|
"Use --training-type lora or provide a base model.",
|
|
err = True,
|
|
)
|
|
raise typer.Exit(code = 2)
|
|
|
|
from studio.backend.core.training.trainer import UnslothTrainer
|
|
|
|
trainer = UnslothTrainer()
|
|
|
|
# Load model (trainer.is_vlm is set after this)
|
|
if not trainer.load_model(
|
|
model_name = cfg.model,
|
|
max_seq_length = cfg.training.max_seq_length,
|
|
load_in_4bit = cfg.training.load_in_4bit if use_lora else False,
|
|
hf_token = hf_token,
|
|
):
|
|
typer.echo("Model load failed", err = True)
|
|
raise typer.Exit(code = 1)
|
|
|
|
is_vision = trainer.is_vlm
|
|
|
|
if not trainer.prepare_model_for_training(**cfg.model_kwargs(use_lora, is_vision)):
|
|
typer.echo("Model preparation failed", err = True)
|
|
raise typer.Exit(code = 1)
|
|
|
|
result = trainer.load_and_format_dataset(
|
|
dataset_source = cfg.data.dataset or "",
|
|
format_type = cfg.data.format_type,
|
|
local_datasets = cfg.data.local_dataset,
|
|
)
|
|
if result is None:
|
|
typer.echo("Dataset load failed", err = True)
|
|
raise typer.Exit(code = 1)
|
|
|
|
ds, eval_ds = result
|
|
|
|
training_kwargs = cfg.training_kwargs()
|
|
training_kwargs["wandb_token"] = wandb_token # CLI/env takes precedence
|
|
started = trainer.start_training(
|
|
dataset = ds, eval_dataset = eval_ds, **training_kwargs
|
|
)
|
|
|
|
if not started:
|
|
typer.echo("Training failed to start", err = True)
|
|
raise typer.Exit(code = 1)
|
|
|
|
try:
|
|
while trainer.training_thread and trainer.training_thread.is_alive():
|
|
time.sleep(1)
|
|
except KeyboardInterrupt:
|
|
typer.echo("Stopping training (Ctrl+C detected)...")
|
|
trainer.stop_training()
|
|
finally:
|
|
if trainer.training_thread:
|
|
trainer.training_thread.join()
|
|
|
|
final = trainer.get_training_progress()
|
|
if getattr(final, "error", None):
|
|
typer.echo(f"Training error: {final.error}", err = True)
|
|
raise typer.Exit(code = 1)
|