feat: support GGUF export for non-PEFT models + fix venv_t5 switching for local checkpoints (#4455)

* feat: support full model GGUF export, disable incompatible methods in UI

* fix: resolve base model from config.json for venv_t5 export switching

* feat: detect BNB-quantized models and disable all export methods for quantized non-PEFT checkpoints

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: relocate Ollama Modelfile alongside GGUFs during non-PEFT export cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Roland Tannous 2026-03-20 12:13:18 +04:00 committed by GitHub
parent be901ecdea
commit ebe45981dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 362 additions and 14 deletions

View file

@ -569,6 +569,27 @@ class ExportBackend:
shutil.rmtree(str(sub), ignore_errors = True)
logger.info(f"Cleaned up subdirectory: {sub.name}")
# For non-PEFT models, save_pretrained_gguf redirects to the
# checkpoint path, leaving a *_gguf directory in outputs/.
# Relocate any GGUFs from there and clean it up.
if self.current_checkpoint:
ckpt = Path(self.current_checkpoint)
gguf_dir = ckpt.parent / f"{ckpt.name}_gguf"
if gguf_dir.is_dir():
for src in gguf_dir.glob("*.gguf"):
dest = os.path.join(abs_save_dir, src.name)
shutil.move(str(src), dest)
logger.info(f"Relocated GGUF: {src.name}{abs_save_dir}/")
# Also relocate Ollama Modelfile if present
modelfile = gguf_dir / "Modelfile"
if modelfile.is_file():
shutil.move(
str(modelfile), os.path.join(abs_save_dir, "Modelfile")
)
logger.info(f"Relocated Modelfile → {abs_save_dir}/")
shutil.rmtree(str(gguf_dir), ignore_errors = True)
logger.info(f"Cleaned up intermediate GGUF dir: {gguf_dir}")
# Write export metadata so the Chat page can identify the base model
self._write_export_metadata(abs_save_dir)

View file

@ -41,6 +41,10 @@ class ModelCheckpoints(BaseModel):
None,
description = "LoRA rank (r) if applicable",
)
is_quantized: bool = Field(
False,
description = "Whether the model uses BNB quantization (e.g. bnb-4bit)",
)
class CheckpointListResponse(BaseModel):

View file

@ -1092,6 +1092,7 @@ async def list_checkpoints(
base_model = metadata.get("base_model"),
peft_type = metadata.get("peft_type"),
lora_rank = metadata.get("lora_rank"),
is_quantized = metadata.get("is_quantized", False),
)
for model_name, checkpoints, metadata in raw_models
]

View file

@ -0,0 +1,190 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""Tests for transformers version detection with local checkpoint fallbacks."""
import json
import pytest
from pathlib import Path
from unittest.mock import patch
# ---------------------------------------------------------------------------
# We need to be able to import the module under test. The studio backend
# uses relative-style imports (``from utils.…``), so we add the backend
# directory to *sys.path* if it is not already there.
# ---------------------------------------------------------------------------
import sys
_BACKEND_DIR = str(Path(__file__).resolve().parent.parent)
if _BACKEND_DIR not in sys.path:
sys.path.insert(0, _BACKEND_DIR)
# Stub the custom logger before importing the module under test so it
# doesn't fail on the ``from loggers import get_logger`` line.
import types as _types
_loggers_stub = _types.ModuleType("loggers")
_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name)
sys.modules.setdefault("loggers", _loggers_stub)
from utils.transformers_version import (
_resolve_base_model,
_check_tokenizer_config_needs_v5,
_tokenizer_class_cache,
needs_transformers_5,
)
# ---------------------------------------------------------------------------
# _resolve_base_model — config.json fallback
# ---------------------------------------------------------------------------
class TestResolveBaseModel:
"""Tests for _resolve_base_model() local config fallbacks."""
def test_adapter_config_takes_priority(self, tmp_path: Path):
"""adapter_config.json should be preferred over config.json."""
adapter_cfg = {"base_model_name_or_path": "meta-llama/Llama-3-8B"}
config_cfg = {"_name_or_path": "different/model"}
(tmp_path / "adapter_config.json").write_text(json.dumps(adapter_cfg))
(tmp_path / "config.json").write_text(json.dumps(config_cfg))
result = _resolve_base_model(str(tmp_path))
assert result == "meta-llama/Llama-3-8B"
def test_config_json_fallback_model_name(self, tmp_path: Path):
"""config.json model_name should resolve when no adapter_config."""
config_cfg = {"model_name": "Qwen/Qwen3.5-9B"}
(tmp_path / "config.json").write_text(json.dumps(config_cfg))
result = _resolve_base_model(str(tmp_path))
assert result == "Qwen/Qwen3.5-9B"
def test_config_json_fallback_name_or_path(self, tmp_path: Path):
"""config.json _name_or_path should resolve as secondary fallback."""
config_cfg = {"_name_or_path": "Qwen/Qwen3.5-9B"}
(tmp_path / "config.json").write_text(json.dumps(config_cfg))
result = _resolve_base_model(str(tmp_path))
assert result == "Qwen/Qwen3.5-9B"
def test_model_name_takes_priority_over_name_or_path(self, tmp_path: Path):
"""model_name should be preferred over _name_or_path."""
config_cfg = {
"model_name": "Qwen/Qwen3.5-9B",
"_name_or_path": "some/other-model",
}
(tmp_path / "config.json").write_text(json.dumps(config_cfg))
result = _resolve_base_model(str(tmp_path))
assert result == "Qwen/Qwen3.5-9B"
def test_config_json_skips_self_referencing(self, tmp_path: Path):
"""config.json should be ignored if model_name == the checkpoint path."""
config_cfg = {"model_name": str(tmp_path)}
(tmp_path / "config.json").write_text(json.dumps(config_cfg))
result = _resolve_base_model(str(tmp_path))
# Should fall through, not return the self-referencing path
assert result == str(tmp_path)
def test_no_config_files(self, tmp_path: Path):
"""Returns original name when no config files are present."""
result = _resolve_base_model(str(tmp_path))
assert result == str(tmp_path)
def test_plain_hf_id_passthrough(self):
"""Plain HuggingFace model IDs pass through unchanged."""
result = _resolve_base_model("meta-llama/Llama-3-8B")
assert result == "meta-llama/Llama-3-8B"
# ---------------------------------------------------------------------------
# _check_tokenizer_config_needs_v5 — local file check
# ---------------------------------------------------------------------------
class TestCheckTokenizerConfigNeedsV5:
"""Tests for local tokenizer_config.json fallback."""
def setup_method(self):
_tokenizer_class_cache.clear()
def test_local_tokenizer_config_v5(self, tmp_path: Path):
"""Local tokenizer_config.json with v5 tokenizer should return True."""
tc = {"tokenizer_class": "TokenizersBackend"}
(tmp_path / "tokenizer_config.json").write_text(json.dumps(tc))
result = _check_tokenizer_config_needs_v5(str(tmp_path))
assert result is True
def test_local_tokenizer_config_v4(self, tmp_path: Path):
"""Local tokenizer_config.json with standard tokenizer should return False."""
tc = {"tokenizer_class": "LlamaTokenizerFast"}
(tmp_path / "tokenizer_config.json").write_text(json.dumps(tc))
result = _check_tokenizer_config_needs_v5(str(tmp_path))
assert result is False
def test_local_file_skips_network(self, tmp_path: Path):
"""When local file exists, no network request should be made."""
tc = {"tokenizer_class": "LlamaTokenizerFast"}
(tmp_path / "tokenizer_config.json").write_text(json.dumps(tc))
with patch("urllib.request.urlopen") as mock_urlopen:
result = _check_tokenizer_config_needs_v5(str(tmp_path))
mock_urlopen.assert_not_called()
assert result is False
def test_result_is_cached(self, tmp_path: Path):
"""Subsequent calls should use the cache."""
tc = {"tokenizer_class": "TokenizersBackend"}
(tmp_path / "tokenizer_config.json").write_text(json.dumps(tc))
key = str(tmp_path)
_check_tokenizer_config_needs_v5(key)
assert key in _tokenizer_class_cache
assert _tokenizer_class_cache[key] is True
# ---------------------------------------------------------------------------
# needs_transformers_5 — integration-level
# ---------------------------------------------------------------------------
class TestNeedsTransformers5:
"""Integration tests for the top-level needs_transformers_5() function."""
def setup_method(self):
_tokenizer_class_cache.clear()
def test_qwen35_substring(self):
assert needs_transformers_5("Qwen/Qwen3.5-9B") is True
def test_qwen3_30b_a3b_substring(self):
assert needs_transformers_5("Qwen/Qwen3-30B-A3B-Instruct-2507") is True
def test_ministral_substring(self):
assert needs_transformers_5("mistralai/Ministral-3-8B-Instruct-2512") is True
def test_llama_does_not_need_v5(self):
"""Standard models should not trigger v5."""
# Patch network call to avoid real fetch
with patch(
"utils.transformers_version._check_tokenizer_config_needs_v5",
return_value = False,
):
assert needs_transformers_5("meta-llama/Llama-3-8B") is False
def test_local_checkpoint_resolved_via_config(self, tmp_path: Path):
"""A local checkpoint with config.json pointing to Qwen3.5 should need v5."""
config_cfg = {"model_name": "Qwen/Qwen3.5-9B"}
(tmp_path / "config.json").write_text(json.dumps(config_cfg))
# _resolve_base_model is called by ensure_transformers_version,
# but needs_transformers_5 just does substring matching.
# We test the full resolution chain here:
resolved = _resolve_base_model(str(tmp_path))
assert needs_transformers_5(resolved) is True

View file

@ -76,6 +76,18 @@ def scan_checkpoints(
elif config_file.exists():
cfg = json.loads(config_file.read_text())
metadata["base_model"] = cfg.get("_name_or_path")
# Detect BNB quantization from config.json (present in both cases)
if config_file.exists():
if "cfg" not in dir():
cfg = json.loads(config_file.read_text())
quant_cfg = cfg.get("quantization_config")
if (
isinstance(quant_cfg, dict)
and quant_cfg.get("quant_method") == "bitsandbytes"
):
metadata["is_quantized"] = True
logger.info("Detected BNB-quantized model: %s", item.name)
except Exception:
pass

View file

@ -92,6 +92,24 @@ def _resolve_base_model(model_name: str) -> str:
except Exception as exc:
logger.debug("Could not read %s: %s", adapter_cfg_path, exc)
# --- config.json fallback (works for both LoRA and full fine-tune) ------
config_json_path = local_path / "config.json"
if config_json_path.is_file():
try:
with open(config_json_path) as f:
cfg = json.load(f)
# Unsloth writes "model_name"; HF writes "_name_or_path"
base = cfg.get("model_name") or cfg.get("_name_or_path")
if base and base != str(local_path):
logger.info(
"Resolved checkpoint '%s' → base model '%s' (via config.json)",
model_name,
base,
)
return base
except Exception as exc:
logger.debug("Could not read %s: %s", config_json_path, exc)
# --- Only try the heavier fallback for local directories ----------------
if local_path.is_dir():
try:
@ -126,6 +144,27 @@ def _check_tokenizer_config_needs_v5(model_name: str) -> bool:
if model_name in _tokenizer_class_cache:
return _tokenizer_class_cache[model_name]
# --- Check local tokenizer_config.json first ---------------------------
local_path = Path(model_name)
local_tc = local_path / "tokenizer_config.json"
if local_tc.is_file():
try:
with open(local_tc) as f:
data = json.load(f)
tokenizer_class = data.get("tokenizer_class", "")
result = tokenizer_class in _TRANSFORMERS_5_TOKENIZER_CLASSES
if result:
logger.info(
"Local check: %s uses tokenizer_class=%s (requires transformers 5.x)",
model_name,
tokenizer_class,
)
_tokenizer_class_cache[model_name] = result
return result
except Exception as exc:
logger.debug("Could not read %s: %s", local_tc, exc)
# --- Fall back to fetching from HuggingFace ----------------------------
import urllib.request
url = f"https://huggingface.co/{model_name}/raw/main/tokenizer_config.json"

View file

@ -31,6 +31,7 @@ export interface ModelCheckpoints {
base_model?: string | null;
peft_type?: string | null;
lora_rank?: number | null;
is_quantized?: boolean;
}
export interface CheckpointListResponse {

View file

@ -18,9 +18,13 @@ import { EXPORT_METHODS, type ExportMethod } from "../constants";
interface MethodPickerProps {
value: ExportMethod | null;
onChange: (v: ExportMethod) => void;
/** Methods that should be shown but disabled (greyed out, not clickable). */
disabledMethods?: ExportMethod[];
/** Optional reason shown in a tooltip on disabled methods. */
disabledReason?: string;
}
export function MethodPicker({ value, onChange }: MethodPickerProps) {
export function MethodPicker({ value, onChange, disabledMethods = [], disabledReason }: MethodPickerProps) {
return (
<div data-tour="export-method" className="flex flex-col gap-3">
<span className="flex items-center gap-1.5 text-xs font-medium text-muted-foreground">
@ -50,16 +54,21 @@ export function MethodPicker({ value, onChange }: MethodPickerProps) {
<div className="grid grid-cols-3 gap-3">
{EXPORT_METHODS.map((m) => {
const selected = value === m.value;
return (
const isDisabled = disabledMethods.includes(m.value);
const card = (
<button
key={m.value}
type="button"
onClick={() => onChange(m.value)}
disabled={isDisabled}
onClick={() => !isDisabled && onChange(m.value)}
className={cn(
"flex items-start gap-3 rounded-xl p-4 text-left ring-1 transition-all",
selected
? "ring-2 ring-primary bg-primary/5"
: "ring-border hover:-translate-y-0.5 hover:shadow-sm",
isDisabled
? "ring-border opacity-40 cursor-not-allowed"
: selected
? "ring-2 ring-primary bg-primary/5"
: "ring-border hover:-translate-y-0.5 hover:shadow-sm",
)}
>
<div
@ -125,6 +134,17 @@ export function MethodPicker({ value, onChange }: MethodPickerProps) {
</div>
</button>
);
if (isDisabled && disabledReason) {
return (
<Tooltip key={m.value}>
<TooltipTrigger asChild={true}>{card}</TooltipTrigger>
<TooltipContent>{disabledReason}</TooltipContent>
</Tooltip>
);
}
return card;
})}
</div>
</div>

View file

@ -122,6 +122,7 @@ export function ExportPage() {
// Derive training info from selected model's API metadata
const baseModelName = selectedModelData?.base_model ?? "—";
const isAdapter = !!selectedModelData?.peft_type;
const isQuantized = !!selectedModelData?.is_quantized;
const loraRank = selectedModelData?.lora_rank ?? null;
const trainingMethodLabel = selectedModelData?.peft_type
? "LoRA / QLoRA"
@ -132,6 +133,17 @@ export function ExportPage() {
setCheckpoint(null);
}, [selectedModelIdx]);
// Auto-reset export method if incompatible with the selected model type
useEffect(() => {
if (!isAdapter && (exportMethod === "merged" || exportMethod === "lora")) {
setExportMethod(null);
}
// Quantized non-PEFT models can't export to any format
if (!isAdapter && isQuantized && exportMethod !== null) {
setExportMethod(null);
}
}, [isAdapter, isQuantized, exportMethod]);
const handleMethodChange = (method: ExportMethod) => {
setExportMethod(method);
if (method !== "gguf") {
@ -465,7 +477,24 @@ export function ExportPage() {
</div>
</div>
<MethodPicker value={exportMethod} onChange={handleMethodChange} />
<MethodPicker
value={exportMethod}
onChange={handleMethodChange}
disabledMethods={
!isAdapter && isQuantized
? ["merged", "lora", "gguf"]
: !isAdapter
? ["merged", "lora"]
: []
}
disabledReason={
!isAdapter && isQuantized
? "Pre-quantized (BNB 4-bit) models cannot be exported without LoRA adapters"
: !isAdapter
? "Not available for full fine-tune checkpoints (no LoRA adapters)"
: undefined
}
/>
<AnimatePresence>
{exportMethod === "gguf" && (

View file

@ -1976,15 +1976,46 @@ def unsloth_save_pretrained_gguf(
fix_bos_token, old_chat_template = fix_tokenizer_bos_token(tokenizer)
# Step 4: Save/merge model to 16-bit format
print(
f'Unsloth: Merging model weights to {"mxfp4" if is_gpt_oss else "16-bit"} format...'
is_peft_model = isinstance(self, PeftModelForCausalLM) or isinstance(
self, PeftModel
)
try:
# Call unsloth_generic_save directly (it's in the same file)
unsloth_generic_save(**arguments)
except Exception as e:
raise RuntimeError(f"Failed to save/merge model: {e}")
if is_peft_model:
print(
f'Unsloth: Merging model weights to {"mxfp4" if is_gpt_oss else "16-bit"} format...'
)
try:
# Call unsloth_generic_save directly (it's in the same file)
unsloth_generic_save(**arguments)
except Exception as e:
raise RuntimeError(f"Failed to save/merge model: {e}")
else:
# Non-PEFT model — checkpoint files already exist on disk.
# Point save_to_gguf at the original checkpoint path instead of
# re-saving to a temporary "model" subdirectory.
original_path = getattr(self.config, "_name_or_path", None)
if original_path and os.path.isdir(original_path):
print(
f"Unsloth: Model is not a PEFT model. Using existing checkpoint at {original_path}"
)
save_directory = original_path
# Persist tokenizer fixes (e.g. BOS token stripping) to disk
# so the GGUF converter picks up the corrected chat template.
if tokenizer is not None:
tokenizer.save_pretrained(save_directory)
else:
# Fallback: save the in-memory model to save_directory
print(
"Unsloth: Model is not a PEFT model. Saving directly without LoRA merge..."
)
os.makedirs(save_directory, exist_ok = True)
try:
self.save_pretrained(save_directory)
if tokenizer is not None:
tokenizer.save_pretrained(save_directory)
except Exception as e:
raise RuntimeError(f"Failed to save model: {e}")
if is_processor:
tokenizer = tokenizer.tokenizer