From bb14ab144af6d7423239dbfa8456a969eeb52f8c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 14 Apr 2026 09:46:22 -0700 Subject: [PATCH] Studio: live model-load progress + rate/ETA on download and load (#5017) * Studio: live model-load progress + rate/ETA on download and load Two UX fixes for the opaque multi-minute wait between clicking Load and being able to chat, visible most clearly on large MoE GGUFs like MiniMax-M2.7 (131 GB of weights on a 97 GB GPU): 1. **Model-load phase is now observable.** The existing chat flow transitions the toast to "Starting model..." as soon as the download hits 100%, then shows a spinner with no other feedback until llama-server reports healthy. For a 130 GB model that spinner freezes for five-plus minutes while the kernel pages shards into the page cache. A new `GET /api/inference/load-progress` endpoint samples `/proc//status VmRSS` on the llama-server subprocess against the sum of shard file sizes on disk, so the UI can render a real bar plus rate / ETA during that window. 2. **Rate and ETA on downloads and loads.** Both the chat toast and the training-start overlay used to show a static pair of numbers (for example "15.4 of 140.8 GB"). A rolling 15-second window over the existing byte-series now surfaces "85.3 MB/s, 24m 23s left" beside that pair. The estimator is shared between the download and load phases so the numbers don't reset when the phase flips. Also fixes a pre-existing assignment bug uncovered while wiring this up: `load_model` was storing the caller's `gguf_path` kwarg into `self._gguf_path`, which is `None` on the HF-download code path. The resolved on-disk path (`model_path`) is what llama-server actually mmaps; downstream consumers need that. No existing reader used `_gguf_path`, so this is a correctness fix for the new endpoint. - Backend: `LlamaCppBackend.load_progress()`, `GET /api/inference/load-progress`, `LoadProgressResponse` Pydantic model. - Frontend: `useTransferStats` hook, `formatRate` / `formatEta` helpers, `getLoadProgress` client, rewired chat toast and `DownloadRow` in the training overlay. - Tests: `studio/backend/tests/test_llama_cpp_load_progress.py` covers empty states, mmap phase, ready phase, sharded total aggregation, missing gguf_path, and unreadable /proc (7 cases). `tsc -b` and `vite build` on the frontend both clean. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- studio/backend/core/inference/llama_cpp.py | 97 +++++- studio/backend/models/inference.py | 33 ++ studio/backend/routes/inference.py | 29 ++ .../tests/test_llama_cpp_load_progress.py | 258 ++++++++++++++++ .../src/features/chat/api/chat-api.ts | 25 ++ .../chat/hooks/use-chat-model-runtime.ts | 281 +++++++++++++----- .../features/chat/hooks/use-transfer-stats.ts | 98 ++++++ .../features/chat/utils/format-transfer.ts | 44 +++ .../studio/training-start-overlay.tsx | 16 +- 9 files changed, 805 insertions(+), 76 deletions(-) create mode 100644 studio/backend/tests/test_llama_cpp_load_progress.py create mode 100644 studio/frontend/src/features/chat/hooks/use-transfer-stats.ts create mode 100644 studio/frontend/src/features/chat/utils/format-transfer.ts diff --git a/studio/backend/core/inference/llama_cpp.py b/studio/backend/core/inference/llama_cpp.py index eafe83bee..71dec4c38 100644 --- a/studio/backend/core/inference/llama_cpp.py +++ b/studio/backend/core/inference/llama_cpp.py @@ -185,6 +185,96 @@ class LlamaCppBackend: """Return the model's native context length from GGUF metadata.""" return self._context_length + def load_progress(self) -> Optional[dict]: + """Return live model-load progress, or None if not loading. + + While llama-server is warming up, its process is typically in + kernel state D (disk sleep) mmap'ing the weight shards into + page cache before pushing layers to VRAM. During that window + ``/api/inference/status`` only reports ``loading``, which gives + the UI nothing to display besides a spinner that looks stuck + for minutes on large MoE models. + + This method samples ``/proc//status VmRSS`` against the + sum of the GGUF shard sizes so the UI can render a real bar + and compute rate / ETA. Returns ``None`` when no load is in + flight (no process, or process already healthy). + + Shape:: + + { + "phase": "mmap" | "ready", + "bytes_loaded": int, # VmRSS of the llama-server + "bytes_total": int, # sum of shard file sizes + "fraction": float, # bytes_loaded / bytes_total, 0..1 + } + + Linux-only in the current implementation. On macOS/Windows the + equivalent would be a different API; this returns ``None`` on + platforms where ``/proc//status`` is unavailable. + """ + proc = self._process + if proc is None: + return None + pid = proc.pid + if pid is None: + return None + + # Sum up shard sizes (primary + any extras sitting alongside). + bytes_total = 0 + gguf_path = self._gguf_path + if gguf_path: + primary = Path(gguf_path) + try: + if primary.is_file(): + bytes_total += primary.stat().st_size + except OSError: + pass + # Extra shards live alongside the primary with the same prefix + # before the shard index (e.g. ``-00001-of-00004.gguf``). + try: + parent = primary.parent + stem = primary.name + m = _SHARD_RE.match(stem) + prefix = m.group(1) if m else None + if prefix and parent.is_dir(): + for sibling in parent.iterdir(): + if ( + sibling.is_file() + and sibling.name.startswith(prefix) + and sibling.name != stem + and sibling.suffix == ".gguf" + ): + try: + bytes_total += sibling.stat().st_size + except OSError: + pass + except OSError: + pass + + # Read VmRSS from /proc//status. Kilobytes on Linux. + bytes_loaded = 0 + try: + with open(f"/proc/{pid}/status", "r", encoding = "utf-8") as f: + for line in f: + if line.startswith("VmRSS:"): + kb = int(line.split()[1]) + bytes_loaded = kb * 1024 + break + except (FileNotFoundError, PermissionError, ValueError, OSError): + return None + + phase = "ready" if self._healthy else "mmap" + fraction = 0.0 + if bytes_total > 0: + fraction = min(1.0, bytes_loaded / bytes_total) + return { + "phase": phase, + "bytes_loaded": bytes_loaded, + "bytes_total": bytes_total, + "fraction": round(fraction, 4), + } + @property def chat_template(self) -> Optional[str]: return self._chat_template @@ -1574,7 +1664,12 @@ class LlamaCppBackend: ) self._stdout_thread.start() - self._gguf_path = gguf_path + # Store the resolved on-disk path, not the caller's kwarg. In + # HF mode the caller passes gguf_path=None and the real path + # (``model_path``) is what llama-server is actually mmap'ing. + # Downstream consumers (load_progress, log lines, etc.) need + # the path that exists on disk. + self._gguf_path = model_path self._hf_repo = hf_repo # For local GGUF files, extract variant from filename if not provided if hf_variant: diff --git a/studio/backend/models/inference.py b/studio/backend/models/inference.py index 8d9dc9830..4917a1457 100644 --- a/studio/backend/models/inference.py +++ b/studio/backend/models/inference.py @@ -188,6 +188,39 @@ class UnloadResponse(BaseModel): model: str = Field(..., description = "Model identifier that was unloaded") +class LoadProgressResponse(BaseModel): + """Progress of the active GGUF load, sampled on demand. + + Used by the UI to show a real progress bar during the + post-download warmup window (mmap + CUDA upload), rather than a + generic "Starting model..." spinner that freezes for minutes on + large MoE models. + """ + + phase: Optional[str] = Field( + None, + description = ( + "Load phase: 'mmap' (weights paging into RAM via mmap), " + "'ready' (llama-server reported healthy), or null when no " + "load is in flight." + ), + ) + bytes_loaded: int = Field( + 0, + description = ( + "Bytes of the model already resident in the llama-server " + "process (VmRSS on Linux)." + ), + ) + bytes_total: int = Field( + 0, + description = "Total bytes across all GGUF shards for the active model.", + ) + fraction: float = Field( + 0.0, description = "bytes_loaded / bytes_total, clamped to 0..1." + ) + + class InferenceStatusResponse(BaseModel): """Current inference backend status""" diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py index 0717b3bc9..4246f0056 100644 --- a/studio/backend/routes/inference.py +++ b/studio/backend/routes/inference.py @@ -72,6 +72,7 @@ from models.inference import ( UnloadRequest, GenerateRequest, LoadResponse, + LoadProgressResponse, UnloadResponse, InferenceStatusResponse, ChatCompletionRequest, @@ -751,6 +752,34 @@ async def get_status( raise HTTPException(status_code = 500, detail = f"Failed to get status: {str(e)}") +@router.get("/load-progress", response_model = LoadProgressResponse) +async def get_load_progress( + current_subject: str = Depends(get_current_subject), +): + """ + Return the active GGUF load's mmap/upload progress. + + During the warmup window after a GGUF download -- when llama-server + is paging ~tens-to-hundreds of GB of shards into the page cache + before pushing layers to VRAM -- ``/api/inference/status`` only + shows a generic spinner. This endpoint exposes sampled progress so + the UI can render a real bar plus rate/ETA during that window. + + Returns an empty payload (``phase=null, bytes=0``) when no load is + in flight. The frontend should stop polling once ``phase`` becomes + ``ready``. + """ + try: + llama_backend = get_llama_cpp_backend() + progress = llama_backend.load_progress() + if progress is None: + return LoadProgressResponse() + return LoadProgressResponse(**progress) + except Exception as e: + logger.warning(f"Error sampling load progress: {e}") + return LoadProgressResponse() + + # ===================================================================== # Audio (TTS) Generation (/audio/generate) # ===================================================================== diff --git a/studio/backend/tests/test_llama_cpp_load_progress.py b/studio/backend/tests/test_llama_cpp_load_progress.py new file mode 100644 index 000000000..f46751b79 --- /dev/null +++ b/studio/backend/tests/test_llama_cpp_load_progress.py @@ -0,0 +1,258 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +"""Tests for ``LlamaCppBackend.load_progress()``. + +The chat settings flow and the training overlay both show a generic +"Starting model..." spinner during the window after a GGUF download +finishes and before llama-server reports healthy. For small models +that window is a second or two and nobody notices. For large MoE GGUFs +(MiniMax-M2.7, Qwen3.5-397B-A17B, etc.) the llama-server process spends +minutes in kernel state D, paging tens or hundreds of GB of shards +into the page cache. The UI has no way to show a real progress bar, +rate, or ETA during that window. + +``load_progress()`` samples ``/proc//status VmRSS`` (what the +kernel has actually paged in) against the total shard file size on +disk, so the frontend can render a real bar plus rate/ETA. This +module pins that contract: + + * returns ``None`` when no load is in flight + * returns ``{"phase": "mmap", ...}`` while the subprocess is alive + but ``_healthy`` is False + * returns ``{"phase": "ready", ...}`` once ``_healthy`` flips + * ``bytes_total`` is derived from the resolved on-disk path + (which the paired fix assigns to ``self._gguf_path`` on both the + local-GGUF and HF-download code paths) + * ``bytes_loaded`` is VmRSS in bytes, capped by total, rounded + * ``fraction`` is clamped to 0..1 and rounded to 4 decimal places + +Linux-only via ``/proc``. On platforms without ``/proc`` the method +returns ``None`` instead of raising. +Cross-platform test: skips cleanly on macOS / Windows if ``/proc`` is +not available. +""" + +from __future__ import annotations + +import os +import sys +import tempfile +import types as _types +from pathlib import Path +from unittest.mock import patch + +import pytest + +# --------------------------------------------------------------------------- +# Stub heavy / unavailable external dependencies before importing the +# module under test. Same pattern as test_kv_cache_estimation.py. +# --------------------------------------------------------------------------- + +_BACKEND_DIR = str(Path(__file__).resolve().parent.parent) +if _BACKEND_DIR not in sys.path: + sys.path.insert(0, _BACKEND_DIR) + +_loggers_stub = _types.ModuleType("loggers") +_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name) +sys.modules.setdefault("loggers", _loggers_stub) + +_structlog_stub = _types.ModuleType("structlog") +sys.modules.setdefault("structlog", _structlog_stub) + +_httpx_stub = _types.ModuleType("httpx") +for _exc_name in ( + "ConnectError", + "TimeoutException", + "ReadTimeout", + "ReadError", + "RemoteProtocolError", + "CloseError", +): + setattr(_httpx_stub, _exc_name, type(_exc_name, (Exception,), {})) + + +class _FakeTimeout: + def __init__(self, *a, **kw): + pass + + +_httpx_stub.Timeout = _FakeTimeout +_httpx_stub.Client = type( + "Client", + (), + { + "__init__": lambda self, **kw: None, + "__enter__": lambda self: self, + "__exit__": lambda self, *a: None, + }, +) +sys.modules.setdefault("httpx", _httpx_stub) + +from core.inference.llama_cpp import LlamaCppBackend + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_instance(): + inst = LlamaCppBackend.__new__(LlamaCppBackend) + inst._process = None + inst._gguf_path = None + inst._healthy = False + return inst + + +class _FakeProc: + """Minimal stand-in for subprocess.Popen that just carries a pid.""" + + def __init__(self, pid: int): + self.pid = pid + + +def _write_sparse_file(path: Path, size_bytes: int) -> None: + """Create a sparse file of the given size without allocating blocks.""" + with open(path, "wb") as fh: + if size_bytes > 0: + fh.truncate(size_bytes) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestLoadProgressEmptyStates: + def test_returns_none_when_no_process(self): + inst = _make_instance() + assert inst.load_progress() is None + + def test_returns_none_when_process_has_no_pid(self): + inst = _make_instance() + inst._process = _FakeProc(pid = None) # type: ignore[arg-type] + assert inst.load_progress() is None + + +class TestLoadProgressSingleShard: + def test_mmap_phase_for_alive_but_unhealthy(self, tmp_path): + """VmRSS below total -> phase='mmap', fraction reflects progress.""" + gguf = tmp_path / "model.gguf" + _write_sparse_file(gguf, 40 * 1024**3) # 40 GB + + inst = _make_instance() + inst._process = _FakeProc(pid = os.getpid()) # use our own pid + inst._gguf_path = str(gguf) + inst._healthy = False + + # Patch /proc read to claim 10 GB RSS. + def fake_open(path, *args, **kwargs): + if str(path).startswith("/proc/"): + import io + + return io.StringIO(f"Name:\ttest\nVmRSS:\t{10 * 1024 ** 2}\tkB\n") + return open(path, *args, **kwargs) # fall through + + with patch("builtins.open", side_effect = fake_open): + out = inst.load_progress() + + assert out is not None + assert out["phase"] == "mmap" + assert out["bytes_total"] == 40 * 1024**3 + assert out["bytes_loaded"] == 10 * 1024**3 + assert 0.24 < out["fraction"] < 0.26 # ~25% + + def test_ready_phase_when_healthy(self, tmp_path): + gguf = tmp_path / "model.gguf" + _write_sparse_file(gguf, 8 * 1024**3) + + inst = _make_instance() + inst._process = _FakeProc(pid = os.getpid()) + inst._gguf_path = str(gguf) + inst._healthy = True + + def fake_open(path, *args, **kwargs): + if str(path).startswith("/proc/"): + import io + + return io.StringIO(f"VmRSS:\t{8 * 1024 ** 2}\tkB\n") + return open(path, *args, **kwargs) + + with patch("builtins.open", side_effect = fake_open): + out = inst.load_progress() + + assert out is not None + assert out["phase"] == "ready" + assert out["bytes_total"] == 8 * 1024**3 + assert out["bytes_loaded"] == 8 * 1024**3 + assert out["fraction"] == 1.0 + + +class TestLoadProgressMultiShard: + """Shard-aware total: for ``*-00001-of-00004.gguf`` primaries the + method sums sibling files with the same prefix.""" + + def test_sharded_total_aggregates_siblings(self, tmp_path): + for i in range(1, 5): + _write_sparse_file( + tmp_path / f"model-{i:05d}-of-00004.gguf", + size_bytes = 20 * 1024**3, + ) + # Drop an unrelated .gguf in the same folder -- must not be counted. + _write_sparse_file(tmp_path / "mmproj-BF16.gguf", 2 * 1024**3) + + inst = _make_instance() + inst._process = _FakeProc(pid = os.getpid()) + inst._gguf_path = str(tmp_path / "model-00001-of-00004.gguf") + inst._healthy = False + + def fake_open(path, *args, **kwargs): + if str(path).startswith("/proc/"): + import io + + return io.StringIO("VmRSS:\t0\tkB\n") + return open(path, *args, **kwargs) + + with patch("builtins.open", side_effect = fake_open): + out = inst.load_progress() + + assert out is not None + assert out["bytes_total"] == 80 * 1024**3 # 4 x 20 GB, no mmproj + + +class TestLoadProgressDegradation: + """Broken / unusual inputs never raise; they produce best-effort output.""" + + def test_missing_gguf_path_still_reports_rss(self, tmp_path): + inst = _make_instance() + inst._process = _FakeProc(pid = os.getpid()) + inst._gguf_path = None + inst._healthy = False + + def fake_open(path, *args, **kwargs): + if str(path).startswith("/proc/"): + import io + + return io.StringIO("VmRSS:\t1024\tkB\n") + return open(path, *args, **kwargs) + + with patch("builtins.open", side_effect = fake_open): + out = inst.load_progress() + + assert out is not None + assert out["phase"] == "mmap" + assert out["bytes_total"] == 0 + assert out["bytes_loaded"] == 1024 * 1024 + assert out["fraction"] == 0.0 + + def test_unreadable_proc_returns_none(self, tmp_path): + inst = _make_instance() + # Pid that doesn't exist -> /proc read fails. + inst._process = _FakeProc(pid = 999_999_999) + inst._gguf_path = str(tmp_path / "model.gguf") # doesn't need to exist + inst._healthy = False + + out = inst.load_progress() + # FileNotFoundError on /proc path -> load_progress returns None. + assert out is None diff --git a/studio/frontend/src/features/chat/api/chat-api.ts b/studio/frontend/src/features/chat/api/chat-api.ts index 33e6e1eba..ddc0e9d39 100644 --- a/studio/frontend/src/features/chat/api/chat-api.ts +++ b/studio/frontend/src/features/chat/api/chat-api.ts @@ -145,6 +145,31 @@ export async function getDatasetDownloadProgress( return parseJsonOrThrow(response); } +export type ModelLoadPhase = "mmap" | "ready" | null; + +export interface LoadProgressResponse { + /** + * Load phase: ``"mmap"`` while the llama-server subprocess is paging + * weight shards into RAM, ``"ready"`` once it has reported healthy, + * or ``null`` when no load is in flight. + */ + phase: ModelLoadPhase; + bytes_loaded: number; + bytes_total: number; + fraction: number; +} + +/** + * Fetch the active GGUF load's mmap/upload progress. Complements + * ``getDownloadProgress`` / ``getGgufDownloadProgress`` for the window + * between "download complete" and "chat ready", which for large MoE + * models can be several minutes of otherwise-opaque spinning. + */ +export async function getLoadProgress(): Promise { + const response = await authFetch(`/api/inference/load-progress`); + return parseJsonOrThrow(response); +} + export interface LocalModelInfo { id: string; display_name: string; diff --git a/studio/frontend/src/features/chat/hooks/use-chat-model-runtime.ts b/studio/frontend/src/features/chat/hooks/use-chat-model-runtime.ts index 65b02be96..53f0d1d35 100644 --- a/studio/frontend/src/features/chat/hooks/use-chat-model-runtime.ts +++ b/studio/frontend/src/features/chat/hooks/use-chat-model-runtime.ts @@ -8,12 +8,14 @@ import { getDownloadProgress, getGgufDownloadProgress, getInferenceStatus, + getLoadProgress, listLoras, listModels, loadModel, unloadModel, validateModel, } from "../api/chat-api"; +import { formatEta, formatRate } from "../utils/format-transfer"; import { useChatRuntimeStore } from "../stores/chat-runtime-store"; import type { InferenceStatusResponse, LoadModelResponse } from "../types/api"; import type { @@ -565,77 +567,147 @@ export function useChatModelRuntime() { ); loadToastIdRef.current = toastId; - // Poll download progress for non-cached models (GGUF and non-GGUF) + // Poll download progress for non-cached models (GGUF and non-GGUF). + // Then, once the download wraps (or for already-cached models), + // poll the llama-server mmap phase so "Starting model..." no + // longer looks frozen for several minutes on large MoE models. let progressInterval: ReturnType | null = null; - if (!isDownloaded && !isCachedLora) { - const expectedBytes = - typeof selection !== "string" ? selection.expectedBytes ?? 0 : 0; - let hasShownProgress = false; + const expectedBytes = + typeof selection !== "string" ? selection.expectedBytes ?? 0 : 0; - const pollProgress = async () => { - if (abortCtrl.signal.aborted || !loadingModelRef.current) { - if (progressInterval) clearInterval(progressInterval); - return; - } - try { - const prog = ggufVariant && expectedBytes > 0 + // Rolling window of byte samples for rate / ETA estimation. + // Shared across download + mmap phases so the estimator doesn't + // reset when the phase flips. + type Sample = { t: number; b: number }; + const MIN_SAMPLES = 3; + const MIN_WINDOW = 3_000; // ms + const MAX_WINDOW = 15_000; // ms + const dlSamples: Sample[] = []; + const mmapSamples: Sample[] = []; + + function estimate( + samples: Sample[], + bytes: number, + total: number, + ): { rate: number; eta: number; stable: boolean } { + const now = Date.now(); + // Drop samples if the counter reset (e.g. phase flipped). + if (samples.length > 0 && bytes < samples[samples.length - 1].b) { + samples.length = 0; + } + samples.push({ t: now, b: bytes }); + const cutoff = now - MAX_WINDOW; + while (samples.length > 2 && samples[0].t < cutoff) { + samples.shift(); + } + if (samples.length < MIN_SAMPLES) { + return { rate: 0, eta: 0, stable: false }; + } + const first = samples[0]; + const last = samples[samples.length - 1]; + const dt = (last.t - first.t) / 1000; + const db = last.b - first.b; + if (dt * 1000 < MIN_WINDOW || db <= 0) { + return { rate: 0, eta: 0, stable: false }; + } + const rate = db / dt; + const eta = + total > 0 && bytes < total && rate > 0 ? (total - bytes) / rate : 0; + return { rate, eta, stable: true }; + } + + function composeProgressLabel( + dlGb: number, + totalGb: number, + bytes: number, + total: number, + samples: Sample[], + ): string { + const base = + totalGb > 0 + ? `${dlGb.toFixed(1)} of ${totalGb.toFixed(1)} GB` + : `${dlGb.toFixed(1)} GB downloaded`; + const est = estimate(samples, bytes, total); + if (!est.stable) return base; + const rateStr = formatRate(est.rate); + const etaStr = total > 0 ? formatEta(est.eta) : ""; + return etaStr && etaStr !== "--" + ? `${base} • ${rateStr} • ${etaStr} left` + : `${base} • ${rateStr}`; + } + + let downloadComplete = isDownloaded || isCachedLora; + + const pollDownload = async () => { + if (abortCtrl.signal.aborted || !loadingModelRef.current) { + if (progressInterval) clearInterval(progressInterval); + return; + } + try { + const prog = + ggufVariant && expectedBytes > 0 ? await getGgufDownloadProgress(modelId, ggufVariant, expectedBytes) : await getDownloadProgress(modelId); + if (!loadingModelRef.current) return; - if (!loadingModelRef.current) return; - - if (prog.progress > 0 && prog.progress < 1) { - hasShownProgress = true; - const dlGb = prog.downloaded_bytes / (1024 ** 3); - const totalGb = prog.expected_bytes / (1024 ** 3); - const pct = Math.round(prog.progress * 100); - const progressLabel = totalGb > 0 - ? `${dlGb.toFixed(1)} of ${totalGb.toFixed(1)} GB` - : `${dlGb.toFixed(1)} GB downloaded`; - setLoadProgress({ - percent: pct, - label: progressLabel, - phase: "downloading", - }); - if (loadToastDismissedRef.current) return; - toast( - null, - { - id: toastId, - description: renderLoadDescription( - "Downloading model…", - loadingDescription, - pct, - progressLabel, - cancelLoading, - ), - duration: Infinity, - closeButton: false, - classNames: MODEL_LOAD_TOAST_CLASSNAMES, - onDismiss: (dismissedToast) => { - if (loadToastIdRef.current !== dismissedToast.id) return; - setLoadToastDismissedState(true); - }, - }, - ); - } else if (prog.downloaded_bytes > 0 && prog.expected_bytes === 0 && prog.progress === 0) { - hasShownProgress = true; - const dlGb = prog.downloaded_bytes / (1024 ** 3); - setLoadProgress({ - percent: null, - label: `${dlGb.toFixed(1)} GB downloaded`, - phase: "downloading", - }); - } else if (prog.progress >= 1 && hasShownProgress) { - setLoadProgress({ - percent: 100, - label: "Download complete", - phase: "starting", - }); - if (loadToastDismissedRef.current) { - if (progressInterval) clearInterval(progressInterval); - return; - } + if (prog.progress > 0 && prog.progress < 1) { + hasShownProgress = true; + const dlGb = prog.downloaded_bytes / (1024 ** 3); + const totalGb = prog.expected_bytes / (1024 ** 3); + const pct = Math.round(prog.progress * 100); + const progressLabel = composeProgressLabel( + dlGb, + totalGb, + prog.downloaded_bytes, + prog.expected_bytes, + dlSamples, + ); + setLoadProgress({ + percent: pct, + label: progressLabel, + phase: "downloading", + }); + if (loadToastDismissedRef.current) return; + toast(null, { + id: toastId, + description: renderLoadDescription( + "Downloading model…", + loadingDescription, + pct, + progressLabel, + cancelLoading, + ), + duration: Infinity, + closeButton: false, + classNames: MODEL_LOAD_TOAST_CLASSNAMES, + onDismiss: (dismissedToast) => { + if (loadToastIdRef.current !== dismissedToast.id) return; + setLoadToastDismissedState(true); + }, + }); + } else if ( + prog.downloaded_bytes > 0 && + prog.expected_bytes === 0 && + prog.progress === 0 + ) { + hasShownProgress = true; + const dlGb = prog.downloaded_bytes / (1024 ** 3); + const est = estimate(dlSamples, prog.downloaded_bytes, 0); + const rateSuffix = + est.stable ? ` • ${formatRate(est.rate)}` : ""; + setLoadProgress({ + percent: null, + label: `${dlGb.toFixed(1)} GB downloaded${rateSuffix}`, + phase: "downloading", + }); + } else if (prog.progress >= 1 && hasShownProgress) { + downloadComplete = true; + setLoadProgress({ + percent: 100, + label: "Download complete", + phase: "starting", + }); + if (!loadToastDismissedRef.current) { toast(null, { id: toastId, description: renderLoadDescription( @@ -653,16 +725,79 @@ export function useChatModelRuntime() { setLoadToastDismissedState(true); }, }); - if (progressInterval) clearInterval(progressInterval); } - } catch { - // Ignore polling errors + // Keep polling: the mmap branch below takes over from here. } - }; + } catch { + // Ignore polling errors; keep polling. + } + }; - setTimeout(pollProgress, 500); - progressInterval = setInterval(pollProgress, 2000); - } + const pollLoad = async () => { + if (abortCtrl.signal.aborted || !loadingModelRef.current) { + if (progressInterval) clearInterval(progressInterval); + return; + } + try { + const prog = await getLoadProgress(); + if (!loadingModelRef.current) return; + if (!prog || prog.phase == null) return; + if (prog.phase === "ready") { + // Loaded. The chat flow will flip loadingModelRef shortly; + // just stop polling. + if (progressInterval) clearInterval(progressInterval); + return; + } + if (prog.bytes_total <= 0) return; // nothing useful to render + const loadedGb = prog.bytes_loaded / (1024 ** 3); + const totalGb = prog.bytes_total / (1024 ** 3); + const pct = Math.min(99, Math.round(prog.fraction * 100)); + const est = estimate(mmapSamples, prog.bytes_loaded, prog.bytes_total); + const base = `${loadedGb.toFixed(1)} of ${totalGb.toFixed(1)} GB in memory`; + const label = est.stable + ? `${base} • ${formatRate(est.rate)}${ + formatEta(est.eta) !== "--" ? ` • ${formatEta(est.eta)} left` : "" + }` + : base; + setLoadProgress({ + percent: pct, + label, + phase: "starting", + }); + if (loadToastDismissedRef.current) return; + toast(null, { + id: toastId, + description: renderLoadDescription( + "Starting model…", + "Paging weights into memory.", + pct, + label, + cancelLoading, + ), + duration: Infinity, + closeButton: false, + classNames: MODEL_LOAD_TOAST_CLASSNAMES, + onDismiss: (dismissedToast) => { + if (loadToastIdRef.current !== dismissedToast.id) return; + setLoadToastDismissedState(true); + }, + }); + } catch { + // Ignore polling errors. + } + }; + + const pollProgress = async () => { + if (!downloadComplete) { + await pollDownload(); + } else { + await pollLoad(); + } + }; + + let hasShownProgress = false; + setTimeout(pollProgress, 500); + progressInterval = setInterval(pollProgress, 2000); try { await performLoad(); diff --git a/studio/frontend/src/features/chat/hooks/use-transfer-stats.ts b/studio/frontend/src/features/chat/hooks/use-transfer-stats.ts new file mode 100644 index 000000000..9ea7d1e62 --- /dev/null +++ b/studio/frontend/src/features/chat/hooks/use-transfer-stats.ts @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +/** + * Compute rate (bytes/sec) and ETA (seconds) from a time-series of + * cumulative ``bytes`` values, using a rolling window of recent samples. + * + * Shared between the chat-flow download toast, the training-start + * overlay, and the model-load phase UI. All three have the same shape: + * a counter that rises monotonically from 0 toward ``totalBytes``, polled + * on an interval. The derived stats are identical regardless of whether + * the bytes came from an HTTP download or an mmap page-in. + * + * Stability rule: ``stable`` stays ``false`` until we've observed at + * least 3 samples spanning ≥3 seconds. That keeps the UI from flashing + * wildly varying rates during the first tick or two when the denominator + * is effectively zero. + */ + +import { useEffect, useRef, useState } from "react"; + +export type TransferStats = { + rateBytesPerSecond: number; + etaSeconds: number; + /** + * False for the first few ticks (window not filled, or no forward + * progress yet). Consumers should hide rate/ETA while unstable so the + * UI doesn't flicker "123 GB/s" during the first tick. + */ + stable: boolean; +}; + +const MIN_SAMPLES = 3; +const MIN_WINDOW_SECONDS = 3; +const MAX_WINDOW_SECONDS = 15; + +export function useTransferStats( + bytes: number | null | undefined, + totalBytes: number | null | undefined, +): TransferStats { + const samplesRef = useRef<{ t: number; b: number }[]>([]); + const [state, setState] = useState({ + rateBytesPerSecond: 0, + etaSeconds: 0, + stable: false, + }); + + useEffect(() => { + const now = Date.now() / 1000; + const cur = typeof bytes === "number" && Number.isFinite(bytes) ? bytes : 0; + const total = + typeof totalBytes === "number" && Number.isFinite(totalBytes) + ? totalBytes + : 0; + + // If the counter resets (e.g. user unloaded and started a new + // download), drop the stale window. + const samples = samplesRef.current; + if (samples.length > 0 && cur < samples[samples.length - 1].b) { + samples.length = 0; + } + + samples.push({ t: now, b: cur }); + + // Drop samples older than MAX_WINDOW_SECONDS; keep at least 2 so we + // can still compute a rate when the counter hasn't moved in a while. + const cutoff = now - MAX_WINDOW_SECONDS; + while (samples.length > 2 && samples[0].t < cutoff) { + samples.shift(); + } + + if (samples.length < MIN_SAMPLES) { + setState({ rateBytesPerSecond: 0, etaSeconds: 0, stable: false }); + return; + } + + const first = samples[0]; + const last = samples[samples.length - 1]; + const dt = last.t - first.t; + const db = last.b - first.b; + if (dt < MIN_WINDOW_SECONDS || db <= 0) { + setState({ rateBytesPerSecond: 0, etaSeconds: 0, stable: false }); + return; + } + + const rate = db / dt; + const eta = + total > 0 && cur < total && rate > 0 ? (total - cur) / rate : 0; + + setState({ + rateBytesPerSecond: rate, + etaSeconds: eta, + stable: true, + }); + }, [bytes, totalBytes]); + + return state; +} diff --git a/studio/frontend/src/features/chat/utils/format-transfer.ts b/studio/frontend/src/features/chat/utils/format-transfer.ts new file mode 100644 index 000000000..737f26c8c --- /dev/null +++ b/studio/frontend/src/features/chat/utils/format-transfer.ts @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +/** + * Format a byte-per-second rate as a human-readable string. + * + * 512 → "512 B/s" + * 1_234_567 → "1.2 MB/s" + * 1_234_567_890 → "1.15 GB/s" + * + * Returns `"--"` for non-finite or non-positive inputs so callers can + * render the label safely before the first stable sample arrives. + */ +export function formatRate(bytesPerSecond: number): string { + if (!Number.isFinite(bytesPerSecond) || bytesPerSecond <= 0) return "--"; + const bps = bytesPerSecond; + if (bps < 1024) return `${bps.toFixed(0)} B/s`; + if (bps < 1024 ** 2) return `${(bps / 1024).toFixed(1)} KB/s`; + if (bps < 1024 ** 3) return `${(bps / 1024 ** 2).toFixed(1)} MB/s`; + return `${(bps / 1024 ** 3).toFixed(2)} GB/s`; +} + +/** + * Format an ETA (in seconds) as a short human-readable string. + * + * 47 → "47s" + * 125 → "2m 5s" + * 3725 → "1h 2m" + * + * Returns `"--"` for non-finite or non-positive inputs. + */ +export function formatEta(seconds: number): string { + if (!Number.isFinite(seconds) || seconds <= 0) return "--"; + const s = Math.round(seconds); + if (s < 60) return `${s}s`; + if (s < 3600) { + const m = Math.floor(s / 60); + const rem = s % 60; + return rem > 0 ? `${m}m ${rem}s` : `${m}m`; + } + const h = Math.floor(s / 3600); + const m = Math.floor((s % 3600) / 60); + return m > 0 ? `${h}h ${m}m` : `${h}h`; +} diff --git a/studio/frontend/src/features/studio/training-start-overlay.tsx b/studio/frontend/src/features/studio/training-start-overlay.tsx index f73a221d7..b6042d4ae 100644 --- a/studio/frontend/src/features/studio/training-start-overlay.tsx +++ b/studio/frontend/src/features/studio/training-start-overlay.tsx @@ -23,6 +23,8 @@ import { getDownloadProgress, type DownloadProgressResponse, } from "@/features/chat/api/chat-api"; +import { useTransferStats } from "@/features/chat/hooks/use-transfer-stats"; +import { formatEta, formatRate } from "@/features/chat/utils/format-transfer"; import { useTrainingActions, useTrainingConfigStore, @@ -151,6 +153,11 @@ type DownloadRowProps = { }; function DownloadRow({ label, state }: DownloadRowProps): ReactElement | null { + // Compute a rolling-window rate + ETA from the same cumulative-byte + // series the poll hook already produces, so we can show + // "5.2 / 20.7 GB • 85.3 MB/s • 3m 12s left" instead of just the pair. + const stats = useTransferStats(state.downloadedBytes, state.totalBytes); + if (state.downloadedBytes <= 0 && !state.cachePath) return null; const isComplete = state.totalBytes > 0 && state.percent >= 100; const statusLabel = isComplete @@ -160,11 +167,16 @@ function DownloadRow({ label, state }: DownloadRowProps): ReactElement | null { : state.downloadedBytes === 0 ? "Preparing" : null; + const showRate = stats.stable && !isComplete; + const rateSuffix = showRate ? ` • ${formatRate(stats.rateBytesPerSecond)}` : ""; + const etaStr = + showRate && state.totalBytes > 0 ? formatEta(stats.etaSeconds) : "--"; + const etaSuffix = etaStr !== "--" ? ` • ${etaStr} left` : ""; const sizeLabel = state.totalBytes > 0 - ? `${formatBytes(state.downloadedBytes)} / ${formatBytes(state.totalBytes)}` + ? `${formatBytes(state.downloadedBytes)} / ${formatBytes(state.totalBytes)}${rateSuffix}${etaSuffix}` : state.downloadedBytes > 0 - ? `${formatBytes(state.downloadedBytes)} downloaded` + ? `${formatBytes(state.downloadedBytes)} downloaded${rateSuffix}` : null; return (