unsloth/studio/backend/utils/models/model_config.py
Lee Jackson f9ef639dde
Studio: support GGUF variant selection for non-suffixed repos (#5023)
* fix: support GGUF variant selection for non-suffixed repos

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: harden GGUF detection across cached models and picker flows

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* chore: use shared GGUF picker helper for search rows

* fix: avoid mixed cache duplication and preserve GGUF fallback detection

* fix: unify GGUF cache matching and merge picker hints

* fix: normalize local GGUF matching across picker and model config

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: robust cached-gguf classification + hint-aware click routing

- _repo_gguf_size_bytes: treat size_on_disk=None as 0 and dedupe fallback
  by commit_hash so partial/interrupted downloads don't TypeError out of
  sum() and wipe the entire cached list.
- list_cached_gguf / list_cached_models: narrow per-repo try/except so
  one malformed repo no longer poisons the whole response.
- handleModelClick: route through isKnownGgufRepo instead of the
  suffix-only isGgufRepo, so non-suffixed GGUF repos still open the
  variant expander from every call site.
- Replace the modelIsGgufById/resultIsGgufById Maps with Sets of known
  GGUF ids to stop conflating "no hint" with "known not-GGUF".
- Make HfModelResult.isGguf required (it is always set in makeMapModel).
- Add regression tests for the None size case, mixed-repo inclusion in
  cached-gguf, and per-repo error isolation.

* fix: exclude mmproj from GGUF classification and case-normalize hint lookups

- _repo_gguf_size_bytes now filters mmproj vision-adapter files so
  safetensors+mmproj.gguf repos stay on the cached-models path and
  non-GGUF rows no longer show zero pickable variants. A vision-capable
  GGUF repo (main weight + mmproj adapter) still classifies as GGUF and
  reports the main weight size.
- modelGgufIds / resultGgufIds now key on lowercased ids and
  isKnownGgufRepo lowercases its lookup, so store and HF-search ids
  that differ only by casing still match the same GGUF hint.
- New regression tests: mmproj-only repo excluded from cached-gguf,
  same repo included in cached-models, vision-capable repo still
  classified as GGUF with correct size.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Roland Tannous <rolandtannous@gravityq.ai>
Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
2026-04-15 15:32:01 +04:00

2168 lines
78 KiB
Python

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""
Model and LoRA configuration handling
"""
from dataclasses import dataclass
from typing import Optional, Dict, Any
from utils.paths import (
normalize_path,
is_local_path,
is_model_cached,
get_cache_path,
resolve_cached_repo_id_case,
outputs_root,
exports_root,
resolve_output_dir,
resolve_export_dir,
)
from utils.utils import without_hf_auth
import structlog
from loggers import get_logger
import os
import subprocess
import sys
from pathlib import Path
from typing import List, Tuple
import hashlib
import json
import threading
import yaml
logger = get_logger(__name__)
# ── Model size extraction ────────────────────────────────────
import re as _re
_MODEL_SIZE_RE = _re.compile(
r"(?:^|[-_/])(\d+\.?\d*)\s*([bm])(?:$|[-_/])", _re.IGNORECASE
)
# MoE active-parameter pattern: matches "A3B", "A3.5B", etc.
_ACTIVE_SIZE_RE = _re.compile(
r"(?:^|[-_/])a(\d+\.?\d*)\s*([bm])(?:$|[-_/])", _re.IGNORECASE
)
def extract_model_size_b(model_id: str) -> float | None:
"""Extract model size in billions from a model identifier.
Prefers MoE active-parameter notation (e.g. ``A3B`` in
``Qwen3.5-35B-A3B``) over the total parameter count.
Handles both ``B`` (billions) and ``M`` (millions) suffixes.
"""
mid = (model_id or "").lower()
active = _ACTIVE_SIZE_RE.search(mid)
if active:
val = float(active.group(1))
return val / 1000.0 if active.group(2).lower() == "m" else val
size = _MODEL_SIZE_RE.search(mid)
if not size:
return None
val = float(size.group(1))
return val / 1000.0 if size.group(2).lower() == "m" else val
# Model name mapping: maps all equivalent model names to their canonical YAML config file
# Format: "canonical_model_name.yaml": [list of all equivalent model names]
# Based on the model mapper provided - canonical filename is based on the first model name in the mapper
MODEL_NAME_MAPPING = {
# ── Embedding models ──
"unsloth_all-MiniLM-L6-v2.yaml": [
"unsloth/all-MiniLM-L6-v2",
"sentence-transformers/all-MiniLM-L6-v2",
],
"unsloth_bge-m3.yaml": [
"unsloth/bge-m3",
"BAAI/bge-m3",
],
"unsloth_embeddinggemma-300m.yaml": [
"unsloth/embeddinggemma-300m",
"google/embeddinggemma-300m",
],
"unsloth_gte-modernbert-base.yaml": [
"unsloth/gte-modernbert-base",
"Alibaba-NLP/gte-modernbert-base",
],
"unsloth_Qwen3-Embedding-0.6B.yaml": [
"unsloth/Qwen3-Embedding-0.6B",
"Qwen/Qwen3-Embedding-0.6B",
"unsloth/Qwen3-Embedding-4B",
"Qwen/Qwen3-Embedding-4B",
],
# ── Other models ──
"unsloth_answerdotai_ModernBERT-large.yaml": [
"answerdotai/ModernBERT-large",
],
"unsloth_Qwen2.5-Coder-7B-Instruct-bnb-4bit.yaml": [
"unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit",
"unsloth/Qwen2.5-Coder-7B-Instruct",
"Qwen/Qwen2.5-Coder-7B-Instruct",
],
"unsloth_codegemma-7b-bnb-4bit.yaml": [
"unsloth/codegemma-7b-bnb-4bit",
"unsloth/codegemma-7b",
"google/codegemma-7b",
],
"unsloth_ERNIE-4.5-21B-A3B-PT.yaml": [
"unsloth/ERNIE-4.5-21B-A3B-PT",
],
"unsloth_ERNIE-4.5-VL-28B-A3B-PT.yaml": [
"unsloth/ERNIE-4.5-VL-28B-A3B-PT",
],
"tiiuae_Falcon-H1-0.5B-Instruct.yaml": [
"tiiuae/Falcon-H1-0.5B-Instruct",
"unsloth/Falcon-H1-0.5B-Instruct",
],
"unsloth_functiongemma-270m-it.yaml": [
"unsloth/functiongemma-270m-it-unsloth-bnb-4bit",
"google/functiongemma-270m-it",
"unsloth/functiongemma-270m-it-unsloth-bnb-4bit",
],
"unsloth_gemma-2-2b.yaml": [
"unsloth/gemma-2-2b-bnb-4bit",
"google/gemma-2-2b",
],
"unsloth_gemma-2-27b-bnb-4bit.yaml": [
"unsloth/gemma-2-9b-bnb-4bit",
"unsloth/gemma-2-9b",
"google/gemma-2-9b",
"unsloth/gemma-2-27b",
"google/gemma-2-27b",
],
"unsloth_gemma-3-4b-pt.yaml": [
"unsloth/gemma-3-4b-pt-unsloth-bnb-4bit",
"google/gemma-3-4b-pt",
"unsloth/gemma-3-4b-pt-bnb-4bit",
],
"unsloth_gemma-3-4b-it.yaml": [
"unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
"google/gemma-3-4b-it",
"unsloth/gemma-3-4b-it-bnb-4bit",
],
"unsloth_gemma-3-27b-it.yaml": [
"unsloth/gemma-3-27b-it-unsloth-bnb-4bit",
"google/gemma-3-27b-it",
"unsloth/gemma-3-27b-it-bnb-4bit",
],
"unsloth_gemma-3-270m-it.yaml": [
"unsloth/gemma-3-270m-it-unsloth-bnb-4bit",
"google/gemma-3-270m-it",
"unsloth/gemma-3-270m-it-bnb-4bit",
],
"unsloth_gemma-3n-E4B-it.yaml": [
"unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit",
"google/gemma-3n-E4B-it",
"unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit",
],
"unsloth_gemma-3n-E4B.yaml": [
"unsloth/gemma-3n-E4B-unsloth-bnb-4bit",
"google/gemma-3n-E4B",
],
"unsloth_gemma-4-31B-it.yaml": [
"unsloth/gemma-4-31B-it",
"google/gemma-4-31B-it",
],
"unsloth_gemma-4-26B-A4B-it.yaml": [
"unsloth/gemma-4-26B-A4B-it",
"google/gemma-4-26B-A4B-it",
],
"unsloth_gemma-4-E2B-it.yaml": [
"unsloth/gemma-4-E2B-it",
"google/gemma-4-E2B-it",
],
"unsloth_gemma-4-E4B-it.yaml": [
"unsloth/gemma-4-E4B-it",
"google/gemma-4-E4B-it",
],
"unsloth_gemma-4-31B.yaml": [
"unsloth/gemma-4-31B",
"google/gemma-4-31B",
],
"unsloth_gemma-4-26B-A4B.yaml": [
"unsloth/gemma-4-26B-A4B",
"google/gemma-4-26B-A4B",
],
"unsloth_gemma-4-E2B.yaml": [
"unsloth/gemma-4-E2B",
"google/gemma-4-E2B",
],
"unsloth_gemma-4-E4B.yaml": [
"unsloth/gemma-4-E4B",
"google/gemma-4-E4B",
],
"unsloth_gpt-oss-20b.yaml": [
"openai/gpt-oss-20b",
"unsloth/gpt-oss-20b-unsloth-bnb-4bit",
"unsloth/gpt-oss-20b-BF16",
],
"unsloth_gpt-oss-120b.yaml": [
"openai/gpt-oss-120b",
"unsloth/gpt-oss-120b-unsloth-bnb-4bit",
],
"unsloth_granite-4.0-350m-unsloth-bnb-4bit.yaml": [
"unsloth/granite-4.0-350m",
"ibm-granite/granite-4.0-350m",
"unsloth/granite-4.0-350m-bnb-4bit",
],
"unsloth_granite-4.0-h-micro.yaml": [
"ibm-granite/granite-4.0-h-micro",
"unsloth/granite-4.0-h-micro-bnb-4bit",
"unsloth/granite-4.0-h-micro-unsloth-bnb-4bit",
],
"unsloth_LFM2-1.2B.yaml": [
"unsloth/LFM2-1.2B",
],
"unsloth_llama-3-8b-bnb-4bit.yaml": [
"unsloth/llama-3-8b",
"meta-llama/Meta-Llama-3-8B",
],
"unsloth_llama-3-8b-Instruct-bnb-4bit.yaml": [
"unsloth/llama-3-8b-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
],
"unsloth_Meta-Llama-3.1-70B-bnb-4bit.yaml": [
"unsloth/Meta-Llama-3.1-8B-bnb-4bit",
"unsloth/Meta-Llama-3.1-8B-unsloth-bnb-4bit",
"meta-llama/Meta-Llama-3.1-8B",
"unsloth/Meta-Llama-3.1-70B-bnb-4bit",
"unsloth/Meta-Llama-3.1-8B",
"unsloth/Meta-Llama-3.1-70B",
"meta-llama/Meta-Llama-3.1-70B",
"unsloth/Meta-Llama-3.1-405B-bnb-4bit",
"meta-llama/Meta-Llama-3.1-405B",
],
"unsloth_Meta-Llama-3.1-8B-Instruct-bnb-4bit.yaml": [
"unsloth/Meta-Llama-3.1-8B-Instruct-unsloth-bnb-4bit",
"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"unsloth/Meta-Llama-3.1-8B-Instruct",
"RedHatAI/Llama-3.1-8B-Instruct-FP8",
"unsloth/Llama-3.1-8B-Instruct-FP8-Block",
"unsloth/Llama-3.1-8B-Instruct-FP8-Dynamic",
],
"unsloth_Llama-3.2-3B-Instruct.yaml": [
"unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit",
"meta-llama/Llama-3.2-3B-Instruct",
"unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
"RedHatAI/Llama-3.2-3B-Instruct-FP8",
"unsloth/Llama-3.2-3B-Instruct-FP8-Block",
"unsloth/Llama-3.2-3B-Instruct-FP8-Dynamic",
],
"unsloth_Llama-3.2-1B-Instruct.yaml": [
"unsloth/Llama-3.2-1B-Instruct-unsloth-bnb-4bit",
"meta-llama/Llama-3.2-1B-Instruct",
"unsloth/Llama-3.2-1B-Instruct-bnb-4bit",
"RedHatAI/Llama-3.2-1B-Instruct-FP8",
"unsloth/Llama-3.2-1B-Instruct-FP8-Block",
"unsloth/Llama-3.2-1B-Instruct-FP8-Dynamic",
],
"unsloth_Llama-3.2-11B-Vision-Instruct.yaml": [
"unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit",
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
],
"unsloth_Llama-3.3-70B-Instruct.yaml": [
"unsloth/Llama-3.3-70B-Instruct-unsloth-bnb-4bit",
"meta-llama/Llama-3.3-70B-Instruct",
"unsloth/Llama-3.3-70B-Instruct-bnb-4bit",
"RedHatAI/Llama-3.3-70B-Instruct-FP8",
"unsloth/Llama-3.3-70B-Instruct-FP8-Block",
"unsloth/Llama-3.3-70B-Instruct-FP8-Dynamic",
],
"unsloth_Llasa-3B.yaml": [
"HKUSTAudio/Llasa-1B",
"unsloth/Llasa-3B",
],
"unsloth_Magistral-Small-2509-unsloth-bnb-4bit.yaml": [
"unsloth/Magistral-Small-2509",
"mistralai/Magistral-Small-2509",
"unsloth/Magistral-Small-2509-bnb-4bit",
],
"unsloth_Ministral-3-3B-Instruct-2512.yaml": [
"unsloth/Ministral-3-3B-Instruct-2512",
],
"unsloth_mistral-7b-v0.3-bnb-4bit.yaml": [
"unsloth/mistral-7b-v0.3-bnb-4bit",
"unsloth/mistral-7b-v0.3",
"mistralai/Mistral-7B-v0.3",
],
"unsloth_Mistral-Nemo-Base-2407-bnb-4bit.yaml": [
"unsloth/Mistral-Nemo-Base-2407-bnb-4bit",
"unsloth/Mistral-Nemo-Base-2407",
"mistralai/Mistral-Nemo-Base-2407",
"unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit",
"unsloth/Mistral-Nemo-Instruct-2407",
"mistralai/Mistral-Nemo-Instruct-2407",
],
"unsloth_Mistral-Small-Instruct-2409.yaml": [
"unsloth/Mistral-Small-Instruct-2409-bnb-4bit",
"mistralai/Mistral-Small-Instruct-2409",
],
"unsloth_mistral-7b-instruct-v0.3-bnb-4bit.yaml": [
"unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
"unsloth/mistral-7b-instruct-v0.3",
"mistralai/Mistral-7B-Instruct-v0.3",
],
"unsloth_Qwen2.5-1.5B-Instruct.yaml": [
"unsloth/Qwen2.5-1.5B-Instruct-unsloth-bnb-4bit",
"Qwen/Qwen2.5-1.5B-Instruct",
"unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit",
],
"unsloth_Nemotron-3-Nano-30B-A3B.yaml": [
"unsloth/Nemotron-3-Nano-30B-A3B",
],
"unsloth_orpheus-3b-0.1-ft.yaml": [
"unsloth/orpheus-3b-0.1-ft",
"unsloth/orpheus-3b-0.1-ft-unsloth-bnb-4bit",
"canopylabs/orpheus-3b-0.1-ft",
"unsloth/orpheus-3b-0.1-ft-bnb-4bit",
],
"OuteAI_Llama-OuteTTS-1.0-1B.yaml": [
"OuteAI/Llama-OuteTTS-1.0-1B",
"unsloth/Llama-OuteTTS-1.0-1B",
"unsloth/llama-outetts-1.0-1b",
"OuteAI/OuteTTS-1.0-0.6B",
"unsloth/OuteTTS-1.0-0.6B",
"unsloth/outetts-1.0-0.6b",
],
"unsloth_PaddleOCR-VL.yaml": [
"unsloth/PaddleOCR-VL",
],
"unsloth_Phi-3-medium-4k-instruct.yaml": [
"unsloth/Phi-3-medium-4k-instruct-bnb-4bit",
"microsoft/Phi-3-medium-4k-instruct",
],
"unsloth_Phi-3.5-mini-instruct.yaml": [
"unsloth/Phi-3.5-mini-instruct-bnb-4bit",
"microsoft/Phi-3.5-mini-instruct",
],
"unsloth_Phi-4.yaml": [
"unsloth/phi-4-unsloth-bnb-4bit",
"microsoft/phi-4",
"unsloth/phi-4-bnb-4bit",
],
"unsloth_Pixtral-12B-2409.yaml": [
"unsloth/Pixtral-12B-2409-unsloth-bnb-4bit",
"mistralai/Pixtral-12B-2409",
"unsloth/Pixtral-12B-2409-bnb-4bit",
],
"unsloth_Qwen2-7B.yaml": [
"unsloth/Qwen2-7B-bnb-4bit",
"Qwen/Qwen2-7B",
],
"unsloth_Qwen2-VL-7B-Instruct.yaml": [
"unsloth/Qwen2-VL-7B-Instruct-unsloth-bnb-4bit",
"Qwen/Qwen2-VL-7B-Instruct",
"unsloth/Qwen2-VL-7B-Instruct-bnb-4bit",
],
"unsloth_Qwen2.5-7B.yaml": [
"unsloth/Qwen2.5-7B-unsloth-bnb-4bit",
"Qwen/Qwen2.5-7B",
"unsloth/Qwen2.5-7B-bnb-4bit",
],
"unsloth_Qwen2.5-Coder-1.5B-Instruct.yaml": [
"unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit",
"Qwen/Qwen2.5-Coder-1.5B-Instruct",
],
"unsloth_Qwen2.5-Coder-14B-Instruct.yaml": [
"unsloth/Qwen2.5-Coder-14B-Instruct-bnb-4bit",
"Qwen/Qwen2.5-Coder-14B-Instruct",
],
"unsloth_Qwen2.5-VL-7B-Instruct-bnb-4bit.yaml": [
"unsloth/Qwen2.5-VL-7B-Instruct",
"Qwen/Qwen2.5-VL-7B-Instruct",
"unsloth/Qwen2.5-VL-7B-Instruct-unsloth-bnb-4bit",
],
"unsloth_Qwen3-0.6B.yaml": [
"unsloth/Qwen3-0.6B-unsloth-bnb-4bit",
"Qwen/Qwen3-0.6B",
"unsloth/Qwen3-0.6B-bnb-4bit",
"Qwen/Qwen3-0.6B-FP8",
"unsloth/Qwen3-0.6B-FP8",
],
"unsloth_Qwen3-4B-Instruct-2507.yaml": [
"unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit",
"Qwen/Qwen3-4B-Instruct-2507",
"unsloth/Qwen3-4B-Instruct-2507-bnb-4bit",
"Qwen/Qwen3-4B-Instruct-2507-FP8",
"unsloth/Qwen3-4B-Instruct-2507-FP8",
],
"unsloth_Qwen3-4B-Thinking-2507.yaml": [
"unsloth/Qwen3-4B-Thinking-2507-unsloth-bnb-4bit",
"Qwen/Qwen3-4B-Thinking-2507",
"unsloth/Qwen3-4B-Thinking-2507-bnb-4bit",
"Qwen/Qwen3-4B-Thinking-2507-FP8",
"unsloth/Qwen3-4B-Thinking-2507-FP8",
],
"unsloth_Qwen3-14B-Base-unsloth-bnb-4bit.yaml": [
"unsloth/Qwen3-14B-Base",
"Qwen/Qwen3-14B-Base",
"unsloth/Qwen3-14B-Base-bnb-4bit",
],
"unsloth_Qwen3-14B.yaml": [
"unsloth/Qwen3-14B-unsloth-bnb-4bit",
"Qwen/Qwen3-14B",
"unsloth/Qwen3-14B-bnb-4bit",
"Qwen/Qwen3-14B-FP8",
"unsloth/Qwen3-14B-FP8",
],
"unsloth_Qwen3-32B.yaml": [
"unsloth/Qwen3-32B-unsloth-bnb-4bit",
"Qwen/Qwen3-32B",
"unsloth/Qwen3-32B-bnb-4bit",
"Qwen/Qwen3-32B-FP8",
"unsloth/Qwen3-32B-FP8",
],
"unsloth_Qwen3-VL-8B-Instruct-unsloth-bnb-4bit.yaml": [
"Qwen/Qwen3-VL-8B-Instruct-FP8",
"unsloth/Qwen3-VL-8B-Instruct-FP8",
"unsloth/Qwen3-VL-8B-Instruct",
"Qwen/Qwen3-VL-8B-Instruct",
"unsloth/Qwen3-VL-8B-Instruct-bnb-4bit",
],
"sesame_csm-1b.yaml": [
"sesame/csm-1b",
"unsloth/csm-1b",
],
"Spark-TTS-0.5B_LLM.yaml": [
"Spark-TTS-0.5B/LLM",
"unsloth/Spark-TTS-0.5B",
],
"unsloth_tinyllama-bnb-4bit.yaml": [
"unsloth/tinyllama",
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
],
"unsloth_whisper-large-v3.yaml": [
"unsloth/whisper-large-v3",
"openai/whisper-large-v3",
],
}
# Reverse mapping for quick lookup: model_name -> canonical_filename
_REVERSE_MODEL_MAPPING = {}
for canonical_file, model_names in MODEL_NAME_MAPPING.items():
for model_name in model_names:
_REVERSE_MODEL_MAPPING[model_name.lower()] = canonical_file
def load_model_config(
model_name: str,
use_auth: bool = False,
token: Optional[str] = None,
trust_remote_code: bool = True,
):
"""
Load model config with optional authentication control.
"""
from transformers import AutoConfig
if token:
# Explicit token provided - use it
return AutoConfig.from_pretrained(
model_name, trust_remote_code = trust_remote_code, token = token
)
if not use_auth:
# Load without any authentication (for public model checks)
with without_hf_auth():
return AutoConfig.from_pretrained(
model_name,
trust_remote_code = trust_remote_code,
token = None,
)
# Use default authentication (cached tokens)
return AutoConfig.from_pretrained(
model_name,
trust_remote_code = trust_remote_code,
)
# VLM architecture suffixes and known VLM model_type values.
_VLM_ARCH_SUFFIXES = ("ForConditionalGeneration", "ForVisionText2Text")
_VLM_MODEL_TYPES = {
"phi3_v",
"llava",
"llava_next",
"llava_onevision",
"internvl_chat",
"cogvlm2",
"minicpmv",
}
# Pre-computed .venv_t5 paths and backend dir for subprocess version switching.
# Vision check uses 5.5.0 (newest, recognizes all architectures).
_VENV_T5_DIR = str(Path.home() / ".unsloth" / "studio" / ".venv_t5_550")
_BACKEND_DIR = str(Path(__file__).resolve().parent.parent.parent)
# Inline script executed in a subprocess with transformers 5.x activated.
# Receives model_name and token via argv, prints JSON result to stdout.
_VISION_CHECK_SCRIPT = r"""
import sys, os, json
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Activate transformers 5.x
venv_t5 = sys.argv[1]
backend_dir = sys.argv[2]
model_name = sys.argv[3]
token = sys.argv[4] if len(sys.argv) > 4 and sys.argv[4] != "" else None
sys.path.insert(0, venv_t5)
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
try:
from transformers import AutoConfig
kwargs = {"trust_remote_code": True}
if token:
kwargs["token"] = token
config = AutoConfig.from_pretrained(model_name, **kwargs)
is_vlm = False
if hasattr(config, "architectures"):
is_vlm = any(
x.endswith(("ForConditionalGeneration", "ForVisionText2Text"))
for x in config.architectures
)
if not is_vlm and hasattr(config, "vision_config"):
is_vlm = True
if not is_vlm and hasattr(config, "img_processor"):
is_vlm = True
if not is_vlm and hasattr(config, "image_token_index"):
is_vlm = True
if not is_vlm and hasattr(config, "model_type"):
vlm_types = {"phi3_v","llava","llava_next","llava_onevision",
"internvl_chat","cogvlm2","minicpmv"}
if config.model_type in vlm_types:
is_vlm = True
model_type = getattr(config, "model_type", "unknown")
archs = getattr(config, "architectures", [])
print(json.dumps({"is_vision": is_vlm, "model_type": model_type,
"architectures": archs}))
except Exception as exc:
print(json.dumps({"error": str(exc)}))
sys.exit(1)
"""
def _is_vision_model_subprocess(
model_name: str, hf_token: Optional[str] = None
) -> Optional[bool]:
"""Run is_vision_model check in a subprocess with transformers 5.x.
Same pattern as training/inference workers: spawn a clean subprocess
with .venv_t5/ prepended to sys.path so AutoConfig recognizes newer
architectures (glm4_moe_lite, etc.).
Returns True/False for definitive results, or None for transient failures
(timeouts, subprocess errors) so callers can decide whether to cache
the result. Subprocess failures are treated as transient because they
can be caused by temporary HF/auth/network issues.
"""
token_arg = hf_token or ""
try:
result = subprocess.run(
[
sys.executable,
"-c",
_VISION_CHECK_SCRIPT,
_VENV_T5_DIR,
_BACKEND_DIR,
model_name,
token_arg,
],
capture_output = True,
text = True,
timeout = 60,
)
if result.returncode != 0:
stderr = result.stderr.strip()
logger.warning(
"Vision check subprocess failed for '%s': %s",
model_name,
stderr or result.stdout.strip(),
)
return None
data = json.loads(result.stdout.strip())
if "error" in data:
logger.warning(
"Vision check subprocess error for '%s': %s",
model_name,
data["error"],
)
return None
is_vlm = data["is_vision"]
logger.info(
"Vision check (subprocess, transformers 5.x) for '%s': "
"model_type=%s, architectures=%s, is_vision=%s",
model_name,
data.get("model_type"),
data.get("architectures"),
is_vlm,
)
return is_vlm
except subprocess.TimeoutExpired:
logger.warning("Vision check subprocess timed out for '%s'", model_name)
return None
except Exception as exc:
logger.warning("Vision check subprocess failed for '%s': %s", model_name, exc)
return None
def _token_fingerprint(token: Optional[str]) -> Optional[str]:
"""Return a SHA256 digest of the token for use as a cache key.
Avoids storing the raw bearer token in process memory as a dict key.
"""
if token is None:
return None
return hashlib.sha256(token.encode("utf-8")).hexdigest()
# Cache vision detection results per session to avoid repeated subprocess spawns.
# Keyed by (normalized_model_name, token_fingerprint) to handle gated models correctly.
# Only definitive results (True/False from successful detection) are cached;
# transient failures (network errors, timeouts) are NOT cached so they can be retried.
_vision_detection_cache: Dict[Tuple[str, Optional[str]], bool] = {}
_vision_cache_lock = threading.Lock()
def is_vision_model(model_name: str, hf_token: Optional[str] = None) -> bool:
"""
Detect vision-language models (VLMs) by checking architecture in config.
Works for fine-tuned models since they inherit the base architecture.
For models that require transformers 5.x (e.g. GLM-4.7-Flash), the check
runs in a subprocess with .venv_t5/ activated -- same pattern as the
training and inference workers.
Results are cached per (model_name, token_fingerprint) for the lifetime of
the process to avoid repeated subprocess spawns and HuggingFace API calls.
Transient failures are not cached so they can be retried on the next call.
Args:
model_name: Model identifier (HF repo or local path)
hf_token: Optional HF token for accessing gated/private models
"""
# Normalize model name for cache key to avoid duplicate entries for
# different casings of the same HF repo (e.g. "Org/Model" vs "org/model").
try:
if is_local_path(model_name):
resolved_name = normalize_path(model_name)
else:
resolved_name = resolve_cached_repo_id_case(model_name)
except Exception as exc:
logger.debug(
"Could not normalize model name '%s' for cache key: %s",
model_name,
exc,
)
resolved_name = model_name
cache_key = (resolved_name, _token_fingerprint(hf_token))
# Lock-free fast path for cache hits. Uses a sentinel to distinguish
# "key not found" from "value is False" in a single atomic dict.get() call.
_MISS = object()
cached = _vision_detection_cache.get(cache_key, _MISS)
if cached is not _MISS:
return cached
# Compute outside the lock to avoid serializing long-running detection
# (subprocess spawns with 60s timeout, HF API calls) across all models.
# The tradeoff: two concurrent calls for the same uncached model may
# both run detection, but they produce the same result and the second
# write is a benign no-op.
result = _is_vision_model_uncached(resolved_name, hf_token)
# Only cache definitive results; None means a transient failure occurred
# and we should retry on the next call instead of locking in a wrong answer.
if result is not None:
with _vision_cache_lock:
_vision_detection_cache[cache_key] = result
return result
return False
def _is_vision_model_uncached(
model_name: str, hf_token: Optional[str] = None
) -> Optional[bool]:
"""Uncached vision model detection -- called by is_vision_model().
Returns True/False for definitive results, or None when detection failed
due to a transient error (network, timeout, subprocess failure) so the
caller knows not to cache the result.
Do not call directly; use is_vision_model() instead.
"""
# Models that need transformers 5.x must be checked in a subprocess
# because AutoConfig in the main process (transformers 4.57.x) doesn't
# recognize their architectures.
from utils.transformers_version import needs_transformers_5
if needs_transformers_5(model_name):
logger.info(
"Model '%s' needs transformers 5.x -- checking vision via subprocess",
model_name,
)
return _is_vision_model_subprocess(model_name, hf_token = hf_token)
try:
config = load_model_config(model_name, use_auth = True, token = hf_token)
# Exclude audio-only models that share ForConditionalGeneration suffix
# (e.g. CsmForConditionalGeneration, WhisperForConditionalGeneration)
_audio_only_model_types = {"csm", "whisper"}
model_type = getattr(config, "model_type", None)
if model_type in _audio_only_model_types:
return False
# Check 1: Architecture class name patterns
if hasattr(config, "architectures"):
is_vlm = any(x.endswith(_VLM_ARCH_SUFFIXES) for x in config.architectures)
if is_vlm:
logger.info(
f"Model {model_name} detected as VLM: architecture {config.architectures}"
)
return True
# Check 2: Has vision_config (most VLMs: LLaVA, Gemma-3, Qwen2-VL, etc.)
if hasattr(config, "vision_config"):
logger.info(f"Model {model_name} detected as VLM: has vision_config")
return True
# Check 3: Has img_processor (Phi-3.5 Vision uses this instead of vision_config)
if hasattr(config, "img_processor"):
logger.info(f"Model {model_name} detected as VLM: has img_processor")
return True
# Check 4: Has image_token_index (common in VLMs for image placeholder tokens)
if hasattr(config, "image_token_index"):
logger.info(f"Model {model_name} detected as VLM: has image_token_index")
return True
# Check 5: Known VLM model_type values that may not match above checks
if hasattr(config, "model_type"):
if config.model_type in _VLM_MODEL_TYPES:
logger.info(
f"Model {model_name} detected as VLM: model_type={config.model_type}"
)
return True
return False
except Exception as e:
logger.warning(f"Could not determine if {model_name} is vision model: {e}")
# Permanent failures (model not found, gated, bad config) should be
# cached as False. Transient failures (network, timeout) should not.
try:
from huggingface_hub.errors import RepositoryNotFoundError, GatedRepoError
except ImportError:
try:
from huggingface_hub.utils import (
RepositoryNotFoundError,
GatedRepoError,
)
except ImportError:
RepositoryNotFoundError = GatedRepoError = None
if RepositoryNotFoundError is not None and isinstance(
e, (RepositoryNotFoundError, GatedRepoError)
):
return False
if isinstance(e, (ValueError, json.JSONDecodeError)):
return False
return None
VALID_AUDIO_TYPES = ("snac", "csm", "bicodec", "dac", "whisper", "audio_vlm")
# Cache detection results per session to avoid repeated API calls
_audio_detection_cache: Dict[str, Optional[str]] = {}
# Tokenizer token patterns → audio_type (all 6 types detected from tokenizer_config.json)
_AUDIO_TOKEN_PATTERNS = {
"csm": lambda tokens: "<|AUDIO|>" in tokens and "<|audio_eos|>" in tokens,
"whisper": lambda tokens: "<|startoftranscript|>" in tokens,
"audio_vlm": lambda tokens: "<audio_soft_token>" in tokens,
"bicodec": lambda tokens: any(t.startswith("<|bicodec_") for t in tokens),
"dac": lambda tokens: "<|audio_start|>" in tokens
and "<|audio_end|>" in tokens
and "<|text_start|>" in tokens
and "<|text_end|>" in tokens,
"snac": lambda tokens: sum(1 for t in tokens if t.startswith("<custom_token_"))
> 10000,
}
def detect_audio_type(model_name: str, hf_token: Optional[str] = None) -> Optional[str]:
"""
Dynamically detect if a model is an audio model and return its type.
Fully dynamic — works for any model, not just known ones.
Uses tokenizer_config.json special tokens to detect all 6 audio types.
Returns: audio_type string ('snac', 'csm', 'bicodec', 'dac', 'whisper', 'audio_vlm') or None.
"""
if model_name in _audio_detection_cache:
return _audio_detection_cache[model_name]
result = _detect_audio_from_tokenizer(model_name, hf_token)
_audio_detection_cache[model_name] = result
if result:
logger.info(f"Model {model_name} detected as audio model: audio_type={result}")
return result
def _detect_audio_from_tokenizer(
model_name: str, hf_token: Optional[str] = None
) -> Optional[str]:
"""Detect audio type from tokenizer special tokens (for LLM-based audio models).
First checks local HF cache, then fetches tokenizer_config.json from HuggingFace.
Checks added_tokens_decoder for distinctive patterns.
"""
def _check_token_patterns(tok_config: dict) -> Optional[str]:
added = tok_config.get("added_tokens_decoder", {})
if not added:
return None
token_contents = [v.get("content", "") for v in added.values()]
for audio_type, check_fn in _AUDIO_TOKEN_PATTERNS.items():
if check_fn(token_contents):
return audio_type
return None
# 1) Check local HF cache first (works for gated/offline models)
try:
repo_dir = get_cache_path(model_name)
if repo_dir is not None and repo_dir.exists():
snapshots_dir = repo_dir / "snapshots"
if snapshots_dir.exists():
for snapshot in snapshots_dir.iterdir():
for tok_path in [
"tokenizer_config.json",
"LLM/tokenizer_config.json",
]:
tok_file = snapshot / tok_path
if tok_file.exists():
tok_config = json.loads(tok_file.read_text())
result = _check_token_patterns(tok_config)
if result:
return result
except Exception as e:
logger.debug(f"Could not check local cache for {model_name}: {e}")
# 2) Fall back to HuggingFace API
try:
import requests
import os
paths_to_try = ["tokenizer_config.json", "LLM/tokenizer_config.json"]
# Use provided token, or fall back to env
token = hf_token or os.environ.get("HF_TOKEN")
headers = {}
if token:
headers["Authorization"] = f"Bearer {token}"
for tok_path in paths_to_try:
url = f"https://huggingface.co/{model_name}/resolve/main/{tok_path}"
resp = requests.get(url, headers = headers, timeout = 15)
if not resp.ok:
continue
tok_config = resp.json()
result = _check_token_patterns(tok_config)
if result:
return result
return None
except Exception as e:
logger.debug(
f"Could not detect audio type from tokenizer for {model_name}: {e}"
)
return None
def is_audio_input_type(audio_type: Optional[str]) -> bool:
"""Check if an audio_type accepts audio input (ASR/speech understanding).
Whisper (ASR) and audio_vlm (Gemma3n) accept audio input.
"""
return audio_type in ("whisper", "audio_vlm")
def _is_mmproj(filename: str) -> bool:
"""Check if a GGUF filename is a vision projection (mmproj) file."""
return "mmproj" in filename.lower()
def _is_gguf_filename(filename: str) -> bool:
return filename.lower().endswith(".gguf")
def _iter_gguf_files(directory: Path):
if not directory.is_dir():
return
for f in directory.iterdir():
if f.is_file() and _is_gguf_filename(f.name):
yield f
def detect_mmproj_file(path: str) -> Optional[str]:
"""
Find the mmproj (vision projection) GGUF file in a directory.
Args:
path: Directory to search — or a .gguf file (uses its parent dir).
Returns:
Full path to the mmproj .gguf file, or None if not found.
"""
p = Path(path)
search_dir = p.parent if p.is_file() else p
if not search_dir.is_dir():
return None
for f in _iter_gguf_files(search_dir):
if _is_mmproj(f.name):
return str(f.resolve())
return None
def detect_gguf_model(path: str) -> Optional[str]:
"""
Check if the given local path is or contains a GGUF model file.
Handles two cases:
1. path is a direct .gguf file path
2. path is a directory containing .gguf files
Skips mmproj (vision projection) files — those must be passed via
``--mmproj``, not ``-m``. Use :func:`detect_mmproj_file` instead.
Returns the full path to the .gguf file if found, None otherwise.
For HuggingFace repo detection, use detect_gguf_model_remote() instead.
"""
p = Path(path)
# Case 1: direct .gguf file
if p.suffix.lower() == ".gguf" and p.is_file():
if _is_mmproj(p.name):
return None
return str(p.resolve())
# Case 2: directory containing .gguf files (skip mmproj)
if p.is_dir():
gguf_files = sorted(
(f for f in _iter_gguf_files(p) if not _is_mmproj(f.name)),
key = lambda f: f.stat().st_size,
reverse = True,
)
if gguf_files:
return str(gguf_files[0].resolve())
return None
# Preferred GGUF quantization levels, in descending priority.
# Q4_K_M is a good default: small, fast, acceptable quality.
# UD (Unsloth Dynamic) variants are always preferred over standard quants
# because they provide better quality per bit. If the repo has no UD variants
# (e.g., bartowski repos), the standard quants are used as fallback.
# Ordered by best size/quality tradeoff, not raw quality.
_GGUF_QUANT_PREFERENCE = [
# UD variants (best quality per bit) -- Q4 is the sweet spot
"UD-Q4_K_XL",
"UD-Q4_K_L",
"UD-Q5_K_XL",
"UD-Q3_K_XL",
"UD-Q6_K_XL",
"UD-Q6_K_S",
"UD-Q8_K_XL",
"UD-Q2_K_XL",
"UD-IQ4_NL",
"UD-IQ4_XS",
"UD-IQ3_S",
"UD-IQ3_XXS",
"UD-IQ2_M",
"UD-IQ2_XXS",
"UD-IQ1_M",
"UD-IQ1_S",
# Standard quants (fallback for non-Unsloth repos)
"Q4_K_M",
"Q4_K_S",
"Q5_K_M",
"Q5_K_S",
"Q6_K",
"Q8_0",
"Q3_K_M",
"Q3_K_L",
"Q3_K_S",
"Q2_K",
"Q2_K_L",
"IQ4_NL",
"IQ4_XS",
"IQ3_M",
"IQ3_XXS",
"IQ2_M",
"IQ1_M",
"F16",
"BF16",
"F32",
]
def _pick_best_gguf(filenames: list[str]) -> Optional[str]:
"""
Pick the best GGUF file from a list of filenames.
Prefers quantization levels in _GGUF_QUANT_PREFERENCE order.
Falls back to the first .gguf file found.
"""
gguf_files = [f for f in filenames if f.lower().endswith(".gguf")]
if not gguf_files:
return None
# Try preferred quantization levels
for quant in _GGUF_QUANT_PREFERENCE:
for f in gguf_files:
if quant in f:
return f
# Fallback: first GGUF file
return gguf_files[0]
@dataclass
class GgufVariantInfo:
"""A single GGUF quantization variant from a HuggingFace repo."""
filename: str # e.g., "gemma-3-4b-it-Q4_K_M.gguf"
quant: str # e.g., "Q4_K_M" (extracted from filename)
size_bytes: int # file size
def _extract_quant_label(filename: str) -> str:
"""
Extract quantization label like Q4_K_M, IQ4_XS, BF16 from a GGUF filename.
Examples:
"gemma-3-4b-it-Q4_K_M.gguf""Q4_K_M"
"model-IQ4_NL.gguf""IQ4_NL"
"model-BF16.gguf""BF16"
"model-UD-IQ1_S.gguf""UD-IQ1_S"
"model-UD-TQ1_0.gguf""UD-TQ1_0"
"MXFP4_MOE/model-MXFP4_MOE-0001.gguf""MXFP4_MOE"
"""
import re
# Use only the basename (rfilename may include directory)
basename = filename.rsplit("/", 1)[-1]
# Strip .gguf and any shard suffix (-00001-of-00010)
stem = re.sub(r"-\d{3,}-of-\d{3,}", "", basename.rsplit(".", 1)[0])
# Match known quantization patterns
match = re.search(
r"(UD-)?" # Optional UD- prefix (Ultra Discrete)
r"(MXFP[0-9]+(?:_[A-Z0-9]+)*" # MXFP variants: MXFP4, MXFP4_MOE
r"|IQ[0-9]+_[A-Z]+(?:_[A-Z0-9]+)?" # IQ variants: IQ4_XS, IQ4_NL, IQ1_S
r"|TQ[0-9]+_[0-9]+" # Ternary quant: TQ1_0, TQ2_0
r"|Q[0-9]+_K_[A-Z]+" # K-quant: Q4_K_M, Q3_K_S
r"|Q[0-9]+_[0-9]+" # Standard: Q8_0, Q5_1
r"|Q[0-9]+_K" # Short K-quant: Q6_K
r"|BF16|F16|F32)", # Full precision
stem,
re.IGNORECASE,
)
if match:
prefix = match.group(1) or ""
return f"{prefix}{match.group(2)}"
# Fallback: last segment after hyphen
return stem.split("-")[-1]
def list_gguf_variants(
repo_id: str,
hf_token: Optional[str] = None,
) -> tuple[list[GgufVariantInfo], bool]:
"""
List all GGUF quantization variants in a HuggingFace repo.
Separates main model files from mmproj (vision projection) files.
The presence of mmproj files indicates a vision-capable model.
Returns:
(variants, has_vision): list of non-mmproj GGUF variants + vision flag.
"""
from huggingface_hub import model_info as hf_model_info
info = hf_model_info(repo_id, token = hf_token, files_metadata = True)
variants: list[GgufVariantInfo] = []
has_vision = False
quant_totals: dict[str, int] = {} # quant -> total bytes
quant_first_file: dict[str, str] = {} # quant -> first filename (for display)
for sibling in info.siblings:
fname = sibling.rfilename
if not fname.lower().endswith(".gguf"):
continue
size = sibling.size or 0
# mmproj files are vision projection models, not main model files
if "mmproj" in fname.lower():
has_vision = True
continue
quant = _extract_quant_label(fname)
quant_totals[quant] = quant_totals.get(quant, 0) + size
if quant not in quant_first_file:
quant_first_file[quant] = fname
for quant, total_size in quant_totals.items():
variants.append(
GgufVariantInfo(
filename = quant_first_file[quant],
quant = quant,
size_bytes = total_size,
)
)
# Sort by size descending (largest = best quality first).
# Recommended pinning and OOM demotion are handled client-side
# where GPU VRAM info is available.
variants.sort(key = lambda v: -v.size_bytes)
return variants, has_vision
def _resolve_gguf_dir(p: Path) -> Optional[Path]:
"""Resolve a path to the directory containing GGUF variants.
If *p* is already a directory, returns it directly. If *p* is a ``.gguf``
file whose parent directory has model metadata (``config.json`` or
``adapter_config.json``), returns the parent -- all GGUFs in that
directory belong to the same model. Returns ``None`` for loose standalone
GGUFs (no config) to avoid cross-wiring unrelated models.
"""
if p.is_dir():
return p
if p.is_file() and p.suffix.lower() == ".gguf":
parent = p.parent
if (parent / "config.json").exists() or (
parent / "adapter_config.json"
).exists():
return parent
return None
def list_local_gguf_variants(
directory: str,
) -> tuple[list[GgufVariantInfo], bool]:
"""List GGUF quantization variants in a local directory.
Mirrors :func:`list_gguf_variants` but reads from the filesystem
instead of the HuggingFace API. Aggregates shard sizes by quant
label so that split GGUFs appear as a single variant.
Returns:
(variants, has_vision): list of non-mmproj GGUF variants + vision flag.
"""
p = _resolve_gguf_dir(Path(directory))
if p is None:
return [], False
quant_totals: dict[str, int] = {}
quant_first_file: dict[str, str] = {}
has_vision = False
for f in sorted(_iter_gguf_files(p)):
if _is_mmproj(f.name):
has_vision = True
continue
try:
size = f.stat().st_size
except OSError:
size = 0
quant = _extract_quant_label(f.name)
quant_totals[quant] = quant_totals.get(quant, 0) + size
if quant not in quant_first_file:
quant_first_file[quant] = f.name
variants = [
GgufVariantInfo(
filename = quant_first_file[q],
quant = q,
size_bytes = s,
)
for q, s in quant_totals.items()
]
variants.sort(key = lambda v: -v.size_bytes)
return variants, has_vision
def _find_local_gguf_by_variant(directory: str, variant: str) -> Optional[str]:
"""Find the GGUF file in *directory* matching a quantization *variant*.
For sharded GGUFs (multiple files with the same quant label), returns
the first shard (sorted by name) which is what ``llama-server -m`` expects.
Returns the resolved absolute path, or ``None`` if no match.
"""
p = _resolve_gguf_dir(Path(directory))
if p is None:
return None
matches = sorted(
f
for f in _iter_gguf_files(p)
if not _is_mmproj(f.name) and _extract_quant_label(f.name) == variant
)
if matches:
return str(matches[0].resolve())
return None
def detect_gguf_model_remote(
repo_id: str,
hf_token: Optional[str] = None,
) -> Optional[str]:
"""
Check if a HuggingFace repo contains GGUF files.
Returns the filename of the best GGUF file in the repo, or None.
"""
try:
from huggingface_hub import model_info as hf_model_info
info = hf_model_info(repo_id, token = hf_token)
repo_files = [s.rfilename for s in info.siblings]
return _pick_best_gguf(repo_files)
except Exception as e:
logger.debug(f"Could not check GGUF files for '{repo_id}': {e}")
return None
def download_gguf_file(
repo_id: str,
filename: str,
hf_token: Optional[str] = None,
) -> str:
"""
Download a specific GGUF file from a HuggingFace repo.
Returns the local path to the downloaded file.
"""
from huggingface_hub import hf_hub_download
local_path = hf_hub_download(
repo_id = repo_id,
filename = filename,
token = hf_token,
)
return local_path
# Cache embedding detection results per session to avoid repeated HF API calls
_embedding_detection_cache: Dict[tuple, bool] = {}
def is_embedding_model(model_name: str, hf_token: Optional[str] = None) -> bool:
"""
Detect embedding/sentence-transformer models using HuggingFace model metadata.
Uses a belt-and-suspenders approach combining three signals:
1. "sentence-transformers" in model tags
2. "feature-extraction" in model tags
3. pipeline_tag is "sentence-similarity" or "feature-extraction"
This catches all known embedding models including those like gte-modernbert
whose library_name is "transformers" rather than "sentence-transformers".
Args:
model_name: Model identifier (HF repo or local path)
hf_token: Optional HF token for accessing gated/private models
Returns:
True if the model is an embedding model, False otherwise.
Defaults to False for local paths or on errors.
"""
cache_key = (model_name, hf_token)
if cache_key in _embedding_detection_cache:
return _embedding_detection_cache[cache_key]
# Local paths: check for sentence-transformer marker file (modules.json)
if is_local_path(model_name):
local_dir = normalize_path(model_name)
is_emb = os.path.isfile(os.path.join(local_dir, "modules.json"))
_embedding_detection_cache[cache_key] = is_emb
return is_emb
try:
from huggingface_hub import model_info as hf_model_info
info = hf_model_info(model_name, token = hf_token)
tags = set(info.tags or [])
pipeline_tag = info.pipeline_tag or ""
is_emb = (
"sentence-transformers" in tags
or "feature-extraction" in tags
or pipeline_tag in ("sentence-similarity", "feature-extraction")
)
_embedding_detection_cache[cache_key] = is_emb
if is_emb:
logger.info(
f"Model {model_name} detected as embedding model: "
f"pipeline_tag={pipeline_tag}, "
f"sentence-transformers in tags={('sentence-transformers' in tags)}, "
f"feature-extraction in tags={('feature-extraction' in tags)}"
)
return is_emb
except Exception as e:
logger.warning(f"Could not determine if {model_name} is embedding model: {e}")
_embedding_detection_cache[cache_key] = False
return False
def _has_model_weight_files(model_dir: Path) -> bool:
"""Return True when a directory contains loadable model weights."""
for item in model_dir.iterdir():
if not item.is_file():
continue
suffix = item.suffix.lower()
if suffix == ".safetensors":
return True
if suffix == ".gguf":
return "mmproj" not in item.name.lower()
if suffix == ".bin":
name = item.name.lower()
if (
name.startswith("pytorch_model")
or name.startswith("model")
or name.startswith("adapter_model")
or name.startswith("consolidated")
):
return True
return False
def _detect_training_output_type(model_dir: Path) -> Optional[str]:
"""Classify a Studio training output as LoRA or full finetune."""
adapter_config = model_dir / "adapter_config.json"
adapter_model = model_dir / "adapter_model.safetensors"
if adapter_config.exists() or adapter_model.exists():
return "lora"
config_file = model_dir / "config.json"
if config_file.exists() and _has_model_weight_files(model_dir):
return "merged"
return None
def _looks_like_lora_adapter(model_dir: Path) -> bool:
return model_dir.is_dir() and (
(model_dir / "adapter_config.json").exists()
or any(model_dir.glob("adapter_model*.safetensors"))
or any(model_dir.glob("adapter_model*.bin"))
)
def scan_trained_models(
outputs_dir: str = str(outputs_root()),
) -> List[Tuple[str, str, str]]:
"""
Scan outputs folder for trained Studio models.
Returns:
List of tuples: [(display_name, model_path, model_type), ...]
model_type is "lora" for adapter runs and "merged" for full finetunes.
"""
trained_models = []
outputs_path = resolve_output_dir(outputs_dir)
if not outputs_path.exists():
logger.warning(f"Outputs directory not found: {outputs_dir}")
return trained_models
try:
for item in outputs_path.iterdir():
if item.is_dir():
model_type = _detect_training_output_type(item)
if model_type is None:
continue
display_name = item.name
model_path = str(item)
trained_models.append((display_name, model_path, model_type))
logger.debug("Found trained model: %s (%s)", display_name, model_type)
# Sort by modification time (newest first)
trained_models.sort(key = lambda x: Path(x[1]).stat().st_mtime, reverse = True)
logger.info(
"Found %s trained models in %s",
len(trained_models),
outputs_dir,
)
return trained_models
except Exception as e:
logger.error(f"Error scanning outputs folder: {e}")
return []
def scan_exported_models(
exports_dir: str = str(exports_root()),
) -> List[Tuple[str, str, str, Optional[str]]]:
"""
Scan exports folder for exported models (merged, LoRA, GGUF).
Supports two directory layouts:
- Two-level: {run}/{checkpoint}/ (merged & LoRA exports)
- Flat: {name}-finetune-gguf/ (GGUF exports)
Returns:
List of tuples: [(display_name, model_path, export_type, base_model), ...]
export_type: "lora" | "merged" | "gguf"
"""
results = []
exports_path = resolve_export_dir(exports_dir)
if not exports_path.exists():
return results
try:
for run_dir in exports_path.iterdir():
if not run_dir.is_dir():
continue
# Check for flat GGUF export (e.g. exports/gemma-3-4b-it-finetune-gguf/)
# Filter out mmproj (vision projection) files — they aren't loadable as main models
gguf_files = [
f for f in _iter_gguf_files(run_dir) if not _is_mmproj(f.name)
]
if gguf_files:
base_model = None
export_meta = run_dir / "export_metadata.json"
try:
if export_meta.exists():
meta = json.loads(export_meta.read_text())
base_model = meta.get("base_model")
except Exception:
pass
display_name = run_dir.name
model_path = str(gguf_files[0]) # path to the .gguf file
results.append((display_name, model_path, "gguf", base_model))
logger.debug(f"Found GGUF export: {display_name}")
continue
# Two-level: {run}/{checkpoint}/
for checkpoint_dir in run_dir.iterdir():
if not checkpoint_dir.is_dir():
continue
adapter_config = checkpoint_dir / "adapter_config.json"
config_file = checkpoint_dir / "config.json"
has_weights = any(checkpoint_dir.glob("*.safetensors")) or any(
checkpoint_dir.glob("*.bin")
)
has_gguf = any(_iter_gguf_files(checkpoint_dir))
base_model = None
export_type = None
if adapter_config.exists():
export_type = "lora"
try:
cfg = json.loads(adapter_config.read_text())
base_model = cfg.get("base_model_name_or_path")
except Exception:
pass
elif config_file.exists() and has_weights:
export_type = "merged"
export_meta = checkpoint_dir / "export_metadata.json"
try:
if export_meta.exists():
meta = json.loads(export_meta.read_text())
base_model = meta.get("base_model")
except Exception:
pass
elif has_gguf:
export_type = "gguf"
gguf_list = list(_iter_gguf_files(checkpoint_dir))
# Check checkpoint_dir first, then fall back to parent run_dir
# (export.py writes metadata to the top-level export directory)
for meta_dir in (checkpoint_dir, run_dir):
export_meta = meta_dir / "export_metadata.json"
try:
if export_meta.exists():
meta = json.loads(export_meta.read_text())
base_model = meta.get("base_model")
if base_model:
break
except Exception:
pass
display_name = f"{run_dir.name} / {checkpoint_dir.name}"
model_path = str(gguf_list[0]) if gguf_list else str(checkpoint_dir)
results.append((display_name, model_path, export_type, base_model))
logger.debug(f"Found GGUF export: {display_name}")
continue
else:
continue
# Fallback: read base model from the original training run's
# adapter_config.json in ./outputs/{run_name}/
if not base_model:
outputs_adapter_cfg = (
resolve_output_dir(run_dir.name) / "adapter_config.json"
)
try:
if outputs_adapter_cfg.exists():
cfg = json.loads(outputs_adapter_cfg.read_text())
base_model = cfg.get("base_model_name_or_path")
except Exception:
pass
display_name = f"{run_dir.name} / {checkpoint_dir.name}"
model_path = str(checkpoint_dir)
results.append((display_name, model_path, export_type, base_model))
logger.debug(f"Found exported model: {display_name} ({export_type})")
results.sort(key = lambda x: Path(x[1]).stat().st_mtime, reverse = True)
logger.info(f"Found {len(results)} exported models in {exports_dir}")
return results
except Exception as e:
logger.error(f"Error scanning exports folder: {e}")
return []
def get_base_model_from_checkpoint(checkpoint_path: str) -> Optional[str]:
"""Read the base model name from a local training or checkpoint directory."""
try:
checkpoint_path_obj = Path(checkpoint_path)
adapter_config_path = checkpoint_path_obj / "adapter_config.json"
if adapter_config_path.exists():
with open(adapter_config_path, "r") as f:
config = json.load(f)
base_model = config.get("base_model_name_or_path")
if base_model:
logger.info(
"Detected base model from adapter_config.json: %s", base_model
)
return base_model
config_path = checkpoint_path_obj / "config.json"
if config_path.exists():
with open(config_path, "r") as f:
config = json.load(f)
for key in ("model_name", "_name_or_path"):
base_model = config.get(key)
if base_model and str(base_model) != str(checkpoint_path_obj):
logger.info(
"Detected base model from config.json (%s): %s",
key,
base_model,
)
return base_model
training_args_path = checkpoint_path_obj / "training_args.bin"
if training_args_path.exists():
try:
import torch
training_args = torch.load(training_args_path)
if hasattr(training_args, "model_name_or_path"):
base_model = training_args.model_name_or_path
logger.info(
"Detected base model from training_args.bin: %s", base_model
)
return base_model
except Exception as e:
logger.warning(f"Could not load training_args.bin: {e}")
dir_name = checkpoint_path_obj.name
if dir_name.startswith("unsloth_"):
parts = dir_name.split("_")
if len(parts) >= 2:
model_parts = parts[1:-1]
base_model = "unsloth/" + "_".join(model_parts)
logger.info("Detected base model from directory name: %s", base_model)
return base_model
logger.warning(f"Could not detect base model for checkpoint: {checkpoint_path}")
return None
except Exception as e:
logger.error(f"Error reading base model from checkpoint config: {e}")
return None
def get_base_model_from_lora(lora_path: str) -> Optional[str]:
"""
Read the base model name from a LoRA adapter's config.
Args:
lora_path: Path to the LoRA adapter directory
Returns:
Base model identifier or None if not found
"""
try:
lora_path_obj = Path(lora_path)
if not _looks_like_lora_adapter(lora_path_obj):
return None
# Try adapter_config.json first
adapter_config_path = lora_path_obj / "adapter_config.json"
if adapter_config_path.exists():
with open(adapter_config_path, "r") as f:
config = json.load(f)
base_model = config.get("base_model_name_or_path")
if base_model:
logger.info(
f"Detected base model from adapter_config.json: {base_model}"
)
return base_model
# Fallback: try training_args.bin (requires torch)
training_args_path = lora_path_obj / "training_args.bin"
if training_args_path.exists():
try:
import torch
training_args = torch.load(training_args_path)
if hasattr(training_args, "model_name_or_path"):
base_model = training_args.model_name_or_path
logger.info(
f"Detected base model from training_args.bin: {base_model}"
)
return base_model
except Exception as e:
logger.warning(f"Could not load training_args.bin: {e}")
# Last resort: parse from directory name
# Format: unsloth_Meta-Llama-3.1-8B-Instruct-bnb-4bit_timestamp
dir_name = lora_path_obj.name
if dir_name.startswith("unsloth_"):
# Remove timestamp suffix (usually _1234567890)
parts = dir_name.split("_")
# Reconstruct model name
if len(parts) >= 2:
model_parts = parts[1:-1] # Skip "unsloth" and timestamp
base_model = "unsloth/" + "_".join(model_parts)
logger.info(f"Detected base model from directory name: {base_model}")
return base_model
logger.warning(f"Could not detect base model for LoRA: {lora_path}")
return None
except Exception as e:
logger.error(f"Error reading base model from LoRA config: {e}")
return None
# Status indicators that appear in UI dropdowns
UI_STATUS_INDICATORS = [" (Ready)", " (Loading...)", " (Active)", ""]
def load_model_defaults(model_name: str) -> Dict[str, Any]:
"""
Load default training parameters for a model from YAML file.
Args:
model_name: Model identifier (e.g., "unsloth/Meta-Llama-3.1-8B-bnb-4bit")
Returns:
Dictionary with default parameters from YAML file, or empty dict if not found
The function looks for a YAML file in configs/model_defaults/ (including subfolders)
based on the model name or its aliases from MODEL_NAME_MAPPING.
If no specific file exists, it falls back to default.yaml.
"""
try:
# Get the script directory to locate configs
script_dir = Path(__file__).parent.parent.parent
defaults_dir = script_dir / "assets" / "configs" / "model_defaults"
# First, check if model is in the mapping
if model_name.lower() in _REVERSE_MODEL_MAPPING:
canonical_file = _REVERSE_MODEL_MAPPING[model_name.lower()]
# Search in subfolders and root
for config_path in defaults_dir.rglob(canonical_file):
if config_path.is_file():
with open(config_path, "r", encoding = "utf-8") as f:
config = yaml.safe_load(f) or {}
logger.info(
f"Loaded model defaults from {config_path} (via mapping)"
)
return config
# If model_name is a local path (e.g. /home/.../Spark-TTS-0.5B/LLM from
# adapter_config.json, or C:\Users\...\model on Windows), try matching
# the last 1-2 path components against the registry
# (e.g. "Spark-TTS-0.5B/LLM").
_is_local_path = is_local_path(model_name)
# Normalize Windows backslash paths so Path().parts splits correctly
# on POSIX/WSL hosts (pathlib treats backslashes as literals on Linux).
_normalized = normalize_path(model_name) if _is_local_path else model_name
if model_name.lower() not in _REVERSE_MODEL_MAPPING and _is_local_path:
parts = Path(_normalized).parts
for depth in [2, 1]:
if len(parts) >= depth:
suffix = "/".join(parts[-depth:])
if suffix.lower() in _REVERSE_MODEL_MAPPING:
canonical_file = _REVERSE_MODEL_MAPPING[suffix.lower()]
for config_path in defaults_dir.rglob(canonical_file):
if config_path.is_file():
with open(config_path, "r", encoding = "utf-8") as f:
config = yaml.safe_load(f) or {}
logger.info(
f"Loaded model defaults from {config_path} (via path suffix '{suffix}')"
)
return config
# Try exact model name match (for backward compatibility).
# For local filesystem paths, use only the directory basename to
# avoid passing absolute paths (e.g. C:\...) into rglob which
# raises "Non-relative patterns are unsupported" on Windows.
_lookup_name = Path(_normalized).name if _is_local_path else model_name
model_filename = _lookup_name.replace("/", "_") + ".yaml"
# Search in subfolders and root
for config_path in defaults_dir.rglob(model_filename):
if config_path.is_file():
with open(config_path, "r", encoding = "utf-8") as f:
config = yaml.safe_load(f) or {}
logger.info(f"Loaded model defaults from {config_path}")
return config
# Fall back to default.yaml
default_config_path = defaults_dir / "default.yaml"
if default_config_path.exists():
with open(default_config_path, "r", encoding = "utf-8") as f:
config = yaml.safe_load(f) or {}
logger.info(f"Loaded default model defaults from {default_config_path}")
return config
logger.warning(f"No default config found for model {model_name}")
return {}
except Exception as e:
logger.error(f"Error loading model defaults for {model_name}: {e}")
return {}
@dataclass
class ModelConfig:
"""Configuration for a model to load"""
identifier: str # Clean model identifier (org/name or path)
display_name: str # Original UI display name
path: str # Normalized filesystem path
is_local: bool # Is this a local file vs HF model?
is_cached: bool # Is this already in HF cache?
is_vision: bool # Is this a vision model?
is_lora: bool # Is this a lora adapter?
is_gguf: bool = False # Is this a GGUF model?
is_audio: bool = False # Is this a TTS audio model?
audio_type: Optional[str] = (
None # Audio codec type: 'snac', 'csm', 'bicodec', 'dac'
)
has_audio_input: bool = False # Accepts audio input (ASR/speech understanding)
gguf_file: Optional[str] = None # Full path to the .gguf file (local mode)
gguf_mmproj_file: Optional[str] = (
None # Full path to the mmproj .gguf file (vision projection)
)
gguf_hf_repo: Optional[str] = (
None # HF repo ID for -hf mode (e.g. "unsloth/gemma-3-4b-it-GGUF")
)
gguf_variant: Optional[str] = None # Quantization variant (e.g. "Q4_K_M")
base_model: Optional[str] = None # Base model (for LoRAs)
@classmethod
def from_lora_path(
cls, lora_path: str, hf_token: Optional[str] = None
) -> Optional["ModelConfig"]:
"""
Create ModelConfig from a local LoRA adapter path.
Automatically detects the base model from adapter config.
Args:
lora_path: Path to LoRA adapter (e.g., "./outputs/unsloth_Meta-Llama-3.1_.../")
hf_token: HF token for vision detection
Returns:
ModelConfig for the LoRA adapter
"""
try:
lora_path_obj = Path(lora_path)
if not lora_path_obj.exists():
logger.error(f"LoRA path does not exist: {lora_path}")
return None
# Get base model
base_model = get_base_model_from_lora(lora_path)
if not base_model:
logger.error(f"Could not determine base model for LoRA: {lora_path}")
return None
# Check if base model is vision
is_vision = is_vision_model(base_model, hf_token = hf_token)
# Check if base model is audio
audio_type = detect_audio_type(base_model, hf_token = hf_token)
display_name = lora_path_obj.name
identifier = lora_path # Use path as identifier for local LoRAs
return cls(
identifier = identifier,
display_name = display_name,
path = lora_path,
is_local = True,
is_cached = True, # Local LoRAs are always "cached"
is_vision = is_vision,
is_lora = True,
is_audio = audio_type is not None and audio_type != "audio_vlm",
audio_type = audio_type,
has_audio_input = is_audio_input_type(audio_type),
base_model = base_model,
)
except Exception as e:
logger.error(f"Error creating ModelConfig from LoRA path: {e}")
return None
@classmethod
def from_identifier(
cls,
model_id: str,
hf_token: Optional[str] = None,
is_lora: bool = False,
gguf_variant: Optional[str] = None,
) -> Optional["ModelConfig"]:
"""
Create ModelConfig from a clean model identifier.
For FastAPI routes where the frontend sends sanitized model paths.
No Gradio dropdown parsing - expects clean identifiers like:
- "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
- "./outputs/my_lora_adapter"
- "/absolute/path/to/model"
Args:
model_id: Clean model identifier (HF repo name or local path)
hf_token: Optional HF token for vision detection on gated models
is_lora: Whether this is a LoRA adapter
gguf_variant: Optional GGUF quantization variant (e.g. "Q4_K_M").
For remote GGUF repos, specifies which quant to load via -hf.
If None, auto-selects using _pick_best_gguf().
Returns:
ModelConfig or None if configuration cannot be created
"""
if not model_id or not model_id.strip():
return None
identifier = model_id.strip()
is_local = is_local_path(identifier)
path = normalize_path(identifier) if is_local else identifier
# Add unsloth/ prefix for shorthand HF models
if not is_local and "/" not in identifier:
identifier = f"unsloth/{identifier}"
path = identifier
# Preserve requested casing, but if a case-variant already exists in local HF cache,
# reuse that exact repo_id spelling to avoid one-time re-downloads after #2592.
if not is_local:
resolved_identifier = resolve_cached_repo_id_case(identifier)
if resolved_identifier != identifier:
logger.info(
"Using cached repo_id casing '%s' for requested '%s'",
resolved_identifier,
identifier,
)
identifier = resolved_identifier
path = resolved_identifier
# Auto-detect GGUF models (check before LoRA/vision detection)
if is_local:
if gguf_variant:
gguf_file = _find_local_gguf_by_variant(path, gguf_variant)
else:
gguf_file = detect_gguf_model(path)
if gguf_file:
display_name = Path(gguf_file).stem
logger.info(f"Detected local GGUF model: {gguf_file}")
# Detect vision: check if base model is vision, then look for mmproj
mmproj_file = None
gguf_is_vision = False
gguf_dir = Path(gguf_file).parent
# Determine if this is a vision model from export metadata
base_is_vision = False
meta_path = gguf_dir / "export_metadata.json"
if meta_path.exists():
try:
meta = json.loads(meta_path.read_text())
base = meta.get("base_model")
if base and is_vision_model(base, hf_token = hf_token):
base_is_vision = True
logger.info(f"GGUF base model '{base}' is a vision model")
except Exception as e:
logger.debug(f"Could not read export metadata: {e}")
# If vision (or mmproj happens to exist), find the mmproj file
mmproj_file = detect_mmproj_file(gguf_file)
if mmproj_file:
gguf_is_vision = True
logger.info(f"Detected mmproj for vision: {mmproj_file}")
elif base_is_vision:
logger.warning(
f"Base model is vision but no mmproj file found in {gguf_dir}"
)
return cls(
identifier = identifier,
display_name = display_name,
path = path,
is_local = True,
is_cached = True,
is_vision = gguf_is_vision,
is_lora = False,
is_gguf = True,
gguf_file = gguf_file,
gguf_mmproj_file = mmproj_file,
)
else:
# Check if the HF repo contains GGUF files
gguf_filename = detect_gguf_model_remote(identifier, hf_token = hf_token)
if gguf_filename:
# Preflight: verify llama-server binary exists BEFORE user waits
# for a multi-GB download that llama-server handles natively
from core.inference.llama_cpp import LlamaCppBackend
if not LlamaCppBackend._find_llama_server_binary():
raise RuntimeError(
"llama-server binary not found — cannot load GGUF models. "
"Run setup.sh to build it, or set LLAMA_SERVER_PATH."
)
# Use list_gguf_variants() to detect vision & resolve variant
variants, has_vision = list_gguf_variants(identifier, hf_token = hf_token)
variant = gguf_variant
if not variant:
# Auto-select best quantization
variant_filenames = [v.filename for v in variants]
best = _pick_best_gguf(variant_filenames)
if best:
variant = _extract_quant_label(best)
else:
variant = "Q4_K_M" # Fallback — llama-server's own default
display_name = f"{identifier.split('/')[-1]} ({variant})"
logger.info(
f"Detected remote GGUF repo '{identifier}', "
f"variant={variant}, vision={has_vision}"
)
return cls(
identifier = identifier,
display_name = display_name,
path = identifier,
is_local = False,
is_cached = False,
is_vision = has_vision,
is_lora = False,
is_gguf = True,
gguf_file = None,
gguf_hf_repo = identifier,
gguf_variant = variant,
)
# Auto-detect LoRA for local paths (check adapter_config.json on disk)
if not is_lora and is_local:
detected_base = (
get_base_model_from_lora(path)
if _looks_like_lora_adapter(Path(path))
else None
)
if detected_base:
is_lora = True
logger.info(
f"Auto-detected local LoRA adapter at '{path}' (base: {detected_base})"
)
# Auto-detect LoRA for remote HF models (check repo file listing)
if not is_lora and not is_local:
try:
from huggingface_hub import model_info as hf_model_info
info = hf_model_info(identifier, token = hf_token)
repo_files = [s.rfilename for s in info.siblings]
if "adapter_config.json" in repo_files:
is_lora = True
logger.info(f"Auto-detected remote LoRA adapter: '{identifier}'")
except Exception as e:
logger.debug(
f"Could not check remote LoRA status for '{identifier}': {e}"
)
# Handle LoRA adapters
base_model = None
if is_lora:
if is_local:
# Local LoRA: read adapter_config.json from disk
base_model = get_base_model_from_lora(path)
else:
# Remote LoRA: download adapter_config.json from HF
try:
from huggingface_hub import hf_hub_download
config_path = hf_hub_download(
identifier, "adapter_config.json", token = hf_token
)
with open(config_path, "r") as f:
adapter_config = json.load(f)
base_model = adapter_config.get("base_model_name_or_path")
if base_model:
logger.info(f"Resolved remote LoRA base model: '{base_model}'")
except Exception as e:
logger.warning(
f"Could not download adapter_config.json for '{identifier}': {e}"
)
if not base_model:
logger.warning(f"Could not determine base model for LoRA '{path}'")
return None
check_model = base_model
else:
check_model = identifier
vision = is_vision_model(check_model, hf_token = hf_token)
audio_type_val = detect_audio_type(check_model, hf_token = hf_token)
has_audio_in = is_audio_input_type(audio_type_val)
display_name = Path(path).name if is_local else identifier.split("/")[-1]
return cls(
identifier = identifier,
display_name = display_name,
path = path,
is_local = is_local,
is_cached = is_model_cached(identifier) if not is_local else True,
is_vision = vision,
is_lora = is_lora,
is_audio = audio_type_val is not None and audio_type_val != "audio_vlm",
audio_type = audio_type_val,
has_audio_input = has_audio_in,
base_model = base_model,
)
@classmethod
def from_ui_selection(
cls,
dropdown_value: Optional[str],
search_value: Optional[str],
local_models: list = None,
hf_token: Optional[str] = None,
is_lora: bool = False,
) -> Optional["ModelConfig"]:
"""
Create a universal ModelConfig from UI dropdown/search selections.
Handles base models and LoRA adapters.
"""
selected = None
if search_value and search_value.strip():
selected = search_value.strip()
elif dropdown_value:
selected = dropdown_value
if not selected:
return None
display_name = selected
# Use the correct 'local_models' parameter to resolve display names
if " (Active)" in selected or " (Ready)" in selected:
clean_display_name = selected.replace(" (Active)", "").replace(
" (Ready)", ""
)
if local_models:
for local_display, local_path in local_models:
if local_display == clean_display_name:
selected = local_path
break
# Clean all UI status indicators to get the final identifier
identifier = selected
for status in UI_STATUS_INDICATORS:
identifier = identifier.replace(status, "")
identifier = identifier.strip()
is_local = is_local_path(identifier)
path = normalize_path(identifier) if is_local else identifier
# Add unsloth/ prefix for shorthand HF models
if not is_local and "/" not in identifier:
identifier = f"unsloth/{identifier}"
path = identifier
if not is_local:
resolved_identifier = resolve_cached_repo_id_case(identifier)
if resolved_identifier != identifier:
identifier = resolved_identifier
path = resolved_identifier
# --- Logic for Base Model and Vision Detection ---
base_model = None
is_vision = False
if is_lora:
# For a LoRA, we MUST find its base model.
base_model = get_base_model_from_lora(path)
if not base_model:
logger.warning(
f"Could not determine base model for LoRA '{path}'. Cannot create config."
)
return None # Cannot proceed without a base model
# A LoRA's vision capability is determined by its base model.
is_vision = is_vision_model(base_model, hf_token = hf_token)
else:
# For a base model, just check its own vision status.
is_vision = is_vision_model(identifier, hf_token = hf_token)
from utils.paths import is_model_cached
is_cached = is_model_cached(identifier) if not is_local else True
return cls(
identifier = identifier,
display_name = display_name,
path = path,
is_local = is_local,
is_cached = is_cached,
is_vision = is_vision,
is_lora = is_lora,
base_model = base_model, # This will be None for base models, and populated for LoRAs
)