unsloth/studio/backend/utils/models/model_config.py
Daniel Han 05ec0f110b
Studio: Ollama support, recommended folders, Custom Folders UX polish (#5050)
* Studio: Ollama support, recommended folders, Custom Folders UX polish

Backend:
- Add _scan_ollama_dir that reads manifests/registry.ollama.ai/library/*
  and creates .gguf symlinks under <ollama_dir>/.studio_links/ pointing
  at the content-addressable blobs, so detect_gguf_model and llama-server
  -m work unchanged for Ollama models
- Filter entries under .studio_links from the generic models/hf/lmstudio
  scanners to avoid duplicate rows and leaked internal paths in the UI
- New GET /api/models/recommended-folders endpoint returning LM Studio
  and Ollama model directories that currently exist on the machine
  (OLLAMA_MODELS env var + standard paths, ~/.lmstudio/models, legacy
  LM Studio cache), used by the Custom Folders quick-add chips
- detect_gguf_model now uses os.path.abspath instead of Path.resolve so
  the readable symlink name is preserved as display_name (e.g.
  qwen2.5-0.5b-Q4_K_M.gguf instead of sha256-abc...)
- llama-server failure with a path under .studio_links or .cache/ollama
  surfaces a friendlier message ("Some Ollama models do not work with
  llama.cpp. Try a different model, or use this model directly through
  Ollama instead.") instead of the generic validation error

Frontend:
- ListLabel supports an optional leading icon and collapse toggle; used
  for Downloaded (download icon), Custom Folders (folder icon), and
  Recommended (star icon)
- Custom Folders header gets folder icon on the left, and +, search,
  and chevron buttons on the right; chevron uses ml-auto so it aligns
  with the Downloaded and Recommended chevrons
- New recommended folder chips render below the registered scan folders
  when there are unregistered well-known paths; one click adds them as
  a scan folder
- Custom folder rows that are direct .gguf files (Ollama symlinks) load
  immediately via onSelect instead of opening the GGUF variant expander
  (which is for repos containing multiple quants, not single files)
- When loading a direct .gguf file path, send max_seq_length = 0 so the
  backend uses the model's native context instead of the 4096 chat
  default (qwen2.5:0.5b now loads at 32768 instead of 4096)
- New listRecommendedFolders() helper on the chat API

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address review: log silent exceptions and support read-only Ollama dirs

Replace silent except blocks in _scan_ollama_dir and the
recommended-folders endpoint with narrower exception types plus debug
or warning logs, so failures are diagnosable without hiding signal.

Add _ollama_links_dir helper that falls back to a per-ollama-dir hashed
namespace under Studio's own cache (~/.unsloth/studio/cache/ollama_links)
when the Ollama models directory is read-only. Common for system installs
at /usr/share/ollama/.ollama/models and /var/lib/ollama/.ollama/models
where the Studio process has read but not write access. Previously the
scanner returned an empty list in that case and Ollama models would
silently not appear.

The fallback preserves the .gguf suffix on symlink names so
detect_gguf_model keeps recognising them. The prior "raw sha256 blob
path" fallback would have missed the suffix check and failed to load.

* Address review: detect mmproj next to symlink target for vision GGUFs

Codex P1 on model_config.py:1012: when detect_gguf_model returns the
symlink path (to preserve readable display names), detect_mmproj_file
searched the symlink's parent directory instead of the target's. For
vision GGUFs surfaced via Ollama's .studio_links/ -- where the weight
file is symlinked but any mmproj sidecar lives next to the real blob
-- mmproj was no longer detected, so the model was misclassified as
text-only and llama-server would start without --mmproj.

detect_mmproj_file now adds the resolved target's parent to the scan
order when path is a symlink. Direct (non-symlink) .gguf paths are
unchanged, so LM Studio and HF cache layouts keep working exactly as
before. Verified with a fake layout reproducing the bug plus a
regression check on a non-symlink LM Studio model.

* Address review: support all Ollama namespaces and vision projector layers

- Iterate over all directories under registry.ollama.ai/ instead of
  hardcoding the "library" namespace. Custom namespaces like
  "mradermacher/llama3" now get scanned and include the namespace
  prefix in display names, model IDs, and symlink names to avoid
  collisions.

- Create companion -mmproj.gguf symlinks for Ollama vision models
  that have an "application/vnd.ollama.image.projector" layer, so
  detect_mmproj_file can find the projector alongside the model.

- Extract symlink creation into _make_symlink helper to reduce
  duplication between model and projector paths.

* Address review: move imports to top level and add scan limit

- Move hashlib and json imports to the top of the file (PEP 8).
- Remove inline `import json as _json` and `import hashlib` from
  function bodies, use the top-level imports directly.
- Add `limit` parameter to `_scan_ollama_dir()` with early exit
  when the threshold is reached.
- Pass `_MAX_MODELS_PER_FOLDER` into the scanner so it stops
  traversing once enough models are found.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address review: Windows fallback, all registry hosts, collision safety

_make_link (formerly _make_symlink):
- Falls back to os.link() hardlink when symlink_to() fails (Windows
  without Developer Mode), then to shutil.copy2 as last resort
- Uses atomic os.replace via tmp file to avoid race window where the
  .gguf path is missing during rescan

Scanner now handles all Ollama registry layouts:
- Uses rglob over manifests/ instead of hardcoding registry.ollama.ai
- Discovers hf.co/org/repo:tag and any other host, not just library/
- Filenames include a stable sha1 hash of the manifest path to prevent
  collisions between models that normalize to the same stem

Per-model subdirectories under .studio_links/:
- Each model's links live in their own hash-keyed subdirectory
- detect_mmproj_file only sees the projector for that specific model,
  not siblings from other Ollama models

Friendly Ollama error detection:
- Now also matches ollama_links/ (the read-only fallback cache path)
  and model_identifier starting with "ollama/"

Recommended folders:
- Added os.access(R_OK | X_OK) check so unreadable system directories
  like /var/lib/ollama/.ollama/models are not advertised as chips

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address review: filter ollama_links from generic scanners

The generic scanners (models_dir, hf_cache, lmstudio) already filter
out .studio_links to avoid duplicate Ollama entries, but missed the
ollama_links fallback cache directory used for read-only Ollama
installs. Add it to the filter.

* Address review: idempotent link creation and path-component filter

_make_link:
- Skip recreation when a valid link/copy already exists (samefile or
  matching size check). Prevents blocking the model-list API with
  multi-GB copies on repeated scans.
- Use uuid4 instead of os.getpid() for tmp file names to avoid race
  conditions from concurrent scans.
- Log cleanup errors instead of silently swallowing them.

Path filter:
- Use os.sep-bounded checks instead of bare substring match to avoid
  false positives on paths like "my.studio_links.backup/model.gguf".

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address review: drop copy fallback, targeted glob, robust path filter

_make_link:
- Drop shutil.copy2 fallback -- copying multi-GB GGUFs inside a sync
  API request would block the backend. Log a warning and skip the
  model when both symlink and hardlink fail.

Scanner:
- Replace rglob("*") with targeted glob patterns (*/*/* and */*/*/*)
  to avoid traversing unrelated subdirectories in large custom folders.

Path filter:
- Use Path.parts membership check instead of os.sep substring matching
  for robustness across platforms.

Scan limit:
- Skip _scan_ollama_dir when _generic already fills the per-folder cap.

* Address review: sha256, top-level uuid import, Path.absolute()

- Switch hashlib.sha1 to hashlib.sha256 for path hashing consistency.
- Move uuid import to the top of the file instead of inside _make_link.
- Replace os.path.abspath with Path.absolute() in detect_gguf_model
  to match the pathlib style used throughout the codebase.

* Address review: fix stale comments (sha1, rglob, copy fallback)

Update three docstrings/comments that still referenced the old
implementation after recent changes:
- sha1 comment now says "not a security boundary" (no hash name)
- "rglob" -> "targeted glob patterns"
- "file copies as a last resort" -> removed (copy fallback was dropped)

* Address review: fix stale links, support all manifest depths, scope error

_make_link:
- Drop size-based idempotency shortcut that kept stale links after
  ollama pull updates a tag to a same-sized blob. Only samefile()
  is used now -- if the link doesn't point at the exact same inode,
  it gets replaced.

Scanner:
- Revert targeted glob back to rglob so deeper OCI-style repo names
  (5+ path segments) are not silently skipped.

Ollama error:
- Only show "Some Ollama models do not work with llama.cpp" when the
  server output contains GGUF compatibility hints (key not found,
  unknown architecture, failed to load). Unrelated failures like
  OOM or missing binaries now show the generic error instead of
  being misdiagnosed.

---------

Co-authored-by: Daniel Han <info@unsloth.ai>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: danielhanchen <michaelhan2050@gmail.com>
2026-04-16 08:24:08 -07:00

2254 lines
82 KiB
Python

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""
Model and LoRA configuration handling
"""
from dataclasses import dataclass
from typing import Optional, Dict, Any
from utils.paths import (
normalize_path,
is_local_path,
is_model_cached,
get_cache_path,
resolve_cached_repo_id_case,
outputs_root,
exports_root,
resolve_output_dir,
resolve_export_dir,
)
from utils.utils import without_hf_auth
import structlog
from loggers import get_logger
import os
import subprocess
import sys
from pathlib import Path
from typing import List, Tuple
import hashlib
import json
import threading
import yaml
logger = get_logger(__name__)
# ── Model size extraction ────────────────────────────────────
import re as _re
_MODEL_SIZE_RE = _re.compile(
r"(?:^|[-_/])(\d+\.?\d*)\s*([bm])(?:$|[-_/])", _re.IGNORECASE
)
# MoE active-parameter pattern: matches "A3B", "A3.5B", etc.
_ACTIVE_SIZE_RE = _re.compile(
r"(?:^|[-_/])a(\d+\.?\d*)\s*([bm])(?:$|[-_/])", _re.IGNORECASE
)
def extract_model_size_b(model_id: str) -> float | None:
"""Extract model size in billions from a model identifier.
Prefers MoE active-parameter notation (e.g. ``A3B`` in
``Qwen3.5-35B-A3B``) over the total parameter count.
Handles both ``B`` (billions) and ``M`` (millions) suffixes.
"""
mid = (model_id or "").lower()
active = _ACTIVE_SIZE_RE.search(mid)
if active:
val = float(active.group(1))
return val / 1000.0 if active.group(2).lower() == "m" else val
size = _MODEL_SIZE_RE.search(mid)
if not size:
return None
val = float(size.group(1))
return val / 1000.0 if size.group(2).lower() == "m" else val
# Model name mapping: maps all equivalent model names to their canonical YAML config file
# Format: "canonical_model_name.yaml": [list of all equivalent model names]
# Based on the model mapper provided - canonical filename is based on the first model name in the mapper
MODEL_NAME_MAPPING = {
# ── Embedding models ──
"unsloth_all-MiniLM-L6-v2.yaml": [
"unsloth/all-MiniLM-L6-v2",
"sentence-transformers/all-MiniLM-L6-v2",
],
"unsloth_bge-m3.yaml": [
"unsloth/bge-m3",
"BAAI/bge-m3",
],
"unsloth_embeddinggemma-300m.yaml": [
"unsloth/embeddinggemma-300m",
"google/embeddinggemma-300m",
],
"unsloth_gte-modernbert-base.yaml": [
"unsloth/gte-modernbert-base",
"Alibaba-NLP/gte-modernbert-base",
],
"unsloth_Qwen3-Embedding-0.6B.yaml": [
"unsloth/Qwen3-Embedding-0.6B",
"Qwen/Qwen3-Embedding-0.6B",
"unsloth/Qwen3-Embedding-4B",
"Qwen/Qwen3-Embedding-4B",
],
# ── Other models ──
"unsloth_answerdotai_ModernBERT-large.yaml": [
"answerdotai/ModernBERT-large",
],
"unsloth_Qwen2.5-Coder-7B-Instruct-bnb-4bit.yaml": [
"unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit",
"unsloth/Qwen2.5-Coder-7B-Instruct",
"Qwen/Qwen2.5-Coder-7B-Instruct",
],
"unsloth_codegemma-7b-bnb-4bit.yaml": [
"unsloth/codegemma-7b-bnb-4bit",
"unsloth/codegemma-7b",
"google/codegemma-7b",
],
"unsloth_ERNIE-4.5-21B-A3B-PT.yaml": [
"unsloth/ERNIE-4.5-21B-A3B-PT",
],
"unsloth_ERNIE-4.5-VL-28B-A3B-PT.yaml": [
"unsloth/ERNIE-4.5-VL-28B-A3B-PT",
],
"tiiuae_Falcon-H1-0.5B-Instruct.yaml": [
"tiiuae/Falcon-H1-0.5B-Instruct",
"unsloth/Falcon-H1-0.5B-Instruct",
],
"unsloth_functiongemma-270m-it.yaml": [
"unsloth/functiongemma-270m-it-unsloth-bnb-4bit",
"google/functiongemma-270m-it",
"unsloth/functiongemma-270m-it-unsloth-bnb-4bit",
],
"unsloth_gemma-2-2b.yaml": [
"unsloth/gemma-2-2b-bnb-4bit",
"google/gemma-2-2b",
],
"unsloth_gemma-2-27b-bnb-4bit.yaml": [
"unsloth/gemma-2-9b-bnb-4bit",
"unsloth/gemma-2-9b",
"google/gemma-2-9b",
"unsloth/gemma-2-27b",
"google/gemma-2-27b",
],
"unsloth_gemma-3-4b-pt.yaml": [
"unsloth/gemma-3-4b-pt-unsloth-bnb-4bit",
"google/gemma-3-4b-pt",
"unsloth/gemma-3-4b-pt-bnb-4bit",
],
"unsloth_gemma-3-4b-it.yaml": [
"unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
"google/gemma-3-4b-it",
"unsloth/gemma-3-4b-it-bnb-4bit",
],
"unsloth_gemma-3-27b-it.yaml": [
"unsloth/gemma-3-27b-it-unsloth-bnb-4bit",
"google/gemma-3-27b-it",
"unsloth/gemma-3-27b-it-bnb-4bit",
],
"unsloth_gemma-3-270m-it.yaml": [
"unsloth/gemma-3-270m-it-unsloth-bnb-4bit",
"google/gemma-3-270m-it",
"unsloth/gemma-3-270m-it-bnb-4bit",
],
"unsloth_gemma-3n-E4B-it.yaml": [
"unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit",
"google/gemma-3n-E4B-it",
"unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit",
],
"unsloth_gemma-3n-E4B.yaml": [
"unsloth/gemma-3n-E4B-unsloth-bnb-4bit",
"google/gemma-3n-E4B",
],
"unsloth_gemma-4-31B-it.yaml": [
"unsloth/gemma-4-31B-it",
"google/gemma-4-31B-it",
],
"unsloth_gemma-4-26B-A4B-it.yaml": [
"unsloth/gemma-4-26B-A4B-it",
"google/gemma-4-26B-A4B-it",
],
"unsloth_gemma-4-E2B-it.yaml": [
"unsloth/gemma-4-E2B-it",
"google/gemma-4-E2B-it",
],
"unsloth_gemma-4-E4B-it.yaml": [
"unsloth/gemma-4-E4B-it",
"google/gemma-4-E4B-it",
],
"unsloth_gemma-4-31B.yaml": [
"unsloth/gemma-4-31B",
"google/gemma-4-31B",
],
"unsloth_gemma-4-26B-A4B.yaml": [
"unsloth/gemma-4-26B-A4B",
"google/gemma-4-26B-A4B",
],
"unsloth_gemma-4-E2B.yaml": [
"unsloth/gemma-4-E2B",
"google/gemma-4-E2B",
],
"unsloth_gemma-4-E4B.yaml": [
"unsloth/gemma-4-E4B",
"google/gemma-4-E4B",
],
"unsloth_gpt-oss-20b.yaml": [
"openai/gpt-oss-20b",
"unsloth/gpt-oss-20b-unsloth-bnb-4bit",
"unsloth/gpt-oss-20b-BF16",
],
"unsloth_gpt-oss-120b.yaml": [
"openai/gpt-oss-120b",
"unsloth/gpt-oss-120b-unsloth-bnb-4bit",
],
"unsloth_granite-4.0-350m-unsloth-bnb-4bit.yaml": [
"unsloth/granite-4.0-350m",
"ibm-granite/granite-4.0-350m",
"unsloth/granite-4.0-350m-bnb-4bit",
],
"unsloth_granite-4.0-h-micro.yaml": [
"ibm-granite/granite-4.0-h-micro",
"unsloth/granite-4.0-h-micro-bnb-4bit",
"unsloth/granite-4.0-h-micro-unsloth-bnb-4bit",
],
"unsloth_LFM2-1.2B.yaml": [
"unsloth/LFM2-1.2B",
],
"unsloth_llama-3-8b-bnb-4bit.yaml": [
"unsloth/llama-3-8b",
"meta-llama/Meta-Llama-3-8B",
],
"unsloth_llama-3-8b-Instruct-bnb-4bit.yaml": [
"unsloth/llama-3-8b-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
],
"unsloth_Meta-Llama-3.1-70B-bnb-4bit.yaml": [
"unsloth/Meta-Llama-3.1-8B-bnb-4bit",
"unsloth/Meta-Llama-3.1-8B-unsloth-bnb-4bit",
"meta-llama/Meta-Llama-3.1-8B",
"unsloth/Meta-Llama-3.1-70B-bnb-4bit",
"unsloth/Meta-Llama-3.1-8B",
"unsloth/Meta-Llama-3.1-70B",
"meta-llama/Meta-Llama-3.1-70B",
"unsloth/Meta-Llama-3.1-405B-bnb-4bit",
"meta-llama/Meta-Llama-3.1-405B",
],
"unsloth_Meta-Llama-3.1-8B-Instruct-bnb-4bit.yaml": [
"unsloth/Meta-Llama-3.1-8B-Instruct-unsloth-bnb-4bit",
"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"unsloth/Meta-Llama-3.1-8B-Instruct",
"RedHatAI/Llama-3.1-8B-Instruct-FP8",
"unsloth/Llama-3.1-8B-Instruct-FP8-Block",
"unsloth/Llama-3.1-8B-Instruct-FP8-Dynamic",
],
"unsloth_Llama-3.2-3B-Instruct.yaml": [
"unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit",
"meta-llama/Llama-3.2-3B-Instruct",
"unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
"RedHatAI/Llama-3.2-3B-Instruct-FP8",
"unsloth/Llama-3.2-3B-Instruct-FP8-Block",
"unsloth/Llama-3.2-3B-Instruct-FP8-Dynamic",
],
"unsloth_Llama-3.2-1B-Instruct.yaml": [
"unsloth/Llama-3.2-1B-Instruct-unsloth-bnb-4bit",
"meta-llama/Llama-3.2-1B-Instruct",
"unsloth/Llama-3.2-1B-Instruct-bnb-4bit",
"RedHatAI/Llama-3.2-1B-Instruct-FP8",
"unsloth/Llama-3.2-1B-Instruct-FP8-Block",
"unsloth/Llama-3.2-1B-Instruct-FP8-Dynamic",
],
"unsloth_Llama-3.2-11B-Vision-Instruct.yaml": [
"unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit",
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
],
"unsloth_Llama-3.3-70B-Instruct.yaml": [
"unsloth/Llama-3.3-70B-Instruct-unsloth-bnb-4bit",
"meta-llama/Llama-3.3-70B-Instruct",
"unsloth/Llama-3.3-70B-Instruct-bnb-4bit",
"RedHatAI/Llama-3.3-70B-Instruct-FP8",
"unsloth/Llama-3.3-70B-Instruct-FP8-Block",
"unsloth/Llama-3.3-70B-Instruct-FP8-Dynamic",
],
"unsloth_Llasa-3B.yaml": [
"HKUSTAudio/Llasa-1B",
"unsloth/Llasa-3B",
],
"unsloth_Magistral-Small-2509-unsloth-bnb-4bit.yaml": [
"unsloth/Magistral-Small-2509",
"mistralai/Magistral-Small-2509",
"unsloth/Magistral-Small-2509-bnb-4bit",
],
"unsloth_Ministral-3-3B-Instruct-2512.yaml": [
"unsloth/Ministral-3-3B-Instruct-2512",
],
"unsloth_mistral-7b-v0.3-bnb-4bit.yaml": [
"unsloth/mistral-7b-v0.3-bnb-4bit",
"unsloth/mistral-7b-v0.3",
"mistralai/Mistral-7B-v0.3",
],
"unsloth_Mistral-Nemo-Base-2407-bnb-4bit.yaml": [
"unsloth/Mistral-Nemo-Base-2407-bnb-4bit",
"unsloth/Mistral-Nemo-Base-2407",
"mistralai/Mistral-Nemo-Base-2407",
"unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit",
"unsloth/Mistral-Nemo-Instruct-2407",
"mistralai/Mistral-Nemo-Instruct-2407",
],
"unsloth_Mistral-Small-Instruct-2409.yaml": [
"unsloth/Mistral-Small-Instruct-2409-bnb-4bit",
"mistralai/Mistral-Small-Instruct-2409",
],
"unsloth_mistral-7b-instruct-v0.3-bnb-4bit.yaml": [
"unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
"unsloth/mistral-7b-instruct-v0.3",
"mistralai/Mistral-7B-Instruct-v0.3",
],
"unsloth_Qwen2.5-1.5B-Instruct.yaml": [
"unsloth/Qwen2.5-1.5B-Instruct-unsloth-bnb-4bit",
"Qwen/Qwen2.5-1.5B-Instruct",
"unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit",
],
"unsloth_Nemotron-3-Nano-30B-A3B.yaml": [
"unsloth/Nemotron-3-Nano-30B-A3B",
],
"unsloth_orpheus-3b-0.1-ft.yaml": [
"unsloth/orpheus-3b-0.1-ft",
"unsloth/orpheus-3b-0.1-ft-unsloth-bnb-4bit",
"canopylabs/orpheus-3b-0.1-ft",
"unsloth/orpheus-3b-0.1-ft-bnb-4bit",
],
"OuteAI_Llama-OuteTTS-1.0-1B.yaml": [
"OuteAI/Llama-OuteTTS-1.0-1B",
"unsloth/Llama-OuteTTS-1.0-1B",
"unsloth/llama-outetts-1.0-1b",
"OuteAI/OuteTTS-1.0-0.6B",
"unsloth/OuteTTS-1.0-0.6B",
"unsloth/outetts-1.0-0.6b",
],
"unsloth_PaddleOCR-VL.yaml": [
"unsloth/PaddleOCR-VL",
],
"unsloth_Phi-3-medium-4k-instruct.yaml": [
"unsloth/Phi-3-medium-4k-instruct-bnb-4bit",
"microsoft/Phi-3-medium-4k-instruct",
],
"unsloth_Phi-3.5-mini-instruct.yaml": [
"unsloth/Phi-3.5-mini-instruct-bnb-4bit",
"microsoft/Phi-3.5-mini-instruct",
],
"unsloth_Phi-4.yaml": [
"unsloth/phi-4-unsloth-bnb-4bit",
"microsoft/phi-4",
"unsloth/phi-4-bnb-4bit",
],
"unsloth_Pixtral-12B-2409.yaml": [
"unsloth/Pixtral-12B-2409-unsloth-bnb-4bit",
"mistralai/Pixtral-12B-2409",
"unsloth/Pixtral-12B-2409-bnb-4bit",
],
"unsloth_Qwen2-7B.yaml": [
"unsloth/Qwen2-7B-bnb-4bit",
"Qwen/Qwen2-7B",
],
"unsloth_Qwen2-VL-7B-Instruct.yaml": [
"unsloth/Qwen2-VL-7B-Instruct-unsloth-bnb-4bit",
"Qwen/Qwen2-VL-7B-Instruct",
"unsloth/Qwen2-VL-7B-Instruct-bnb-4bit",
],
"unsloth_Qwen2.5-7B.yaml": [
"unsloth/Qwen2.5-7B-unsloth-bnb-4bit",
"Qwen/Qwen2.5-7B",
"unsloth/Qwen2.5-7B-bnb-4bit",
],
"unsloth_Qwen2.5-Coder-1.5B-Instruct.yaml": [
"unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit",
"Qwen/Qwen2.5-Coder-1.5B-Instruct",
],
"unsloth_Qwen2.5-Coder-14B-Instruct.yaml": [
"unsloth/Qwen2.5-Coder-14B-Instruct-bnb-4bit",
"Qwen/Qwen2.5-Coder-14B-Instruct",
],
"unsloth_Qwen2.5-VL-7B-Instruct-bnb-4bit.yaml": [
"unsloth/Qwen2.5-VL-7B-Instruct",
"Qwen/Qwen2.5-VL-7B-Instruct",
"unsloth/Qwen2.5-VL-7B-Instruct-unsloth-bnb-4bit",
],
"unsloth_Qwen3-0.6B.yaml": [
"unsloth/Qwen3-0.6B-unsloth-bnb-4bit",
"Qwen/Qwen3-0.6B",
"unsloth/Qwen3-0.6B-bnb-4bit",
"Qwen/Qwen3-0.6B-FP8",
"unsloth/Qwen3-0.6B-FP8",
],
"unsloth_Qwen3-4B-Instruct-2507.yaml": [
"unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit",
"Qwen/Qwen3-4B-Instruct-2507",
"unsloth/Qwen3-4B-Instruct-2507-bnb-4bit",
"Qwen/Qwen3-4B-Instruct-2507-FP8",
"unsloth/Qwen3-4B-Instruct-2507-FP8",
],
"unsloth_Qwen3-4B-Thinking-2507.yaml": [
"unsloth/Qwen3-4B-Thinking-2507-unsloth-bnb-4bit",
"Qwen/Qwen3-4B-Thinking-2507",
"unsloth/Qwen3-4B-Thinking-2507-bnb-4bit",
"Qwen/Qwen3-4B-Thinking-2507-FP8",
"unsloth/Qwen3-4B-Thinking-2507-FP8",
],
"unsloth_Qwen3-14B-Base-unsloth-bnb-4bit.yaml": [
"unsloth/Qwen3-14B-Base",
"Qwen/Qwen3-14B-Base",
"unsloth/Qwen3-14B-Base-bnb-4bit",
],
"unsloth_Qwen3-14B.yaml": [
"unsloth/Qwen3-14B-unsloth-bnb-4bit",
"Qwen/Qwen3-14B",
"unsloth/Qwen3-14B-bnb-4bit",
"Qwen/Qwen3-14B-FP8",
"unsloth/Qwen3-14B-FP8",
],
"unsloth_Qwen3-32B.yaml": [
"unsloth/Qwen3-32B-unsloth-bnb-4bit",
"Qwen/Qwen3-32B",
"unsloth/Qwen3-32B-bnb-4bit",
"Qwen/Qwen3-32B-FP8",
"unsloth/Qwen3-32B-FP8",
],
"unsloth_Qwen3-VL-8B-Instruct-unsloth-bnb-4bit.yaml": [
"Qwen/Qwen3-VL-8B-Instruct-FP8",
"unsloth/Qwen3-VL-8B-Instruct-FP8",
"unsloth/Qwen3-VL-8B-Instruct",
"Qwen/Qwen3-VL-8B-Instruct",
"unsloth/Qwen3-VL-8B-Instruct-bnb-4bit",
],
"sesame_csm-1b.yaml": [
"sesame/csm-1b",
"unsloth/csm-1b",
],
"Spark-TTS-0.5B_LLM.yaml": [
"Spark-TTS-0.5B/LLM",
"unsloth/Spark-TTS-0.5B",
],
"unsloth_tinyllama-bnb-4bit.yaml": [
"unsloth/tinyllama",
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
],
"unsloth_whisper-large-v3.yaml": [
"unsloth/whisper-large-v3",
"openai/whisper-large-v3",
],
}
# Reverse mapping for quick lookup: model_name -> canonical_filename
_REVERSE_MODEL_MAPPING = {}
for canonical_file, model_names in MODEL_NAME_MAPPING.items():
for model_name in model_names:
_REVERSE_MODEL_MAPPING[model_name.lower()] = canonical_file
def load_model_config(
model_name: str,
use_auth: bool = False,
token: Optional[str] = None,
trust_remote_code: bool = True,
):
"""
Load model config with optional authentication control.
"""
from transformers import AutoConfig
if token:
# Explicit token provided - use it
return AutoConfig.from_pretrained(
model_name, trust_remote_code = trust_remote_code, token = token
)
if not use_auth:
# Load without any authentication (for public model checks)
with without_hf_auth():
return AutoConfig.from_pretrained(
model_name,
trust_remote_code = trust_remote_code,
token = None,
)
# Use default authentication (cached tokens)
return AutoConfig.from_pretrained(
model_name,
trust_remote_code = trust_remote_code,
)
# VLM architecture suffixes and known VLM model_type values.
_VLM_ARCH_SUFFIXES = ("ForConditionalGeneration", "ForVisionText2Text")
_VLM_MODEL_TYPES = {
"phi3_v",
"llava",
"llava_next",
"llava_onevision",
"internvl_chat",
"cogvlm2",
"minicpmv",
}
# Pre-computed .venv_t5 paths and backend dir for subprocess version switching.
# Vision check uses 5.5.0 (newest, recognizes all architectures).
_VENV_T5_DIR = str(Path.home() / ".unsloth" / "studio" / ".venv_t5_550")
_BACKEND_DIR = str(Path(__file__).resolve().parent.parent.parent)
# Inline script executed in a subprocess with transformers 5.x activated.
# Receives model_name and token via argv, prints JSON result to stdout.
_VISION_CHECK_SCRIPT = r"""
import sys, os, json
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Activate transformers 5.x
venv_t5 = sys.argv[1]
backend_dir = sys.argv[2]
model_name = sys.argv[3]
token = sys.argv[4] if len(sys.argv) > 4 and sys.argv[4] != "" else None
sys.path.insert(0, venv_t5)
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
try:
from transformers import AutoConfig
kwargs = {"trust_remote_code": True}
if token:
kwargs["token"] = token
config = AutoConfig.from_pretrained(model_name, **kwargs)
is_vlm = False
if hasattr(config, "architectures"):
is_vlm = any(
x.endswith(("ForConditionalGeneration", "ForVisionText2Text"))
for x in config.architectures
)
if not is_vlm and hasattr(config, "vision_config"):
is_vlm = True
if not is_vlm and hasattr(config, "img_processor"):
is_vlm = True
if not is_vlm and hasattr(config, "image_token_index"):
is_vlm = True
if not is_vlm and hasattr(config, "model_type"):
vlm_types = {"phi3_v","llava","llava_next","llava_onevision",
"internvl_chat","cogvlm2","minicpmv"}
if config.model_type in vlm_types:
is_vlm = True
model_type = getattr(config, "model_type", "unknown")
archs = getattr(config, "architectures", [])
print(json.dumps({"is_vision": is_vlm, "model_type": model_type,
"architectures": archs}))
except Exception as exc:
print(json.dumps({"error": str(exc)}))
sys.exit(1)
"""
def _is_vision_model_subprocess(
model_name: str, hf_token: Optional[str] = None
) -> Optional[bool]:
"""Run is_vision_model check in a subprocess with transformers 5.x.
Same pattern as training/inference workers: spawn a clean subprocess
with .venv_t5/ prepended to sys.path so AutoConfig recognizes newer
architectures (glm4_moe_lite, etc.).
Returns True/False for definitive results, or None for transient failures
(timeouts, subprocess errors) so callers can decide whether to cache
the result. Subprocess failures are treated as transient because they
can be caused by temporary HF/auth/network issues.
"""
token_arg = hf_token or ""
try:
result = subprocess.run(
[
sys.executable,
"-c",
_VISION_CHECK_SCRIPT,
_VENV_T5_DIR,
_BACKEND_DIR,
model_name,
token_arg,
],
capture_output = True,
text = True,
timeout = 60,
)
if result.returncode != 0:
stderr = result.stderr.strip()
logger.warning(
"Vision check subprocess failed for '%s': %s",
model_name,
stderr or result.stdout.strip(),
)
return None
data = json.loads(result.stdout.strip())
if "error" in data:
logger.warning(
"Vision check subprocess error for '%s': %s",
model_name,
data["error"],
)
return None
is_vlm = data["is_vision"]
logger.info(
"Vision check (subprocess, transformers 5.x) for '%s': "
"model_type=%s, architectures=%s, is_vision=%s",
model_name,
data.get("model_type"),
data.get("architectures"),
is_vlm,
)
return is_vlm
except subprocess.TimeoutExpired:
logger.warning("Vision check subprocess timed out for '%s'", model_name)
return None
except Exception as exc:
logger.warning("Vision check subprocess failed for '%s': %s", model_name, exc)
return None
def _token_fingerprint(token: Optional[str]) -> Optional[str]:
"""Return a SHA256 digest of the token for use as a cache key.
Avoids storing the raw bearer token in process memory as a dict key.
"""
if token is None:
return None
return hashlib.sha256(token.encode("utf-8")).hexdigest()
# Cache vision detection results per session to avoid repeated subprocess spawns.
# Keyed by (normalized_model_name, token_fingerprint) to handle gated models correctly.
# Only definitive results (True/False from successful detection) are cached;
# transient failures (network errors, timeouts) are NOT cached so they can be retried.
_vision_detection_cache: Dict[Tuple[str, Optional[str]], bool] = {}
_vision_cache_lock = threading.Lock()
def is_vision_model(model_name: str, hf_token: Optional[str] = None) -> bool:
"""
Detect vision-language models (VLMs) by checking architecture in config.
Works for fine-tuned models since they inherit the base architecture.
For models that require transformers 5.x (e.g. GLM-4.7-Flash), the check
runs in a subprocess with .venv_t5/ activated -- same pattern as the
training and inference workers.
Results are cached per (model_name, token_fingerprint) for the lifetime of
the process to avoid repeated subprocess spawns and HuggingFace API calls.
Transient failures are not cached so they can be retried on the next call.
Args:
model_name: Model identifier (HF repo or local path)
hf_token: Optional HF token for accessing gated/private models
"""
# Normalize model name for cache key to avoid duplicate entries for
# different casings of the same HF repo (e.g. "Org/Model" vs "org/model").
try:
if is_local_path(model_name):
resolved_name = normalize_path(model_name)
else:
resolved_name = resolve_cached_repo_id_case(model_name)
except Exception as exc:
logger.debug(
"Could not normalize model name '%s' for cache key: %s",
model_name,
exc,
)
resolved_name = model_name
cache_key = (resolved_name, _token_fingerprint(hf_token))
# Lock-free fast path for cache hits. Uses a sentinel to distinguish
# "key not found" from "value is False" in a single atomic dict.get() call.
_MISS = object()
cached = _vision_detection_cache.get(cache_key, _MISS)
if cached is not _MISS:
return cached
# Compute outside the lock to avoid serializing long-running detection
# (subprocess spawns with 60s timeout, HF API calls) across all models.
# The tradeoff: two concurrent calls for the same uncached model may
# both run detection, but they produce the same result and the second
# write is a benign no-op.
result = _is_vision_model_uncached(resolved_name, hf_token)
# Only cache definitive results; None means a transient failure occurred
# and we should retry on the next call instead of locking in a wrong answer.
if result is not None:
with _vision_cache_lock:
_vision_detection_cache[cache_key] = result
return result
return False
def _is_vision_model_uncached(
model_name: str, hf_token: Optional[str] = None
) -> Optional[bool]:
"""Uncached vision model detection -- called by is_vision_model().
Returns True/False for definitive results, or None when detection failed
due to a transient error (network, timeout, subprocess failure) so the
caller knows not to cache the result.
Do not call directly; use is_vision_model() instead.
"""
# Models that need transformers 5.x must be checked in a subprocess
# because AutoConfig in the main process (transformers 4.57.x) doesn't
# recognize their architectures.
from utils.transformers_version import needs_transformers_5
if needs_transformers_5(model_name):
logger.info(
"Model '%s' needs transformers 5.x -- checking vision via subprocess",
model_name,
)
return _is_vision_model_subprocess(model_name, hf_token = hf_token)
try:
config = load_model_config(model_name, use_auth = True, token = hf_token)
# Exclude audio-only models that share ForConditionalGeneration suffix
# (e.g. CsmForConditionalGeneration, WhisperForConditionalGeneration)
_audio_only_model_types = {"csm", "whisper"}
model_type = getattr(config, "model_type", None)
if model_type in _audio_only_model_types:
return False
# Check 1: Architecture class name patterns
if hasattr(config, "architectures"):
is_vlm = any(x.endswith(_VLM_ARCH_SUFFIXES) for x in config.architectures)
if is_vlm:
logger.info(
f"Model {model_name} detected as VLM: architecture {config.architectures}"
)
return True
# Check 2: Has vision_config (most VLMs: LLaVA, Gemma-3, Qwen2-VL, etc.)
if hasattr(config, "vision_config"):
logger.info(f"Model {model_name} detected as VLM: has vision_config")
return True
# Check 3: Has img_processor (Phi-3.5 Vision uses this instead of vision_config)
if hasattr(config, "img_processor"):
logger.info(f"Model {model_name} detected as VLM: has img_processor")
return True
# Check 4: Has image_token_index (common in VLMs for image placeholder tokens)
if hasattr(config, "image_token_index"):
logger.info(f"Model {model_name} detected as VLM: has image_token_index")
return True
# Check 5: Known VLM model_type values that may not match above checks
if hasattr(config, "model_type"):
if config.model_type in _VLM_MODEL_TYPES:
logger.info(
f"Model {model_name} detected as VLM: model_type={config.model_type}"
)
return True
return False
except Exception as e:
logger.warning(f"Could not determine if {model_name} is vision model: {e}")
# Permanent failures (model not found, gated, bad config) should be
# cached as False. Transient failures (network, timeout) should not.
try:
from huggingface_hub.errors import RepositoryNotFoundError, GatedRepoError
except ImportError:
try:
from huggingface_hub.utils import (
RepositoryNotFoundError,
GatedRepoError,
)
except ImportError:
RepositoryNotFoundError = GatedRepoError = None
if RepositoryNotFoundError is not None and isinstance(
e, (RepositoryNotFoundError, GatedRepoError)
):
return False
if isinstance(e, (ValueError, json.JSONDecodeError)):
return False
return None
VALID_AUDIO_TYPES = ("snac", "csm", "bicodec", "dac", "whisper", "audio_vlm")
# Cache detection results per session to avoid repeated API calls
_audio_detection_cache: Dict[str, Optional[str]] = {}
# Tokenizer token patterns → audio_type (all 6 types detected from tokenizer_config.json)
_AUDIO_TOKEN_PATTERNS = {
"csm": lambda tokens: "<|AUDIO|>" in tokens and "<|audio_eos|>" in tokens,
"whisper": lambda tokens: "<|startoftranscript|>" in tokens,
"audio_vlm": lambda tokens: "<audio_soft_token>" in tokens,
"bicodec": lambda tokens: any(t.startswith("<|bicodec_") for t in tokens),
"dac": lambda tokens: "<|audio_start|>" in tokens
and "<|audio_end|>" in tokens
and "<|text_start|>" in tokens
and "<|text_end|>" in tokens,
"snac": lambda tokens: sum(1 for t in tokens if t.startswith("<custom_token_"))
> 10000,
}
def detect_audio_type(model_name: str, hf_token: Optional[str] = None) -> Optional[str]:
"""
Dynamically detect if a model is an audio model and return its type.
Fully dynamic — works for any model, not just known ones.
Uses tokenizer_config.json special tokens to detect all 6 audio types.
Returns: audio_type string ('snac', 'csm', 'bicodec', 'dac', 'whisper', 'audio_vlm') or None.
"""
if model_name in _audio_detection_cache:
return _audio_detection_cache[model_name]
result = _detect_audio_from_tokenizer(model_name, hf_token)
_audio_detection_cache[model_name] = result
if result:
logger.info(f"Model {model_name} detected as audio model: audio_type={result}")
return result
def _detect_audio_from_tokenizer(
model_name: str, hf_token: Optional[str] = None
) -> Optional[str]:
"""Detect audio type from tokenizer special tokens (for LLM-based audio models).
First checks local HF cache, then fetches tokenizer_config.json from HuggingFace.
Checks added_tokens_decoder for distinctive patterns.
"""
def _check_token_patterns(tok_config: dict) -> Optional[str]:
added = tok_config.get("added_tokens_decoder", {})
if not added:
return None
token_contents = [v.get("content", "") for v in added.values()]
for audio_type, check_fn in _AUDIO_TOKEN_PATTERNS.items():
if check_fn(token_contents):
return audio_type
return None
# 1) Check local HF cache first (works for gated/offline models)
try:
repo_dir = get_cache_path(model_name)
if repo_dir is not None and repo_dir.exists():
snapshots_dir = repo_dir / "snapshots"
if snapshots_dir.exists():
for snapshot in snapshots_dir.iterdir():
for tok_path in [
"tokenizer_config.json",
"LLM/tokenizer_config.json",
]:
tok_file = snapshot / tok_path
if tok_file.exists():
tok_config = json.loads(tok_file.read_text())
result = _check_token_patterns(tok_config)
if result:
return result
except Exception as e:
logger.debug(f"Could not check local cache for {model_name}: {e}")
# 2) Fall back to HuggingFace API
try:
import requests
import os
paths_to_try = ["tokenizer_config.json", "LLM/tokenizer_config.json"]
# Use provided token, or fall back to env
token = hf_token or os.environ.get("HF_TOKEN")
headers = {}
if token:
headers["Authorization"] = f"Bearer {token}"
for tok_path in paths_to_try:
url = f"https://huggingface.co/{model_name}/resolve/main/{tok_path}"
resp = requests.get(url, headers = headers, timeout = 15)
if not resp.ok:
continue
tok_config = resp.json()
result = _check_token_patterns(tok_config)
if result:
return result
return None
except Exception as e:
logger.debug(
f"Could not detect audio type from tokenizer for {model_name}: {e}"
)
return None
def is_audio_input_type(audio_type: Optional[str]) -> bool:
"""Check if an audio_type accepts audio input (ASR/speech understanding).
Whisper (ASR) and audio_vlm (Gemma3n) accept audio input.
"""
return audio_type in ("whisper", "audio_vlm")
def _is_mmproj(filename: str) -> bool:
"""Check if a GGUF filename is a vision projection (mmproj) file."""
return "mmproj" in filename.lower()
def _is_gguf_filename(filename: str) -> bool:
return filename.lower().endswith(".gguf")
def _iter_gguf_files(directory: Path, recursive: bool = False):
if not directory.is_dir():
return
iterator = directory.rglob("*") if recursive else directory.iterdir()
for f in iterator:
if f.is_file() and _is_gguf_filename(f.name):
yield f
def detect_mmproj_file(path: str, search_root: Optional[str] = None) -> Optional[str]:
"""
Find the mmproj (vision projection) GGUF file for a given model.
Args:
path: Directory to search — or a .gguf file (uses its parent dir
as the starting point).
search_root: Optional outer directory that should also be scanned
(and any directory between it and ``path``). This handles
local layouts where the model weights live in a quant-named
subdir (``snapshot/BF16/foo.gguf``) but the mmproj sits at
the snapshot root (``snapshot/mmproj-BF16.gguf``). When
``None``, only the immediate parent dir is scanned, matching
the historical behavior.
Returns:
Full path to the mmproj .gguf file, or None if not found.
"""
p = Path(path)
start_dir = p.parent if p.is_file() else p
if not start_dir.is_dir():
return None
# Build the list of dirs to scan: immediate dir first, then walk up
# to (and including) ``search_root`` if it is an ancestor. We walk
# incrementally rather than recursing into ``search_root`` so we
# don't accidentally pick up an mmproj from a sibling subdir
# belonging to a different model variant.
seen: set[Path] = set()
scan_order: list[Path] = []
def _add(d: Path) -> None:
try:
resolved = d.resolve()
except OSError:
return
if resolved in seen or not resolved.is_dir():
return
seen.add(resolved)
scan_order.append(resolved)
_add(start_dir)
# When ``path`` is a symlink (e.g. Ollama's ``.studio_links/...gguf``
# -> ``blobs/sha256-...``), the symlink's parent directory rarely
# contains the mmproj sibling; the real mmproj file lives next to
# the symlink target. Add the target's parent to the scan so vision
# GGUFs that are surfaced via symlinks are still recognised as
# vision models.
try:
if p.is_symlink() and p.is_file():
target_parent = p.resolve().parent
if target_parent.is_dir():
_add(target_parent)
except OSError:
pass
if search_root is not None:
try:
root_resolved = Path(search_root).resolve()
start_resolved = start_dir.resolve()
# Only walk if start_dir is inside (or equal to) search_root.
if root_resolved == start_resolved or (
start_resolved.is_relative_to(root_resolved)
if hasattr(start_resolved, "is_relative_to")
else str(start_resolved).startswith(str(root_resolved) + "/")
):
cur = start_resolved
# Walk up from start_dir to (and including) root_resolved.
while cur != root_resolved and cur.parent != cur:
cur = cur.parent
_add(cur)
if cur == root_resolved:
break
except OSError:
pass
for d in scan_order:
for f in _iter_gguf_files(d):
if _is_mmproj(f.name):
return str(f.resolve())
return None
def detect_gguf_model(path: str) -> Optional[str]:
"""
Check if the given local path is or contains a GGUF model file.
Handles two cases:
1. path is a direct .gguf file path
2. path is a directory containing .gguf files
Skips mmproj (vision projection) files — those must be passed via
``--mmproj``, not ``-m``. Use :func:`detect_mmproj_file` instead.
Returns the full path to the .gguf file if found, None otherwise.
For HuggingFace repo detection, use detect_gguf_model_remote() instead.
"""
p = Path(path)
# Case 1: direct .gguf file
if p.suffix.lower() == ".gguf" and p.is_file():
if _is_mmproj(p.name):
return None
# Use absolute (not resolve) to preserve symlink names -- e.g.
# Ollama .studio_links/model.gguf -> blobs/sha256-... should
# keep the readable symlink name, not the opaque blob hash.
return str(p.absolute())
# Case 2: directory containing .gguf files (skip mmproj)
if p.is_dir():
gguf_files = sorted(
(f for f in _iter_gguf_files(p) if not _is_mmproj(f.name)),
key = lambda f: f.stat().st_size,
reverse = True,
)
if gguf_files:
return str(gguf_files[0].resolve())
return None
# Preferred GGUF quantization levels, in descending priority.
# Q4_K_M is a good default: small, fast, acceptable quality.
# UD (Unsloth Dynamic) variants are always preferred over standard quants
# because they provide better quality per bit. If the repo has no UD variants
# (e.g., bartowski repos), the standard quants are used as fallback.
# Ordered by best size/quality tradeoff, not raw quality.
_GGUF_QUANT_PREFERENCE = [
# UD variants (best quality per bit) -- Q4 is the sweet spot
"UD-Q4_K_XL",
"UD-Q4_K_L",
"UD-Q5_K_XL",
"UD-Q3_K_XL",
"UD-Q6_K_XL",
"UD-Q6_K_S",
"UD-Q8_K_XL",
"UD-Q2_K_XL",
"UD-IQ4_NL",
"UD-IQ4_XS",
"UD-IQ3_S",
"UD-IQ3_XXS",
"UD-IQ2_M",
"UD-IQ2_XXS",
"UD-IQ1_M",
"UD-IQ1_S",
# Standard quants (fallback for non-Unsloth repos)
"Q4_K_M",
"Q4_K_S",
"Q5_K_M",
"Q5_K_S",
"Q6_K",
"Q8_0",
"Q3_K_M",
"Q3_K_L",
"Q3_K_S",
"Q2_K",
"Q2_K_L",
"IQ4_NL",
"IQ4_XS",
"IQ3_M",
"IQ3_XXS",
"IQ2_M",
"IQ1_M",
"F16",
"BF16",
"F32",
]
def _pick_best_gguf(filenames: list[str]) -> Optional[str]:
"""
Pick the best GGUF file from a list of filenames.
Prefers quantization levels in _GGUF_QUANT_PREFERENCE order.
Falls back to the first .gguf file found.
"""
gguf_files = [f for f in filenames if f.lower().endswith(".gguf")]
if not gguf_files:
return None
# Try preferred quantization levels
for quant in _GGUF_QUANT_PREFERENCE:
for f in gguf_files:
if quant in f:
return f
# Fallback: first GGUF file
return gguf_files[0]
@dataclass
class GgufVariantInfo:
"""A single GGUF quantization variant from a HuggingFace repo."""
filename: str # e.g., "gemma-3-4b-it-Q4_K_M.gguf"
quant: str # e.g., "Q4_K_M" (extracted from filename)
size_bytes: int # file size
def _extract_quant_label(filename: str) -> str:
"""
Extract quantization label like Q4_K_M, IQ4_XS, BF16 from a GGUF filename.
Examples:
"gemma-3-4b-it-Q4_K_M.gguf""Q4_K_M"
"model-IQ4_NL.gguf""IQ4_NL"
"model-BF16.gguf""BF16"
"model-UD-IQ1_S.gguf""UD-IQ1_S"
"model-UD-TQ1_0.gguf""UD-TQ1_0"
"MXFP4_MOE/model-MXFP4_MOE-0001.gguf""MXFP4_MOE"
"""
import re
# Use only the basename (rfilename may include directory)
basename = filename.rsplit("/", 1)[-1]
# Strip .gguf and any shard suffix (-00001-of-00010)
stem = re.sub(r"-\d{3,}-of-\d{3,}", "", basename.rsplit(".", 1)[0])
# Match known quantization patterns
match = re.search(
r"(UD-)?" # Optional UD- prefix (Ultra Discrete)
r"(MXFP[0-9]+(?:_[A-Z0-9]+)*" # MXFP variants: MXFP4, MXFP4_MOE
r"|IQ[0-9]+_[A-Z]+(?:_[A-Z0-9]+)?" # IQ variants: IQ4_XS, IQ4_NL, IQ1_S
r"|TQ[0-9]+_[0-9]+" # Ternary quant: TQ1_0, TQ2_0
r"|Q[0-9]+_K_[A-Z]+" # K-quant: Q4_K_M, Q3_K_S
r"|Q[0-9]+_[0-9]+" # Standard: Q8_0, Q5_1
r"|Q[0-9]+_K" # Short K-quant: Q6_K
r"|BF16|F16|F32)", # Full precision
stem,
re.IGNORECASE,
)
if match:
prefix = match.group(1) or ""
return f"{prefix}{match.group(2)}"
# Fallback: last segment after hyphen
return stem.split("-")[-1]
def list_gguf_variants(
repo_id: str,
hf_token: Optional[str] = None,
) -> tuple[list[GgufVariantInfo], bool]:
"""
List all GGUF quantization variants in a HuggingFace repo.
Separates main model files from mmproj (vision projection) files.
The presence of mmproj files indicates a vision-capable model.
Returns:
(variants, has_vision): list of non-mmproj GGUF variants + vision flag.
"""
from huggingface_hub import model_info as hf_model_info
info = hf_model_info(repo_id, token = hf_token, files_metadata = True)
variants: list[GgufVariantInfo] = []
has_vision = False
quant_totals: dict[str, int] = {} # quant -> total bytes
quant_first_file: dict[str, str] = {} # quant -> first filename (for display)
for sibling in info.siblings:
fname = sibling.rfilename
if not fname.lower().endswith(".gguf"):
continue
size = sibling.size or 0
# mmproj files are vision projection models, not main model files
if "mmproj" in fname.lower():
has_vision = True
continue
quant = _extract_quant_label(fname)
quant_totals[quant] = quant_totals.get(quant, 0) + size
if quant not in quant_first_file:
quant_first_file[quant] = fname
for quant, total_size in quant_totals.items():
variants.append(
GgufVariantInfo(
filename = quant_first_file[quant],
quant = quant,
size_bytes = total_size,
)
)
# Sort by size descending (largest = best quality first).
# Recommended pinning and OOM demotion are handled client-side
# where GPU VRAM info is available.
variants.sort(key = lambda v: -v.size_bytes)
return variants, has_vision
def _resolve_gguf_dir(p: Path) -> Optional[Path]:
"""Resolve a path to the directory containing GGUF variants.
If *p* is already a directory, returns it directly. If *p* is a ``.gguf``
file whose parent directory has model metadata (``config.json`` or
``adapter_config.json``), returns the parent -- all GGUFs in that
directory belong to the same model. Returns ``None`` for loose standalone
GGUFs (no config) to avoid cross-wiring unrelated models.
"""
if p.is_dir():
return p
if p.is_file() and p.suffix.lower() == ".gguf":
parent = p.parent
if (parent / "config.json").exists() or (
parent / "adapter_config.json"
).exists():
return parent
return None
def list_local_gguf_variants(
directory: str,
) -> tuple[list[GgufVariantInfo], bool]:
"""List GGUF quantization variants in a local directory.
Mirrors :func:`list_gguf_variants` but reads from the filesystem
instead of the HuggingFace API. Aggregates shard sizes by quant
label so that split GGUFs appear as a single variant.
Returns:
(variants, has_vision): list of non-mmproj GGUF variants + vision flag.
"""
p = _resolve_gguf_dir(Path(directory))
if p is None:
return [], False
quant_totals: dict[str, int] = {}
quant_first_file: dict[str, str] = {}
has_vision = False
# Recurse so variant-specific subdirectories (e.g. ``BF16/...gguf``
# used by some HF GGUF repos for the largest quants) are picked up.
# Filenames in the result preserve the relative subpath so that
# ``_find_local_gguf_by_variant`` can locate the file again.
for f in sorted(_iter_gguf_files(p, recursive = True)):
if _is_mmproj(f.name):
has_vision = True
continue
try:
size = f.stat().st_size
except OSError:
size = 0
quant = _extract_quant_label(f.name)
quant_totals[quant] = quant_totals.get(quant, 0) + size
# Only compute the (potentially expensive) relative path when this
# is the first file we've seen for this quant -- after that we'd
# discard the result anyway. Use posix-style separators so the
# filename matches what ``list_gguf_variants`` (the remote HF
# API path) returns on every platform; otherwise Windows would
# emit ``BF16\foo.gguf`` here.
if quant not in quant_first_file:
quant_first_file[quant] = f.relative_to(p).as_posix()
variants = [
GgufVariantInfo(
filename = quant_first_file[q],
quant = q,
size_bytes = s,
)
for q, s in quant_totals.items()
]
variants.sort(key = lambda v: -v.size_bytes)
return variants, has_vision
def _find_local_gguf_by_variant(directory: str, variant: str) -> Optional[str]:
"""Find the GGUF file in *directory* matching a quantization *variant*.
For sharded GGUFs (multiple files with the same quant label), returns
the first shard (sorted by name) which is what ``llama-server -m`` expects.
Returns the resolved absolute path, or ``None`` if no match.
"""
p = _resolve_gguf_dir(Path(directory))
if p is None:
return None
# Recurse into subdirectories so variants stored under a quant-named
# subdir (e.g. ``BF16/foo-BF16-00001-of-00002.gguf``) are found.
matches = sorted(
f
for f in _iter_gguf_files(p, recursive = True)
if not _is_mmproj(f.name) and _extract_quant_label(f.name) == variant
)
if matches:
return str(matches[0].resolve())
return None
def detect_gguf_model_remote(
repo_id: str,
hf_token: Optional[str] = None,
) -> Optional[str]:
"""
Check if a HuggingFace repo contains GGUF files.
Returns the filename of the best GGUF file in the repo, or None.
"""
try:
from huggingface_hub import model_info as hf_model_info
info = hf_model_info(repo_id, token = hf_token)
repo_files = [s.rfilename for s in info.siblings]
return _pick_best_gguf(repo_files)
except Exception as e:
logger.debug(f"Could not check GGUF files for '{repo_id}': {e}")
return None
def download_gguf_file(
repo_id: str,
filename: str,
hf_token: Optional[str] = None,
) -> str:
"""
Download a specific GGUF file from a HuggingFace repo.
Returns the local path to the downloaded file.
"""
from huggingface_hub import hf_hub_download
local_path = hf_hub_download(
repo_id = repo_id,
filename = filename,
token = hf_token,
)
return local_path
# Cache embedding detection results per session to avoid repeated HF API calls
_embedding_detection_cache: Dict[tuple, bool] = {}
def is_embedding_model(model_name: str, hf_token: Optional[str] = None) -> bool:
"""
Detect embedding/sentence-transformer models using HuggingFace model metadata.
Uses a belt-and-suspenders approach combining three signals:
1. "sentence-transformers" in model tags
2. "feature-extraction" in model tags
3. pipeline_tag is "sentence-similarity" or "feature-extraction"
This catches all known embedding models including those like gte-modernbert
whose library_name is "transformers" rather than "sentence-transformers".
Args:
model_name: Model identifier (HF repo or local path)
hf_token: Optional HF token for accessing gated/private models
Returns:
True if the model is an embedding model, False otherwise.
Defaults to False for local paths or on errors.
"""
cache_key = (model_name, hf_token)
if cache_key in _embedding_detection_cache:
return _embedding_detection_cache[cache_key]
# Local paths: check for sentence-transformer marker file (modules.json)
if is_local_path(model_name):
local_dir = normalize_path(model_name)
is_emb = os.path.isfile(os.path.join(local_dir, "modules.json"))
_embedding_detection_cache[cache_key] = is_emb
return is_emb
try:
from huggingface_hub import model_info as hf_model_info
info = hf_model_info(model_name, token = hf_token)
tags = set(info.tags or [])
pipeline_tag = info.pipeline_tag or ""
is_emb = (
"sentence-transformers" in tags
or "feature-extraction" in tags
or pipeline_tag in ("sentence-similarity", "feature-extraction")
)
_embedding_detection_cache[cache_key] = is_emb
if is_emb:
logger.info(
f"Model {model_name} detected as embedding model: "
f"pipeline_tag={pipeline_tag}, "
f"sentence-transformers in tags={('sentence-transformers' in tags)}, "
f"feature-extraction in tags={('feature-extraction' in tags)}"
)
return is_emb
except Exception as e:
logger.warning(f"Could not determine if {model_name} is embedding model: {e}")
_embedding_detection_cache[cache_key] = False
return False
def _has_model_weight_files(model_dir: Path) -> bool:
"""Return True when a directory contains loadable model weights."""
for item in model_dir.iterdir():
if not item.is_file():
continue
suffix = item.suffix.lower()
if suffix == ".safetensors":
return True
if suffix == ".gguf":
return "mmproj" not in item.name.lower()
if suffix == ".bin":
name = item.name.lower()
if (
name.startswith("pytorch_model")
or name.startswith("model")
or name.startswith("adapter_model")
or name.startswith("consolidated")
):
return True
return False
def _detect_training_output_type(model_dir: Path) -> Optional[str]:
"""Classify a Studio training output as LoRA or full finetune."""
adapter_config = model_dir / "adapter_config.json"
adapter_model = model_dir / "adapter_model.safetensors"
if adapter_config.exists() or adapter_model.exists():
return "lora"
config_file = model_dir / "config.json"
if config_file.exists() and _has_model_weight_files(model_dir):
return "merged"
return None
def _looks_like_lora_adapter(model_dir: Path) -> bool:
return model_dir.is_dir() and (
(model_dir / "adapter_config.json").exists()
or any(model_dir.glob("adapter_model*.safetensors"))
or any(model_dir.glob("adapter_model*.bin"))
)
def scan_trained_models(
outputs_dir: str = str(outputs_root()),
) -> List[Tuple[str, str, str]]:
"""
Scan outputs folder for trained Studio models.
Returns:
List of tuples: [(display_name, model_path, model_type), ...]
model_type is "lora" for adapter runs and "merged" for full finetunes.
"""
trained_models = []
outputs_path = resolve_output_dir(outputs_dir)
if not outputs_path.exists():
logger.warning(f"Outputs directory not found: {outputs_dir}")
return trained_models
try:
for item in outputs_path.iterdir():
if item.is_dir():
model_type = _detect_training_output_type(item)
if model_type is None:
continue
display_name = item.name
model_path = str(item)
trained_models.append((display_name, model_path, model_type))
logger.debug("Found trained model: %s (%s)", display_name, model_type)
# Sort by modification time (newest first)
trained_models.sort(key = lambda x: Path(x[1]).stat().st_mtime, reverse = True)
logger.info(
"Found %s trained models in %s",
len(trained_models),
outputs_dir,
)
return trained_models
except Exception as e:
logger.error(f"Error scanning outputs folder: {e}")
return []
def scan_exported_models(
exports_dir: str = str(exports_root()),
) -> List[Tuple[str, str, str, Optional[str]]]:
"""
Scan exports folder for exported models (merged, LoRA, GGUF).
Supports two directory layouts:
- Two-level: {run}/{checkpoint}/ (merged & LoRA exports)
- Flat: {name}-finetune-gguf/ (GGUF exports)
Returns:
List of tuples: [(display_name, model_path, export_type, base_model), ...]
export_type: "lora" | "merged" | "gguf"
"""
results = []
exports_path = resolve_export_dir(exports_dir)
if not exports_path.exists():
return results
try:
for run_dir in exports_path.iterdir():
if not run_dir.is_dir():
continue
# Check for flat GGUF export (e.g. exports/gemma-3-4b-it-finetune-gguf/)
# Filter out mmproj (vision projection) files — they aren't loadable as main models
gguf_files = [
f for f in _iter_gguf_files(run_dir) if not _is_mmproj(f.name)
]
if gguf_files:
base_model = None
export_meta = run_dir / "export_metadata.json"
try:
if export_meta.exists():
meta = json.loads(export_meta.read_text())
base_model = meta.get("base_model")
except Exception:
pass
display_name = run_dir.name
model_path = str(gguf_files[0]) # path to the .gguf file
results.append((display_name, model_path, "gguf", base_model))
logger.debug(f"Found GGUF export: {display_name}")
continue
# Two-level: {run}/{checkpoint}/
for checkpoint_dir in run_dir.iterdir():
if not checkpoint_dir.is_dir():
continue
adapter_config = checkpoint_dir / "adapter_config.json"
config_file = checkpoint_dir / "config.json"
has_weights = any(checkpoint_dir.glob("*.safetensors")) or any(
checkpoint_dir.glob("*.bin")
)
has_gguf = any(_iter_gguf_files(checkpoint_dir))
base_model = None
export_type = None
if adapter_config.exists():
export_type = "lora"
try:
cfg = json.loads(adapter_config.read_text())
base_model = cfg.get("base_model_name_or_path")
except Exception:
pass
elif config_file.exists() and has_weights:
export_type = "merged"
export_meta = checkpoint_dir / "export_metadata.json"
try:
if export_meta.exists():
meta = json.loads(export_meta.read_text())
base_model = meta.get("base_model")
except Exception:
pass
elif has_gguf:
export_type = "gguf"
gguf_list = list(_iter_gguf_files(checkpoint_dir))
# Check checkpoint_dir first, then fall back to parent run_dir
# (export.py writes metadata to the top-level export directory)
for meta_dir in (checkpoint_dir, run_dir):
export_meta = meta_dir / "export_metadata.json"
try:
if export_meta.exists():
meta = json.loads(export_meta.read_text())
base_model = meta.get("base_model")
if base_model:
break
except Exception:
pass
display_name = f"{run_dir.name} / {checkpoint_dir.name}"
model_path = str(gguf_list[0]) if gguf_list else str(checkpoint_dir)
results.append((display_name, model_path, export_type, base_model))
logger.debug(f"Found GGUF export: {display_name}")
continue
else:
continue
# Fallback: read base model from the original training run's
# adapter_config.json in ./outputs/{run_name}/
if not base_model:
outputs_adapter_cfg = (
resolve_output_dir(run_dir.name) / "adapter_config.json"
)
try:
if outputs_adapter_cfg.exists():
cfg = json.loads(outputs_adapter_cfg.read_text())
base_model = cfg.get("base_model_name_or_path")
except Exception:
pass
display_name = f"{run_dir.name} / {checkpoint_dir.name}"
model_path = str(checkpoint_dir)
results.append((display_name, model_path, export_type, base_model))
logger.debug(f"Found exported model: {display_name} ({export_type})")
results.sort(key = lambda x: Path(x[1]).stat().st_mtime, reverse = True)
logger.info(f"Found {len(results)} exported models in {exports_dir}")
return results
except Exception as e:
logger.error(f"Error scanning exports folder: {e}")
return []
def get_base_model_from_checkpoint(checkpoint_path: str) -> Optional[str]:
"""Read the base model name from a local training or checkpoint directory."""
try:
checkpoint_path_obj = Path(checkpoint_path)
adapter_config_path = checkpoint_path_obj / "adapter_config.json"
if adapter_config_path.exists():
with open(adapter_config_path, "r") as f:
config = json.load(f)
base_model = config.get("base_model_name_or_path")
if base_model:
logger.info(
"Detected base model from adapter_config.json: %s", base_model
)
return base_model
config_path = checkpoint_path_obj / "config.json"
if config_path.exists():
with open(config_path, "r") as f:
config = json.load(f)
for key in ("model_name", "_name_or_path"):
base_model = config.get(key)
if base_model and str(base_model) != str(checkpoint_path_obj):
logger.info(
"Detected base model from config.json (%s): %s",
key,
base_model,
)
return base_model
training_args_path = checkpoint_path_obj / "training_args.bin"
if training_args_path.exists():
try:
import torch
training_args = torch.load(training_args_path)
if hasattr(training_args, "model_name_or_path"):
base_model = training_args.model_name_or_path
logger.info(
"Detected base model from training_args.bin: %s", base_model
)
return base_model
except Exception as e:
logger.warning(f"Could not load training_args.bin: {e}")
dir_name = checkpoint_path_obj.name
if dir_name.startswith("unsloth_"):
parts = dir_name.split("_")
if len(parts) >= 2:
model_parts = parts[1:-1]
base_model = "unsloth/" + "_".join(model_parts)
logger.info("Detected base model from directory name: %s", base_model)
return base_model
logger.warning(f"Could not detect base model for checkpoint: {checkpoint_path}")
return None
except Exception as e:
logger.error(f"Error reading base model from checkpoint config: {e}")
return None
def get_base_model_from_lora(lora_path: str) -> Optional[str]:
"""
Read the base model name from a LoRA adapter's config.
Args:
lora_path: Path to the LoRA adapter directory
Returns:
Base model identifier or None if not found
"""
try:
lora_path_obj = Path(lora_path)
if not _looks_like_lora_adapter(lora_path_obj):
return None
# Try adapter_config.json first
adapter_config_path = lora_path_obj / "adapter_config.json"
if adapter_config_path.exists():
with open(adapter_config_path, "r") as f:
config = json.load(f)
base_model = config.get("base_model_name_or_path")
if base_model:
logger.info(
f"Detected base model from adapter_config.json: {base_model}"
)
return base_model
# Fallback: try training_args.bin (requires torch)
training_args_path = lora_path_obj / "training_args.bin"
if training_args_path.exists():
try:
import torch
training_args = torch.load(training_args_path)
if hasattr(training_args, "model_name_or_path"):
base_model = training_args.model_name_or_path
logger.info(
f"Detected base model from training_args.bin: {base_model}"
)
return base_model
except Exception as e:
logger.warning(f"Could not load training_args.bin: {e}")
# Last resort: parse from directory name
# Format: unsloth_Meta-Llama-3.1-8B-Instruct-bnb-4bit_timestamp
dir_name = lora_path_obj.name
if dir_name.startswith("unsloth_"):
# Remove timestamp suffix (usually _1234567890)
parts = dir_name.split("_")
# Reconstruct model name
if len(parts) >= 2:
model_parts = parts[1:-1] # Skip "unsloth" and timestamp
base_model = "unsloth/" + "_".join(model_parts)
logger.info(f"Detected base model from directory name: {base_model}")
return base_model
logger.warning(f"Could not detect base model for LoRA: {lora_path}")
return None
except Exception as e:
logger.error(f"Error reading base model from LoRA config: {e}")
return None
# Status indicators that appear in UI dropdowns
UI_STATUS_INDICATORS = [" (Ready)", " (Loading...)", " (Active)", ""]
def load_model_defaults(model_name: str) -> Dict[str, Any]:
"""
Load default training parameters for a model from YAML file.
Args:
model_name: Model identifier (e.g., "unsloth/Meta-Llama-3.1-8B-bnb-4bit")
Returns:
Dictionary with default parameters from YAML file, or empty dict if not found
The function looks for a YAML file in configs/model_defaults/ (including subfolders)
based on the model name or its aliases from MODEL_NAME_MAPPING.
If no specific file exists, it falls back to default.yaml.
"""
try:
# Get the script directory to locate configs
script_dir = Path(__file__).parent.parent.parent
defaults_dir = script_dir / "assets" / "configs" / "model_defaults"
# First, check if model is in the mapping
if model_name.lower() in _REVERSE_MODEL_MAPPING:
canonical_file = _REVERSE_MODEL_MAPPING[model_name.lower()]
# Search in subfolders and root
for config_path in defaults_dir.rglob(canonical_file):
if config_path.is_file():
with open(config_path, "r", encoding = "utf-8") as f:
config = yaml.safe_load(f) or {}
logger.info(
f"Loaded model defaults from {config_path} (via mapping)"
)
return config
# If model_name is a local path (e.g. /home/.../Spark-TTS-0.5B/LLM from
# adapter_config.json, or C:\Users\...\model on Windows), try matching
# the last 1-2 path components against the registry
# (e.g. "Spark-TTS-0.5B/LLM").
_is_local_path = is_local_path(model_name)
# Normalize Windows backslash paths so Path().parts splits correctly
# on POSIX/WSL hosts (pathlib treats backslashes as literals on Linux).
_normalized = normalize_path(model_name) if _is_local_path else model_name
if model_name.lower() not in _REVERSE_MODEL_MAPPING and _is_local_path:
parts = Path(_normalized).parts
for depth in [2, 1]:
if len(parts) >= depth:
suffix = "/".join(parts[-depth:])
if suffix.lower() in _REVERSE_MODEL_MAPPING:
canonical_file = _REVERSE_MODEL_MAPPING[suffix.lower()]
for config_path in defaults_dir.rglob(canonical_file):
if config_path.is_file():
with open(config_path, "r", encoding = "utf-8") as f:
config = yaml.safe_load(f) or {}
logger.info(
f"Loaded model defaults from {config_path} (via path suffix '{suffix}')"
)
return config
# Try exact model name match (for backward compatibility).
# For local filesystem paths, use only the directory basename to
# avoid passing absolute paths (e.g. C:\...) into rglob which
# raises "Non-relative patterns are unsupported" on Windows.
_lookup_name = Path(_normalized).name if _is_local_path else model_name
model_filename = _lookup_name.replace("/", "_") + ".yaml"
# Search in subfolders and root
for config_path in defaults_dir.rglob(model_filename):
if config_path.is_file():
with open(config_path, "r", encoding = "utf-8") as f:
config = yaml.safe_load(f) or {}
logger.info(f"Loaded model defaults from {config_path}")
return config
# Fall back to default.yaml
default_config_path = defaults_dir / "default.yaml"
if default_config_path.exists():
with open(default_config_path, "r", encoding = "utf-8") as f:
config = yaml.safe_load(f) or {}
logger.info(f"Loaded default model defaults from {default_config_path}")
return config
logger.warning(f"No default config found for model {model_name}")
return {}
except Exception as e:
logger.error(f"Error loading model defaults for {model_name}: {e}")
return {}
@dataclass
class ModelConfig:
"""Configuration for a model to load"""
identifier: str # Clean model identifier (org/name or path)
display_name: str # Original UI display name
path: str # Normalized filesystem path
is_local: bool # Is this a local file vs HF model?
is_cached: bool # Is this already in HF cache?
is_vision: bool # Is this a vision model?
is_lora: bool # Is this a lora adapter?
is_gguf: bool = False # Is this a GGUF model?
is_audio: bool = False # Is this a TTS audio model?
audio_type: Optional[str] = (
None # Audio codec type: 'snac', 'csm', 'bicodec', 'dac'
)
has_audio_input: bool = False # Accepts audio input (ASR/speech understanding)
gguf_file: Optional[str] = None # Full path to the .gguf file (local mode)
gguf_mmproj_file: Optional[str] = (
None # Full path to the mmproj .gguf file (vision projection)
)
gguf_hf_repo: Optional[str] = (
None # HF repo ID for -hf mode (e.g. "unsloth/gemma-3-4b-it-GGUF")
)
gguf_variant: Optional[str] = None # Quantization variant (e.g. "Q4_K_M")
base_model: Optional[str] = None # Base model (for LoRAs)
@classmethod
def from_lora_path(
cls, lora_path: str, hf_token: Optional[str] = None
) -> Optional["ModelConfig"]:
"""
Create ModelConfig from a local LoRA adapter path.
Automatically detects the base model from adapter config.
Args:
lora_path: Path to LoRA adapter (e.g., "./outputs/unsloth_Meta-Llama-3.1_.../")
hf_token: HF token for vision detection
Returns:
ModelConfig for the LoRA adapter
"""
try:
lora_path_obj = Path(lora_path)
if not lora_path_obj.exists():
logger.error(f"LoRA path does not exist: {lora_path}")
return None
# Get base model
base_model = get_base_model_from_lora(lora_path)
if not base_model:
logger.error(f"Could not determine base model for LoRA: {lora_path}")
return None
# Check if base model is vision
is_vision = is_vision_model(base_model, hf_token = hf_token)
# Check if base model is audio
audio_type = detect_audio_type(base_model, hf_token = hf_token)
display_name = lora_path_obj.name
identifier = lora_path # Use path as identifier for local LoRAs
return cls(
identifier = identifier,
display_name = display_name,
path = lora_path,
is_local = True,
is_cached = True, # Local LoRAs are always "cached"
is_vision = is_vision,
is_lora = True,
is_audio = audio_type is not None and audio_type != "audio_vlm",
audio_type = audio_type,
has_audio_input = is_audio_input_type(audio_type),
base_model = base_model,
)
except Exception as e:
logger.error(f"Error creating ModelConfig from LoRA path: {e}")
return None
@classmethod
def from_identifier(
cls,
model_id: str,
hf_token: Optional[str] = None,
is_lora: bool = False,
gguf_variant: Optional[str] = None,
) -> Optional["ModelConfig"]:
"""
Create ModelConfig from a clean model identifier.
For FastAPI routes where the frontend sends sanitized model paths.
No Gradio dropdown parsing - expects clean identifiers like:
- "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
- "./outputs/my_lora_adapter"
- "/absolute/path/to/model"
Args:
model_id: Clean model identifier (HF repo name or local path)
hf_token: Optional HF token for vision detection on gated models
is_lora: Whether this is a LoRA adapter
gguf_variant: Optional GGUF quantization variant (e.g. "Q4_K_M").
For remote GGUF repos, specifies which quant to load via -hf.
If None, auto-selects using _pick_best_gguf().
Returns:
ModelConfig or None if configuration cannot be created
"""
if not model_id or not model_id.strip():
return None
identifier = model_id.strip()
is_local = is_local_path(identifier)
path = normalize_path(identifier) if is_local else identifier
# Add unsloth/ prefix for shorthand HF models
if not is_local and "/" not in identifier:
identifier = f"unsloth/{identifier}"
path = identifier
# Preserve requested casing, but if a case-variant already exists in local HF cache,
# reuse that exact repo_id spelling to avoid one-time re-downloads after #2592.
if not is_local:
resolved_identifier = resolve_cached_repo_id_case(identifier)
if resolved_identifier != identifier:
logger.info(
"Using cached repo_id casing '%s' for requested '%s'",
resolved_identifier,
identifier,
)
identifier = resolved_identifier
path = resolved_identifier
# Auto-detect GGUF models (check before LoRA/vision detection)
if is_local:
if gguf_variant:
gguf_file = _find_local_gguf_by_variant(path, gguf_variant)
else:
gguf_file = detect_gguf_model(path)
if gguf_file:
display_name = Path(gguf_file).stem
logger.info(f"Detected local GGUF model: {gguf_file}")
# Detect vision: check if base model is vision, then look for mmproj
mmproj_file = None
gguf_is_vision = False
gguf_dir = Path(gguf_file).parent
# Determine if this is a vision model from export metadata
base_is_vision = False
meta_path = gguf_dir / "export_metadata.json"
if meta_path.exists():
try:
meta = json.loads(meta_path.read_text())
base = meta.get("base_model")
if base and is_vision_model(base, hf_token = hf_token):
base_is_vision = True
logger.info(f"GGUF base model '{base}' is a vision model")
except Exception as e:
logger.debug(f"Could not read export metadata: {e}")
# If vision (or mmproj happens to exist), find the mmproj
# file. The recursive variant scan in
# ``_find_local_gguf_by_variant`` may have returned a
# weight file inside a quant-named subdir (e.g.
# ``.../BF16/foo.gguf``) while ``mmproj-*.gguf`` lives
# at the snapshot root. Pass ``search_root=path`` so
# ``detect_mmproj_file`` walks up to the snapshot root
# instead of seeing only the weight file's immediate
# parent.
mmproj_file = detect_mmproj_file(gguf_file, search_root = path)
if mmproj_file:
gguf_is_vision = True
logger.info(f"Detected mmproj for vision: {mmproj_file}")
elif base_is_vision:
logger.warning(
f"Base model is vision but no mmproj file found in {gguf_dir}"
)
return cls(
identifier = identifier,
display_name = display_name,
path = path,
is_local = True,
is_cached = True,
is_vision = gguf_is_vision,
is_lora = False,
is_gguf = True,
gguf_file = gguf_file,
gguf_mmproj_file = mmproj_file,
)
else:
# Check if the HF repo contains GGUF files
gguf_filename = detect_gguf_model_remote(identifier, hf_token = hf_token)
if gguf_filename:
# Preflight: verify llama-server binary exists BEFORE user waits
# for a multi-GB download that llama-server handles natively
from core.inference.llama_cpp import LlamaCppBackend
if not LlamaCppBackend._find_llama_server_binary():
raise RuntimeError(
"llama-server binary not found — cannot load GGUF models. "
"Run setup.sh to build it, or set LLAMA_SERVER_PATH."
)
# Use list_gguf_variants() to detect vision & resolve variant
variants, has_vision = list_gguf_variants(identifier, hf_token = hf_token)
variant = gguf_variant
if not variant:
# Auto-select best quantization
variant_filenames = [v.filename for v in variants]
best = _pick_best_gguf(variant_filenames)
if best:
variant = _extract_quant_label(best)
else:
variant = "Q4_K_M" # Fallback — llama-server's own default
display_name = f"{identifier.split('/')[-1]} ({variant})"
logger.info(
f"Detected remote GGUF repo '{identifier}', "
f"variant={variant}, vision={has_vision}"
)
return cls(
identifier = identifier,
display_name = display_name,
path = identifier,
is_local = False,
is_cached = False,
is_vision = has_vision,
is_lora = False,
is_gguf = True,
gguf_file = None,
gguf_hf_repo = identifier,
gguf_variant = variant,
)
# Auto-detect LoRA for local paths (check adapter_config.json on disk)
if not is_lora and is_local:
detected_base = (
get_base_model_from_lora(path)
if _looks_like_lora_adapter(Path(path))
else None
)
if detected_base:
is_lora = True
logger.info(
f"Auto-detected local LoRA adapter at '{path}' (base: {detected_base})"
)
# Auto-detect LoRA for remote HF models (check repo file listing)
if not is_lora and not is_local:
try:
from huggingface_hub import model_info as hf_model_info
info = hf_model_info(identifier, token = hf_token)
repo_files = [s.rfilename for s in info.siblings]
if "adapter_config.json" in repo_files:
is_lora = True
logger.info(f"Auto-detected remote LoRA adapter: '{identifier}'")
except Exception as e:
logger.debug(
f"Could not check remote LoRA status for '{identifier}': {e}"
)
# Handle LoRA adapters
base_model = None
if is_lora:
if is_local:
# Local LoRA: read adapter_config.json from disk
base_model = get_base_model_from_lora(path)
else:
# Remote LoRA: download adapter_config.json from HF
try:
from huggingface_hub import hf_hub_download
config_path = hf_hub_download(
identifier, "adapter_config.json", token = hf_token
)
with open(config_path, "r") as f:
adapter_config = json.load(f)
base_model = adapter_config.get("base_model_name_or_path")
if base_model:
logger.info(f"Resolved remote LoRA base model: '{base_model}'")
except Exception as e:
logger.warning(
f"Could not download adapter_config.json for '{identifier}': {e}"
)
if not base_model:
logger.warning(f"Could not determine base model for LoRA '{path}'")
return None
check_model = base_model
else:
check_model = identifier
vision = is_vision_model(check_model, hf_token = hf_token)
audio_type_val = detect_audio_type(check_model, hf_token = hf_token)
has_audio_in = is_audio_input_type(audio_type_val)
display_name = Path(path).name if is_local else identifier.split("/")[-1]
return cls(
identifier = identifier,
display_name = display_name,
path = path,
is_local = is_local,
is_cached = is_model_cached(identifier) if not is_local else True,
is_vision = vision,
is_lora = is_lora,
is_audio = audio_type_val is not None and audio_type_val != "audio_vlm",
audio_type = audio_type_val,
has_audio_input = has_audio_in,
base_model = base_model,
)
@classmethod
def from_ui_selection(
cls,
dropdown_value: Optional[str],
search_value: Optional[str],
local_models: list = None,
hf_token: Optional[str] = None,
is_lora: bool = False,
) -> Optional["ModelConfig"]:
"""
Create a universal ModelConfig from UI dropdown/search selections.
Handles base models and LoRA adapters.
"""
selected = None
if search_value and search_value.strip():
selected = search_value.strip()
elif dropdown_value:
selected = dropdown_value
if not selected:
return None
display_name = selected
# Use the correct 'local_models' parameter to resolve display names
if " (Active)" in selected or " (Ready)" in selected:
clean_display_name = selected.replace(" (Active)", "").replace(
" (Ready)", ""
)
if local_models:
for local_display, local_path in local_models:
if local_display == clean_display_name:
selected = local_path
break
# Clean all UI status indicators to get the final identifier
identifier = selected
for status in UI_STATUS_INDICATORS:
identifier = identifier.replace(status, "")
identifier = identifier.strip()
is_local = is_local_path(identifier)
path = normalize_path(identifier) if is_local else identifier
# Add unsloth/ prefix for shorthand HF models
if not is_local and "/" not in identifier:
identifier = f"unsloth/{identifier}"
path = identifier
if not is_local:
resolved_identifier = resolve_cached_repo_id_case(identifier)
if resolved_identifier != identifier:
identifier = resolved_identifier
path = resolved_identifier
# --- Logic for Base Model and Vision Detection ---
base_model = None
is_vision = False
if is_lora:
# For a LoRA, we MUST find its base model.
base_model = get_base_model_from_lora(path)
if not base_model:
logger.warning(
f"Could not determine base model for LoRA '{path}'. Cannot create config."
)
return None # Cannot proceed without a base model
# A LoRA's vision capability is determined by its base model.
is_vision = is_vision_model(base_model, hf_token = hf_token)
else:
# For a base model, just check its own vision status.
is_vision = is_vision_model(identifier, hf_token = hf_token)
from utils.paths import is_model_cached
is_cached = is_model_cached(identifier) if not is_local else True
return cls(
identifier = identifier,
display_name = display_name,
path = path,
is_local = is_local,
is_cached = is_cached,
is_vision = is_vision,
is_lora = is_lora,
base_model = base_model, # This will be None for base models, and populated for LoRAs
)