diff --git a/studio/backend/core/export/export.py b/studio/backend/core/export/export.py index a6f67c5ba..966e045b1 100644 --- a/studio/backend/core/export/export.py +++ b/studio/backend/core/export/export.py @@ -569,6 +569,27 @@ class ExportBackend: shutil.rmtree(str(sub), ignore_errors = True) logger.info(f"Cleaned up subdirectory: {sub.name}") + # For non-PEFT models, save_pretrained_gguf redirects to the + # checkpoint path, leaving a *_gguf directory in outputs/. + # Relocate any GGUFs from there and clean it up. + if self.current_checkpoint: + ckpt = Path(self.current_checkpoint) + gguf_dir = ckpt.parent / f"{ckpt.name}_gguf" + if gguf_dir.is_dir(): + for src in gguf_dir.glob("*.gguf"): + dest = os.path.join(abs_save_dir, src.name) + shutil.move(str(src), dest) + logger.info(f"Relocated GGUF: {src.name} → {abs_save_dir}/") + # Also relocate Ollama Modelfile if present + modelfile = gguf_dir / "Modelfile" + if modelfile.is_file(): + shutil.move( + str(modelfile), os.path.join(abs_save_dir, "Modelfile") + ) + logger.info(f"Relocated Modelfile → {abs_save_dir}/") + shutil.rmtree(str(gguf_dir), ignore_errors = True) + logger.info(f"Cleaned up intermediate GGUF dir: {gguf_dir}") + # Write export metadata so the Chat page can identify the base model self._write_export_metadata(abs_save_dir) diff --git a/studio/backend/models/models.py b/studio/backend/models/models.py index f721bbbf7..daa8eec90 100644 --- a/studio/backend/models/models.py +++ b/studio/backend/models/models.py @@ -41,6 +41,10 @@ class ModelCheckpoints(BaseModel): None, description = "LoRA rank (r) if applicable", ) + is_quantized: bool = Field( + False, + description = "Whether the model uses BNB quantization (e.g. bnb-4bit)", + ) class CheckpointListResponse(BaseModel): diff --git a/studio/backend/routes/models.py b/studio/backend/routes/models.py index 60be0559d..e70576244 100644 --- a/studio/backend/routes/models.py +++ b/studio/backend/routes/models.py @@ -1092,6 +1092,7 @@ async def list_checkpoints( base_model = metadata.get("base_model"), peft_type = metadata.get("peft_type"), lora_rank = metadata.get("lora_rank"), + is_quantized = metadata.get("is_quantized", False), ) for model_name, checkpoints, metadata in raw_models ] diff --git a/studio/backend/tests/test_transformers_version.py b/studio/backend/tests/test_transformers_version.py new file mode 100644 index 000000000..f3dae537c --- /dev/null +++ b/studio/backend/tests/test_transformers_version.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +"""Tests for transformers version detection with local checkpoint fallbacks.""" + +import json +import pytest +from pathlib import Path +from unittest.mock import patch + + +# --------------------------------------------------------------------------- +# We need to be able to import the module under test. The studio backend +# uses relative-style imports (``from utils.…``), so we add the backend +# directory to *sys.path* if it is not already there. +# --------------------------------------------------------------------------- +import sys + +_BACKEND_DIR = str(Path(__file__).resolve().parent.parent) +if _BACKEND_DIR not in sys.path: + sys.path.insert(0, _BACKEND_DIR) + +# Stub the custom logger before importing the module under test so it +# doesn't fail on the ``from loggers import get_logger`` line. +import types as _types + +_loggers_stub = _types.ModuleType("loggers") +_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name) +sys.modules.setdefault("loggers", _loggers_stub) + +from utils.transformers_version import ( + _resolve_base_model, + _check_tokenizer_config_needs_v5, + _tokenizer_class_cache, + needs_transformers_5, +) + + +# --------------------------------------------------------------------------- +# _resolve_base_model — config.json fallback +# --------------------------------------------------------------------------- + + +class TestResolveBaseModel: + """Tests for _resolve_base_model() local config fallbacks.""" + + def test_adapter_config_takes_priority(self, tmp_path: Path): + """adapter_config.json should be preferred over config.json.""" + adapter_cfg = {"base_model_name_or_path": "meta-llama/Llama-3-8B"} + config_cfg = {"_name_or_path": "different/model"} + (tmp_path / "adapter_config.json").write_text(json.dumps(adapter_cfg)) + (tmp_path / "config.json").write_text(json.dumps(config_cfg)) + + result = _resolve_base_model(str(tmp_path)) + assert result == "meta-llama/Llama-3-8B" + + def test_config_json_fallback_model_name(self, tmp_path: Path): + """config.json model_name should resolve when no adapter_config.""" + config_cfg = {"model_name": "Qwen/Qwen3.5-9B"} + (tmp_path / "config.json").write_text(json.dumps(config_cfg)) + + result = _resolve_base_model(str(tmp_path)) + assert result == "Qwen/Qwen3.5-9B" + + def test_config_json_fallback_name_or_path(self, tmp_path: Path): + """config.json _name_or_path should resolve as secondary fallback.""" + config_cfg = {"_name_or_path": "Qwen/Qwen3.5-9B"} + (tmp_path / "config.json").write_text(json.dumps(config_cfg)) + + result = _resolve_base_model(str(tmp_path)) + assert result == "Qwen/Qwen3.5-9B" + + def test_model_name_takes_priority_over_name_or_path(self, tmp_path: Path): + """model_name should be preferred over _name_or_path.""" + config_cfg = { + "model_name": "Qwen/Qwen3.5-9B", + "_name_or_path": "some/other-model", + } + (tmp_path / "config.json").write_text(json.dumps(config_cfg)) + + result = _resolve_base_model(str(tmp_path)) + assert result == "Qwen/Qwen3.5-9B" + + def test_config_json_skips_self_referencing(self, tmp_path: Path): + """config.json should be ignored if model_name == the checkpoint path.""" + config_cfg = {"model_name": str(tmp_path)} + (tmp_path / "config.json").write_text(json.dumps(config_cfg)) + + result = _resolve_base_model(str(tmp_path)) + # Should fall through, not return the self-referencing path + assert result == str(tmp_path) + + def test_no_config_files(self, tmp_path: Path): + """Returns original name when no config files are present.""" + result = _resolve_base_model(str(tmp_path)) + assert result == str(tmp_path) + + def test_plain_hf_id_passthrough(self): + """Plain HuggingFace model IDs pass through unchanged.""" + result = _resolve_base_model("meta-llama/Llama-3-8B") + assert result == "meta-llama/Llama-3-8B" + + +# --------------------------------------------------------------------------- +# _check_tokenizer_config_needs_v5 — local file check +# --------------------------------------------------------------------------- + + +class TestCheckTokenizerConfigNeedsV5: + """Tests for local tokenizer_config.json fallback.""" + + def setup_method(self): + _tokenizer_class_cache.clear() + + def test_local_tokenizer_config_v5(self, tmp_path: Path): + """Local tokenizer_config.json with v5 tokenizer should return True.""" + tc = {"tokenizer_class": "TokenizersBackend"} + (tmp_path / "tokenizer_config.json").write_text(json.dumps(tc)) + + result = _check_tokenizer_config_needs_v5(str(tmp_path)) + assert result is True + + def test_local_tokenizer_config_v4(self, tmp_path: Path): + """Local tokenizer_config.json with standard tokenizer should return False.""" + tc = {"tokenizer_class": "LlamaTokenizerFast"} + (tmp_path / "tokenizer_config.json").write_text(json.dumps(tc)) + + result = _check_tokenizer_config_needs_v5(str(tmp_path)) + assert result is False + + def test_local_file_skips_network(self, tmp_path: Path): + """When local file exists, no network request should be made.""" + tc = {"tokenizer_class": "LlamaTokenizerFast"} + (tmp_path / "tokenizer_config.json").write_text(json.dumps(tc)) + + with patch("urllib.request.urlopen") as mock_urlopen: + result = _check_tokenizer_config_needs_v5(str(tmp_path)) + mock_urlopen.assert_not_called() + assert result is False + + def test_result_is_cached(self, tmp_path: Path): + """Subsequent calls should use the cache.""" + tc = {"tokenizer_class": "TokenizersBackend"} + (tmp_path / "tokenizer_config.json").write_text(json.dumps(tc)) + + key = str(tmp_path) + _check_tokenizer_config_needs_v5(key) + assert key in _tokenizer_class_cache + assert _tokenizer_class_cache[key] is True + + +# --------------------------------------------------------------------------- +# needs_transformers_5 — integration-level +# --------------------------------------------------------------------------- + + +class TestNeedsTransformers5: + """Integration tests for the top-level needs_transformers_5() function.""" + + def setup_method(self): + _tokenizer_class_cache.clear() + + def test_qwen35_substring(self): + assert needs_transformers_5("Qwen/Qwen3.5-9B") is True + + def test_qwen3_30b_a3b_substring(self): + assert needs_transformers_5("Qwen/Qwen3-30B-A3B-Instruct-2507") is True + + def test_ministral_substring(self): + assert needs_transformers_5("mistralai/Ministral-3-8B-Instruct-2512") is True + + def test_llama_does_not_need_v5(self): + """Standard models should not trigger v5.""" + # Patch network call to avoid real fetch + with patch( + "utils.transformers_version._check_tokenizer_config_needs_v5", + return_value = False, + ): + assert needs_transformers_5("meta-llama/Llama-3-8B") is False + + def test_local_checkpoint_resolved_via_config(self, tmp_path: Path): + """A local checkpoint with config.json pointing to Qwen3.5 should need v5.""" + config_cfg = {"model_name": "Qwen/Qwen3.5-9B"} + (tmp_path / "config.json").write_text(json.dumps(config_cfg)) + + # _resolve_base_model is called by ensure_transformers_version, + # but needs_transformers_5 just does substring matching. + # We test the full resolution chain here: + resolved = _resolve_base_model(str(tmp_path)) + assert needs_transformers_5(resolved) is True diff --git a/studio/backend/utils/models/checkpoints.py b/studio/backend/utils/models/checkpoints.py index a7cb80f33..b6b2e11c2 100644 --- a/studio/backend/utils/models/checkpoints.py +++ b/studio/backend/utils/models/checkpoints.py @@ -76,6 +76,18 @@ def scan_checkpoints( elif config_file.exists(): cfg = json.loads(config_file.read_text()) metadata["base_model"] = cfg.get("_name_or_path") + + # Detect BNB quantization from config.json (present in both cases) + if config_file.exists(): + if "cfg" not in dir(): + cfg = json.loads(config_file.read_text()) + quant_cfg = cfg.get("quantization_config") + if ( + isinstance(quant_cfg, dict) + and quant_cfg.get("quant_method") == "bitsandbytes" + ): + metadata["is_quantized"] = True + logger.info("Detected BNB-quantized model: %s", item.name) except Exception: pass diff --git a/studio/backend/utils/transformers_version.py b/studio/backend/utils/transformers_version.py index 60b43500c..d8724de72 100644 --- a/studio/backend/utils/transformers_version.py +++ b/studio/backend/utils/transformers_version.py @@ -92,6 +92,24 @@ def _resolve_base_model(model_name: str) -> str: except Exception as exc: logger.debug("Could not read %s: %s", adapter_cfg_path, exc) + # --- config.json fallback (works for both LoRA and full fine-tune) ------ + config_json_path = local_path / "config.json" + if config_json_path.is_file(): + try: + with open(config_json_path) as f: + cfg = json.load(f) + # Unsloth writes "model_name"; HF writes "_name_or_path" + base = cfg.get("model_name") or cfg.get("_name_or_path") + if base and base != str(local_path): + logger.info( + "Resolved checkpoint '%s' → base model '%s' (via config.json)", + model_name, + base, + ) + return base + except Exception as exc: + logger.debug("Could not read %s: %s", config_json_path, exc) + # --- Only try the heavier fallback for local directories ---------------- if local_path.is_dir(): try: @@ -126,6 +144,27 @@ def _check_tokenizer_config_needs_v5(model_name: str) -> bool: if model_name in _tokenizer_class_cache: return _tokenizer_class_cache[model_name] + # --- Check local tokenizer_config.json first --------------------------- + local_path = Path(model_name) + local_tc = local_path / "tokenizer_config.json" + if local_tc.is_file(): + try: + with open(local_tc) as f: + data = json.load(f) + tokenizer_class = data.get("tokenizer_class", "") + result = tokenizer_class in _TRANSFORMERS_5_TOKENIZER_CLASSES + if result: + logger.info( + "Local check: %s uses tokenizer_class=%s (requires transformers 5.x)", + model_name, + tokenizer_class, + ) + _tokenizer_class_cache[model_name] = result + return result + except Exception as exc: + logger.debug("Could not read %s: %s", local_tc, exc) + + # --- Fall back to fetching from HuggingFace ---------------------------- import urllib.request url = f"https://huggingface.co/{model_name}/raw/main/tokenizer_config.json" diff --git a/studio/frontend/src/features/export/api/export-api.ts b/studio/frontend/src/features/export/api/export-api.ts index 5c01845cf..aff56c3e6 100644 --- a/studio/frontend/src/features/export/api/export-api.ts +++ b/studio/frontend/src/features/export/api/export-api.ts @@ -31,6 +31,7 @@ export interface ModelCheckpoints { base_model?: string | null; peft_type?: string | null; lora_rank?: number | null; + is_quantized?: boolean; } export interface CheckpointListResponse { diff --git a/studio/frontend/src/features/export/components/method-picker.tsx b/studio/frontend/src/features/export/components/method-picker.tsx index 9d55b9e4d..78cec2cd7 100644 --- a/studio/frontend/src/features/export/components/method-picker.tsx +++ b/studio/frontend/src/features/export/components/method-picker.tsx @@ -18,9 +18,13 @@ import { EXPORT_METHODS, type ExportMethod } from "../constants"; interface MethodPickerProps { value: ExportMethod | null; onChange: (v: ExportMethod) => void; + /** Methods that should be shown but disabled (greyed out, not clickable). */ + disabledMethods?: ExportMethod[]; + /** Optional reason shown in a tooltip on disabled methods. */ + disabledReason?: string; } -export function MethodPicker({ value, onChange }: MethodPickerProps) { +export function MethodPicker({ value, onChange, disabledMethods = [], disabledReason }: MethodPickerProps) { return (
@@ -50,16 +54,21 @@ export function MethodPicker({ value, onChange }: MethodPickerProps) {
{EXPORT_METHODS.map((m) => { const selected = value === m.value; - return ( + const isDisabled = disabledMethods.includes(m.value); + + const card = ( ); + + if (isDisabled && disabledReason) { + return ( + + {card} + {disabledReason} + + ); + } + + return card; })}
diff --git a/studio/frontend/src/features/export/export-page.tsx b/studio/frontend/src/features/export/export-page.tsx index ed3802295..edf5b666a 100644 --- a/studio/frontend/src/features/export/export-page.tsx +++ b/studio/frontend/src/features/export/export-page.tsx @@ -122,6 +122,7 @@ export function ExportPage() { // Derive training info from selected model's API metadata const baseModelName = selectedModelData?.base_model ?? "—"; const isAdapter = !!selectedModelData?.peft_type; + const isQuantized = !!selectedModelData?.is_quantized; const loraRank = selectedModelData?.lora_rank ?? null; const trainingMethodLabel = selectedModelData?.peft_type ? "LoRA / QLoRA" @@ -132,6 +133,17 @@ export function ExportPage() { setCheckpoint(null); }, [selectedModelIdx]); + // Auto-reset export method if incompatible with the selected model type + useEffect(() => { + if (!isAdapter && (exportMethod === "merged" || exportMethod === "lora")) { + setExportMethod(null); + } + // Quantized non-PEFT models can't export to any format + if (!isAdapter && isQuantized && exportMethod !== null) { + setExportMethod(null); + } + }, [isAdapter, isQuantized, exportMethod]); + const handleMethodChange = (method: ExportMethod) => { setExportMethod(method); if (method !== "gguf") { @@ -465,7 +477,24 @@ export function ExportPage() { - + {exportMethod === "gguf" && ( diff --git a/unsloth/save.py b/unsloth/save.py index 6e38d1e95..1759d86fb 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -1976,15 +1976,46 @@ def unsloth_save_pretrained_gguf( fix_bos_token, old_chat_template = fix_tokenizer_bos_token(tokenizer) # Step 4: Save/merge model to 16-bit format - print( - f'Unsloth: Merging model weights to {"mxfp4" if is_gpt_oss else "16-bit"} format...' + is_peft_model = isinstance(self, PeftModelForCausalLM) or isinstance( + self, PeftModel ) - try: - # Call unsloth_generic_save directly (it's in the same file) - unsloth_generic_save(**arguments) - except Exception as e: - raise RuntimeError(f"Failed to save/merge model: {e}") + if is_peft_model: + print( + f'Unsloth: Merging model weights to {"mxfp4" if is_gpt_oss else "16-bit"} format...' + ) + try: + # Call unsloth_generic_save directly (it's in the same file) + unsloth_generic_save(**arguments) + + except Exception as e: + raise RuntimeError(f"Failed to save/merge model: {e}") + else: + # Non-PEFT model — checkpoint files already exist on disk. + # Point save_to_gguf at the original checkpoint path instead of + # re-saving to a temporary "model" subdirectory. + original_path = getattr(self.config, "_name_or_path", None) + if original_path and os.path.isdir(original_path): + print( + f"Unsloth: Model is not a PEFT model. Using existing checkpoint at {original_path}" + ) + save_directory = original_path + # Persist tokenizer fixes (e.g. BOS token stripping) to disk + # so the GGUF converter picks up the corrected chat template. + if tokenizer is not None: + tokenizer.save_pretrained(save_directory) + else: + # Fallback: save the in-memory model to save_directory + print( + "Unsloth: Model is not a PEFT model. Saving directly without LoRA merge..." + ) + os.makedirs(save_directory, exist_ok = True) + try: + self.save_pretrained(save_directory) + if tokenizer is not None: + tokenizer.save_pretrained(save_directory) + except Exception as e: + raise RuntimeError(f"Failed to save model: {e}") if is_processor: tokenizer = tokenizer.tokenizer