unsloth/unsloth_cli/options.py
Daniel Han 0c8d407793
Rename cli/ to unsloth_cli/ to fix namespace collision with stringzilla (#4393)
* Rename cli/ to unsloth_cli/ to fix namespace collision with stringzilla

stringzilla installs a namespace package at cli/ (cli/split.py, cli/wc.py)
in site-packages without an __init__.py. When unsloth is installed as an
editable package (pip install -e .), the entry point script does
`from cli import app` which finds stringzilla's namespace cli/ first and
fails with `ImportError: cannot import name 'app' from 'cli'`.

Non-editable installs happened to work because unsloth's cli/__init__.py
overwrites the namespace directory, but this is fragile and breaks if
stringzilla is installed after unsloth.

Renaming to unsloth_cli/ avoids the collision entirely and fixes both
editable and non-editable install paths.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update stale cli/ references in comments and license files

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-17 20:40:21 -07:00

153 lines
5.4 KiB
Python

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""Generate Typer CLI options from Pydantic models."""
import functools
import inspect
from pathlib import Path
from typing import Any, Callable, Optional, get_args, get_origin
import typer
from pydantic import BaseModel
def _python_name_to_cli_flag(name: str) -> str:
"""Convert python_name to --cli-flag."""
return "--" + name.replace("_", "-")
def _unwrap_optional(annotation: Any) -> Any:
"""Unwrap Optional[X] to X."""
origin = get_origin(annotation)
if origin is not None:
args = get_args(annotation)
if type(None) in args:
non_none = [a for a in args if a is not type(None)]
if non_none:
return non_none[0]
return annotation
def _is_bool_field(annotation: Any) -> bool:
"""Check if field is a boolean (including Optional[bool])."""
return _unwrap_optional(annotation) is bool
def _is_list_type(annotation: Any) -> bool:
"""Check if type is a List."""
return get_origin(annotation) is list
def _get_python_type(annotation: Any) -> type:
"""Get the Python type for annotation."""
unwrapped = _unwrap_optional(annotation)
if unwrapped in (str, int, float, bool, Path):
return unwrapped
return str
def _collect_config_fields(config_class: type[BaseModel]) -> list[tuple[str, Any]]:
"""
Collect all fields from a config class, flattening nested models. Returns list of
(name, field_info) tuples. Raises ValueError on duplicate field names.
"""
fields = []
seen_names: set[str] = set()
for name, field_info in config_class.model_fields.items():
annotation = field_info.annotation
# Skip nested models - recurse into them
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
for nested_name, nested_field in annotation.model_fields.items():
if nested_name in seen_names:
raise ValueError(f"Duplicate field name '{nested_name}' in config")
seen_names.add(nested_name)
fields.append((nested_name, nested_field))
else:
if name in seen_names:
raise ValueError(f"Duplicate field name '{name}' in config")
seen_names.add(name)
fields.append((name, field_info))
return fields
def add_options_from_config(config_class: type[BaseModel]) -> Callable:
"""
Decorator that adds CLI options for all fields in a Pydantic config model.
The decorated function should declare a `config_overrides: dict = None` parameter
which will receive a dict of all CLI-provided config values.
"""
fields = _collect_config_fields(config_class)
field_names = {
name for name, field_info in fields if not _is_list_type(field_info.annotation)
}
def decorator(func: Callable) -> Callable:
sig = inspect.signature(func)
original_params = list(sig.parameters.values())
original_param_names = {p.name for p in original_params}
# Build new parameters: config fields first, then original params
new_params = []
for field_name, field_info in fields:
# Skip fields already defined in function signature (e.g., with envvar)
if field_name in original_param_names:
continue
annotation = field_info.annotation
if _is_list_type(annotation):
continue
flag_name = _python_name_to_cli_flag(field_name)
help_text = field_info.description or ""
if _is_bool_field(annotation):
default = typer.Option(
None,
f"{flag_name}/--no-{field_name.replace('_', '-')}",
help = help_text,
)
param = inspect.Parameter(
field_name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
default = default,
annotation = Optional[bool],
)
else:
py_type = _get_python_type(annotation)
default = typer.Option(None, flag_name, help = help_text)
param = inspect.Parameter(
field_name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
default = default,
annotation = Optional[py_type],
)
new_params.append(param)
# Add original params, excluding config_overrides (will be injected)
for param in original_params:
if param.name != "config_overrides":
new_params.append(param)
new_sig = sig.replace(parameters = new_params)
@functools.wraps(func)
def wrapper(*args, **kwargs):
config_overrides = {}
for key in list(kwargs.keys()):
if key in field_names:
if kwargs[key] is not None:
config_overrides[key] = kwargs[key]
# Only delete if not an explicitly declared parameter
if key not in original_param_names:
del kwargs[key]
kwargs["config_overrides"] = config_overrides
return func(*args, **kwargs)
wrapper.__signature__ = new_sig
return wrapper
return decorator