Studio: live model-load progress + rate/ETA on download and load (#5017)

* Studio: live model-load progress + rate/ETA on download and load

Two UX fixes for the opaque multi-minute wait between clicking Load
and being able to chat, visible most clearly on large MoE GGUFs like
MiniMax-M2.7 (131 GB of weights on a 97 GB GPU):

1. **Model-load phase is now observable.** The existing chat flow
   transitions the toast to "Starting model..." as soon as the
   download hits 100%, then shows a spinner with no other feedback
   until llama-server reports healthy. For a 130 GB model that spinner
   freezes for five-plus minutes while the kernel pages shards into
   the page cache. A new `GET /api/inference/load-progress` endpoint
   samples `/proc/<pid>/status VmRSS` on the llama-server subprocess
   against the sum of shard file sizes on disk, so the UI can render
   a real bar plus rate / ETA during that window.

2. **Rate and ETA on downloads and loads.** Both the chat toast and
   the training-start overlay used to show a static pair of numbers
   (for example "15.4 of 140.8 GB"). A rolling 15-second window over
   the existing byte-series now surfaces "85.3 MB/s, 24m 23s left"
   beside that pair. The estimator is shared between the download
   and load phases so the numbers don't reset when the phase flips.

Also fixes a pre-existing assignment bug uncovered while wiring this
up: `load_model` was storing the caller's `gguf_path` kwarg into
`self._gguf_path`, which is `None` on the HF-download code path. The
resolved on-disk path (`model_path`) is what llama-server actually
mmaps; downstream consumers need that. No existing reader used
`_gguf_path`, so this is a correctness fix for the new endpoint.

- Backend: `LlamaCppBackend.load_progress()`, `GET /api/inference/load-progress`, `LoadProgressResponse` Pydantic model.
- Frontend: `useTransferStats` hook, `formatRate` / `formatEta` helpers, `getLoadProgress` client, rewired chat toast and `DownloadRow` in the training overlay.
- Tests: `studio/backend/tests/test_llama_cpp_load_progress.py` covers empty states, mmap phase, ready phase, sharded total aggregation, missing gguf_path, and unreadable /proc (7 cases). `tsc -b` and `vite build` on the frontend both clean.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Daniel Han 2026-04-14 09:46:22 -07:00 committed by GitHub
parent 514bb3a20e
commit bb14ab144a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 805 additions and 76 deletions

View file

@ -185,6 +185,96 @@ class LlamaCppBackend:
"""Return the model's native context length from GGUF metadata."""
return self._context_length
def load_progress(self) -> Optional[dict]:
"""Return live model-load progress, or None if not loading.
While llama-server is warming up, its process is typically in
kernel state D (disk sleep) mmap'ing the weight shards into
page cache before pushing layers to VRAM. During that window
``/api/inference/status`` only reports ``loading``, which gives
the UI nothing to display besides a spinner that looks stuck
for minutes on large MoE models.
This method samples ``/proc/<pid>/status VmRSS`` against the
sum of the GGUF shard sizes so the UI can render a real bar
and compute rate / ETA. Returns ``None`` when no load is in
flight (no process, or process already healthy).
Shape::
{
"phase": "mmap" | "ready",
"bytes_loaded": int, # VmRSS of the llama-server
"bytes_total": int, # sum of shard file sizes
"fraction": float, # bytes_loaded / bytes_total, 0..1
}
Linux-only in the current implementation. On macOS/Windows the
equivalent would be a different API; this returns ``None`` on
platforms where ``/proc/<pid>/status`` is unavailable.
"""
proc = self._process
if proc is None:
return None
pid = proc.pid
if pid is None:
return None
# Sum up shard sizes (primary + any extras sitting alongside).
bytes_total = 0
gguf_path = self._gguf_path
if gguf_path:
primary = Path(gguf_path)
try:
if primary.is_file():
bytes_total += primary.stat().st_size
except OSError:
pass
# Extra shards live alongside the primary with the same prefix
# before the shard index (e.g. ``-00001-of-00004.gguf``).
try:
parent = primary.parent
stem = primary.name
m = _SHARD_RE.match(stem)
prefix = m.group(1) if m else None
if prefix and parent.is_dir():
for sibling in parent.iterdir():
if (
sibling.is_file()
and sibling.name.startswith(prefix)
and sibling.name != stem
and sibling.suffix == ".gguf"
):
try:
bytes_total += sibling.stat().st_size
except OSError:
pass
except OSError:
pass
# Read VmRSS from /proc/<pid>/status. Kilobytes on Linux.
bytes_loaded = 0
try:
with open(f"/proc/{pid}/status", "r", encoding = "utf-8") as f:
for line in f:
if line.startswith("VmRSS:"):
kb = int(line.split()[1])
bytes_loaded = kb * 1024
break
except (FileNotFoundError, PermissionError, ValueError, OSError):
return None
phase = "ready" if self._healthy else "mmap"
fraction = 0.0
if bytes_total > 0:
fraction = min(1.0, bytes_loaded / bytes_total)
return {
"phase": phase,
"bytes_loaded": bytes_loaded,
"bytes_total": bytes_total,
"fraction": round(fraction, 4),
}
@property
def chat_template(self) -> Optional[str]:
return self._chat_template
@ -1574,7 +1664,12 @@ class LlamaCppBackend:
)
self._stdout_thread.start()
self._gguf_path = gguf_path
# Store the resolved on-disk path, not the caller's kwarg. In
# HF mode the caller passes gguf_path=None and the real path
# (``model_path``) is what llama-server is actually mmap'ing.
# Downstream consumers (load_progress, log lines, etc.) need
# the path that exists on disk.
self._gguf_path = model_path
self._hf_repo = hf_repo
# For local GGUF files, extract variant from filename if not provided
if hf_variant:

View file

@ -188,6 +188,39 @@ class UnloadResponse(BaseModel):
model: str = Field(..., description = "Model identifier that was unloaded")
class LoadProgressResponse(BaseModel):
"""Progress of the active GGUF load, sampled on demand.
Used by the UI to show a real progress bar during the
post-download warmup window (mmap + CUDA upload), rather than a
generic "Starting model..." spinner that freezes for minutes on
large MoE models.
"""
phase: Optional[str] = Field(
None,
description = (
"Load phase: 'mmap' (weights paging into RAM via mmap), "
"'ready' (llama-server reported healthy), or null when no "
"load is in flight."
),
)
bytes_loaded: int = Field(
0,
description = (
"Bytes of the model already resident in the llama-server "
"process (VmRSS on Linux)."
),
)
bytes_total: int = Field(
0,
description = "Total bytes across all GGUF shards for the active model.",
)
fraction: float = Field(
0.0, description = "bytes_loaded / bytes_total, clamped to 0..1."
)
class InferenceStatusResponse(BaseModel):
"""Current inference backend status"""

View file

@ -72,6 +72,7 @@ from models.inference import (
UnloadRequest,
GenerateRequest,
LoadResponse,
LoadProgressResponse,
UnloadResponse,
InferenceStatusResponse,
ChatCompletionRequest,
@ -751,6 +752,34 @@ async def get_status(
raise HTTPException(status_code = 500, detail = f"Failed to get status: {str(e)}")
@router.get("/load-progress", response_model = LoadProgressResponse)
async def get_load_progress(
current_subject: str = Depends(get_current_subject),
):
"""
Return the active GGUF load's mmap/upload progress.
During the warmup window after a GGUF download -- when llama-server
is paging ~tens-to-hundreds of GB of shards into the page cache
before pushing layers to VRAM -- ``/api/inference/status`` only
shows a generic spinner. This endpoint exposes sampled progress so
the UI can render a real bar plus rate/ETA during that window.
Returns an empty payload (``phase=null, bytes=0``) when no load is
in flight. The frontend should stop polling once ``phase`` becomes
``ready``.
"""
try:
llama_backend = get_llama_cpp_backend()
progress = llama_backend.load_progress()
if progress is None:
return LoadProgressResponse()
return LoadProgressResponse(**progress)
except Exception as e:
logger.warning(f"Error sampling load progress: {e}")
return LoadProgressResponse()
# =====================================================================
# Audio (TTS) Generation (/audio/generate)
# =====================================================================

View file

@ -0,0 +1,258 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""Tests for ``LlamaCppBackend.load_progress()``.
The chat settings flow and the training overlay both show a generic
"Starting model..." spinner during the window after a GGUF download
finishes and before llama-server reports healthy. For small models
that window is a second or two and nobody notices. For large MoE GGUFs
(MiniMax-M2.7, Qwen3.5-397B-A17B, etc.) the llama-server process spends
minutes in kernel state D, paging tens or hundreds of GB of shards
into the page cache. The UI has no way to show a real progress bar,
rate, or ETA during that window.
``load_progress()`` samples ``/proc/<pid>/status VmRSS`` (what the
kernel has actually paged in) against the total shard file size on
disk, so the frontend can render a real bar plus rate/ETA. This
module pins that contract:
* returns ``None`` when no load is in flight
* returns ``{"phase": "mmap", ...}`` while the subprocess is alive
but ``_healthy`` is False
* returns ``{"phase": "ready", ...}`` once ``_healthy`` flips
* ``bytes_total`` is derived from the resolved on-disk path
(which the paired fix assigns to ``self._gguf_path`` on both the
local-GGUF and HF-download code paths)
* ``bytes_loaded`` is VmRSS in bytes, capped by total, rounded
* ``fraction`` is clamped to 0..1 and rounded to 4 decimal places
Linux-only via ``/proc``. On platforms without ``/proc`` the method
returns ``None`` instead of raising.
Cross-platform test: skips cleanly on macOS / Windows if ``/proc`` is
not available.
"""
from __future__ import annotations
import os
import sys
import tempfile
import types as _types
from pathlib import Path
from unittest.mock import patch
import pytest
# ---------------------------------------------------------------------------
# Stub heavy / unavailable external dependencies before importing the
# module under test. Same pattern as test_kv_cache_estimation.py.
# ---------------------------------------------------------------------------
_BACKEND_DIR = str(Path(__file__).resolve().parent.parent)
if _BACKEND_DIR not in sys.path:
sys.path.insert(0, _BACKEND_DIR)
_loggers_stub = _types.ModuleType("loggers")
_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name)
sys.modules.setdefault("loggers", _loggers_stub)
_structlog_stub = _types.ModuleType("structlog")
sys.modules.setdefault("structlog", _structlog_stub)
_httpx_stub = _types.ModuleType("httpx")
for _exc_name in (
"ConnectError",
"TimeoutException",
"ReadTimeout",
"ReadError",
"RemoteProtocolError",
"CloseError",
):
setattr(_httpx_stub, _exc_name, type(_exc_name, (Exception,), {}))
class _FakeTimeout:
def __init__(self, *a, **kw):
pass
_httpx_stub.Timeout = _FakeTimeout
_httpx_stub.Client = type(
"Client",
(),
{
"__init__": lambda self, **kw: None,
"__enter__": lambda self: self,
"__exit__": lambda self, *a: None,
},
)
sys.modules.setdefault("httpx", _httpx_stub)
from core.inference.llama_cpp import LlamaCppBackend
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_instance():
inst = LlamaCppBackend.__new__(LlamaCppBackend)
inst._process = None
inst._gguf_path = None
inst._healthy = False
return inst
class _FakeProc:
"""Minimal stand-in for subprocess.Popen that just carries a pid."""
def __init__(self, pid: int):
self.pid = pid
def _write_sparse_file(path: Path, size_bytes: int) -> None:
"""Create a sparse file of the given size without allocating blocks."""
with open(path, "wb") as fh:
if size_bytes > 0:
fh.truncate(size_bytes)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestLoadProgressEmptyStates:
def test_returns_none_when_no_process(self):
inst = _make_instance()
assert inst.load_progress() is None
def test_returns_none_when_process_has_no_pid(self):
inst = _make_instance()
inst._process = _FakeProc(pid = None) # type: ignore[arg-type]
assert inst.load_progress() is None
class TestLoadProgressSingleShard:
def test_mmap_phase_for_alive_but_unhealthy(self, tmp_path):
"""VmRSS below total -> phase='mmap', fraction reflects progress."""
gguf = tmp_path / "model.gguf"
_write_sparse_file(gguf, 40 * 1024**3) # 40 GB
inst = _make_instance()
inst._process = _FakeProc(pid = os.getpid()) # use our own pid
inst._gguf_path = str(gguf)
inst._healthy = False
# Patch /proc read to claim 10 GB RSS.
def fake_open(path, *args, **kwargs):
if str(path).startswith("/proc/"):
import io
return io.StringIO(f"Name:\ttest\nVmRSS:\t{10 * 1024 ** 2}\tkB\n")
return open(path, *args, **kwargs) # fall through
with patch("builtins.open", side_effect = fake_open):
out = inst.load_progress()
assert out is not None
assert out["phase"] == "mmap"
assert out["bytes_total"] == 40 * 1024**3
assert out["bytes_loaded"] == 10 * 1024**3
assert 0.24 < out["fraction"] < 0.26 # ~25%
def test_ready_phase_when_healthy(self, tmp_path):
gguf = tmp_path / "model.gguf"
_write_sparse_file(gguf, 8 * 1024**3)
inst = _make_instance()
inst._process = _FakeProc(pid = os.getpid())
inst._gguf_path = str(gguf)
inst._healthy = True
def fake_open(path, *args, **kwargs):
if str(path).startswith("/proc/"):
import io
return io.StringIO(f"VmRSS:\t{8 * 1024 ** 2}\tkB\n")
return open(path, *args, **kwargs)
with patch("builtins.open", side_effect = fake_open):
out = inst.load_progress()
assert out is not None
assert out["phase"] == "ready"
assert out["bytes_total"] == 8 * 1024**3
assert out["bytes_loaded"] == 8 * 1024**3
assert out["fraction"] == 1.0
class TestLoadProgressMultiShard:
"""Shard-aware total: for ``*-00001-of-00004.gguf`` primaries the
method sums sibling files with the same prefix."""
def test_sharded_total_aggregates_siblings(self, tmp_path):
for i in range(1, 5):
_write_sparse_file(
tmp_path / f"model-{i:05d}-of-00004.gguf",
size_bytes = 20 * 1024**3,
)
# Drop an unrelated .gguf in the same folder -- must not be counted.
_write_sparse_file(tmp_path / "mmproj-BF16.gguf", 2 * 1024**3)
inst = _make_instance()
inst._process = _FakeProc(pid = os.getpid())
inst._gguf_path = str(tmp_path / "model-00001-of-00004.gguf")
inst._healthy = False
def fake_open(path, *args, **kwargs):
if str(path).startswith("/proc/"):
import io
return io.StringIO("VmRSS:\t0\tkB\n")
return open(path, *args, **kwargs)
with patch("builtins.open", side_effect = fake_open):
out = inst.load_progress()
assert out is not None
assert out["bytes_total"] == 80 * 1024**3 # 4 x 20 GB, no mmproj
class TestLoadProgressDegradation:
"""Broken / unusual inputs never raise; they produce best-effort output."""
def test_missing_gguf_path_still_reports_rss(self, tmp_path):
inst = _make_instance()
inst._process = _FakeProc(pid = os.getpid())
inst._gguf_path = None
inst._healthy = False
def fake_open(path, *args, **kwargs):
if str(path).startswith("/proc/"):
import io
return io.StringIO("VmRSS:\t1024\tkB\n")
return open(path, *args, **kwargs)
with patch("builtins.open", side_effect = fake_open):
out = inst.load_progress()
assert out is not None
assert out["phase"] == "mmap"
assert out["bytes_total"] == 0
assert out["bytes_loaded"] == 1024 * 1024
assert out["fraction"] == 0.0
def test_unreadable_proc_returns_none(self, tmp_path):
inst = _make_instance()
# Pid that doesn't exist -> /proc read fails.
inst._process = _FakeProc(pid = 999_999_999)
inst._gguf_path = str(tmp_path / "model.gguf") # doesn't need to exist
inst._healthy = False
out = inst.load_progress()
# FileNotFoundError on /proc path -> load_progress returns None.
assert out is None

View file

@ -145,6 +145,31 @@ export async function getDatasetDownloadProgress(
return parseJsonOrThrow(response);
}
export type ModelLoadPhase = "mmap" | "ready" | null;
export interface LoadProgressResponse {
/**
* Load phase: ``"mmap"`` while the llama-server subprocess is paging
* weight shards into RAM, ``"ready"`` once it has reported healthy,
* or ``null`` when no load is in flight.
*/
phase: ModelLoadPhase;
bytes_loaded: number;
bytes_total: number;
fraction: number;
}
/**
* Fetch the active GGUF load's mmap/upload progress. Complements
* ``getDownloadProgress`` / ``getGgufDownloadProgress`` for the window
* between "download complete" and "chat ready", which for large MoE
* models can be several minutes of otherwise-opaque spinning.
*/
export async function getLoadProgress(): Promise<LoadProgressResponse> {
const response = await authFetch(`/api/inference/load-progress`);
return parseJsonOrThrow(response);
}
export interface LocalModelInfo {
id: string;
display_name: string;

View file

@ -8,12 +8,14 @@ import {
getDownloadProgress,
getGgufDownloadProgress,
getInferenceStatus,
getLoadProgress,
listLoras,
listModels,
loadModel,
unloadModel,
validateModel,
} from "../api/chat-api";
import { formatEta, formatRate } from "../utils/format-transfer";
import { useChatRuntimeStore } from "../stores/chat-runtime-store";
import type { InferenceStatusResponse, LoadModelResponse } from "../types/api";
import type {
@ -565,77 +567,147 @@ export function useChatModelRuntime() {
);
loadToastIdRef.current = toastId;
// Poll download progress for non-cached models (GGUF and non-GGUF)
// Poll download progress for non-cached models (GGUF and non-GGUF).
// Then, once the download wraps (or for already-cached models),
// poll the llama-server mmap phase so "Starting model..." no
// longer looks frozen for several minutes on large MoE models.
let progressInterval: ReturnType<typeof setInterval> | null = null;
if (!isDownloaded && !isCachedLora) {
const expectedBytes =
typeof selection !== "string" ? selection.expectedBytes ?? 0 : 0;
let hasShownProgress = false;
const expectedBytes =
typeof selection !== "string" ? selection.expectedBytes ?? 0 : 0;
const pollProgress = async () => {
if (abortCtrl.signal.aborted || !loadingModelRef.current) {
if (progressInterval) clearInterval(progressInterval);
return;
}
try {
const prog = ggufVariant && expectedBytes > 0
// Rolling window of byte samples for rate / ETA estimation.
// Shared across download + mmap phases so the estimator doesn't
// reset when the phase flips.
type Sample = { t: number; b: number };
const MIN_SAMPLES = 3;
const MIN_WINDOW = 3_000; // ms
const MAX_WINDOW = 15_000; // ms
const dlSamples: Sample[] = [];
const mmapSamples: Sample[] = [];
function estimate(
samples: Sample[],
bytes: number,
total: number,
): { rate: number; eta: number; stable: boolean } {
const now = Date.now();
// Drop samples if the counter reset (e.g. phase flipped).
if (samples.length > 0 && bytes < samples[samples.length - 1].b) {
samples.length = 0;
}
samples.push({ t: now, b: bytes });
const cutoff = now - MAX_WINDOW;
while (samples.length > 2 && samples[0].t < cutoff) {
samples.shift();
}
if (samples.length < MIN_SAMPLES) {
return { rate: 0, eta: 0, stable: false };
}
const first = samples[0];
const last = samples[samples.length - 1];
const dt = (last.t - first.t) / 1000;
const db = last.b - first.b;
if (dt * 1000 < MIN_WINDOW || db <= 0) {
return { rate: 0, eta: 0, stable: false };
}
const rate = db / dt;
const eta =
total > 0 && bytes < total && rate > 0 ? (total - bytes) / rate : 0;
return { rate, eta, stable: true };
}
function composeProgressLabel(
dlGb: number,
totalGb: number,
bytes: number,
total: number,
samples: Sample[],
): string {
const base =
totalGb > 0
? `${dlGb.toFixed(1)} of ${totalGb.toFixed(1)} GB`
: `${dlGb.toFixed(1)} GB downloaded`;
const est = estimate(samples, bytes, total);
if (!est.stable) return base;
const rateStr = formatRate(est.rate);
const etaStr = total > 0 ? formatEta(est.eta) : "";
return etaStr && etaStr !== "--"
? `${base}${rateStr}${etaStr} left`
: `${base}${rateStr}`;
}
let downloadComplete = isDownloaded || isCachedLora;
const pollDownload = async () => {
if (abortCtrl.signal.aborted || !loadingModelRef.current) {
if (progressInterval) clearInterval(progressInterval);
return;
}
try {
const prog =
ggufVariant && expectedBytes > 0
? await getGgufDownloadProgress(modelId, ggufVariant, expectedBytes)
: await getDownloadProgress(modelId);
if (!loadingModelRef.current) return;
if (!loadingModelRef.current) return;
if (prog.progress > 0 && prog.progress < 1) {
hasShownProgress = true;
const dlGb = prog.downloaded_bytes / (1024 ** 3);
const totalGb = prog.expected_bytes / (1024 ** 3);
const pct = Math.round(prog.progress * 100);
const progressLabel = totalGb > 0
? `${dlGb.toFixed(1)} of ${totalGb.toFixed(1)} GB`
: `${dlGb.toFixed(1)} GB downloaded`;
setLoadProgress({
percent: pct,
label: progressLabel,
phase: "downloading",
});
if (loadToastDismissedRef.current) return;
toast(
null,
{
id: toastId,
description: renderLoadDescription(
"Downloading model…",
loadingDescription,
pct,
progressLabel,
cancelLoading,
),
duration: Infinity,
closeButton: false,
classNames: MODEL_LOAD_TOAST_CLASSNAMES,
onDismiss: (dismissedToast) => {
if (loadToastIdRef.current !== dismissedToast.id) return;
setLoadToastDismissedState(true);
},
},
);
} else if (prog.downloaded_bytes > 0 && prog.expected_bytes === 0 && prog.progress === 0) {
hasShownProgress = true;
const dlGb = prog.downloaded_bytes / (1024 ** 3);
setLoadProgress({
percent: null,
label: `${dlGb.toFixed(1)} GB downloaded`,
phase: "downloading",
});
} else if (prog.progress >= 1 && hasShownProgress) {
setLoadProgress({
percent: 100,
label: "Download complete",
phase: "starting",
});
if (loadToastDismissedRef.current) {
if (progressInterval) clearInterval(progressInterval);
return;
}
if (prog.progress > 0 && prog.progress < 1) {
hasShownProgress = true;
const dlGb = prog.downloaded_bytes / (1024 ** 3);
const totalGb = prog.expected_bytes / (1024 ** 3);
const pct = Math.round(prog.progress * 100);
const progressLabel = composeProgressLabel(
dlGb,
totalGb,
prog.downloaded_bytes,
prog.expected_bytes,
dlSamples,
);
setLoadProgress({
percent: pct,
label: progressLabel,
phase: "downloading",
});
if (loadToastDismissedRef.current) return;
toast(null, {
id: toastId,
description: renderLoadDescription(
"Downloading model…",
loadingDescription,
pct,
progressLabel,
cancelLoading,
),
duration: Infinity,
closeButton: false,
classNames: MODEL_LOAD_TOAST_CLASSNAMES,
onDismiss: (dismissedToast) => {
if (loadToastIdRef.current !== dismissedToast.id) return;
setLoadToastDismissedState(true);
},
});
} else if (
prog.downloaded_bytes > 0 &&
prog.expected_bytes === 0 &&
prog.progress === 0
) {
hasShownProgress = true;
const dlGb = prog.downloaded_bytes / (1024 ** 3);
const est = estimate(dlSamples, prog.downloaded_bytes, 0);
const rateSuffix =
est.stable ? `${formatRate(est.rate)}` : "";
setLoadProgress({
percent: null,
label: `${dlGb.toFixed(1)} GB downloaded${rateSuffix}`,
phase: "downloading",
});
} else if (prog.progress >= 1 && hasShownProgress) {
downloadComplete = true;
setLoadProgress({
percent: 100,
label: "Download complete",
phase: "starting",
});
if (!loadToastDismissedRef.current) {
toast(null, {
id: toastId,
description: renderLoadDescription(
@ -653,16 +725,79 @@ export function useChatModelRuntime() {
setLoadToastDismissedState(true);
},
});
if (progressInterval) clearInterval(progressInterval);
}
} catch {
// Ignore polling errors
// Keep polling: the mmap branch below takes over from here.
}
};
} catch {
// Ignore polling errors; keep polling.
}
};
setTimeout(pollProgress, 500);
progressInterval = setInterval(pollProgress, 2000);
}
const pollLoad = async () => {
if (abortCtrl.signal.aborted || !loadingModelRef.current) {
if (progressInterval) clearInterval(progressInterval);
return;
}
try {
const prog = await getLoadProgress();
if (!loadingModelRef.current) return;
if (!prog || prog.phase == null) return;
if (prog.phase === "ready") {
// Loaded. The chat flow will flip loadingModelRef shortly;
// just stop polling.
if (progressInterval) clearInterval(progressInterval);
return;
}
if (prog.bytes_total <= 0) return; // nothing useful to render
const loadedGb = prog.bytes_loaded / (1024 ** 3);
const totalGb = prog.bytes_total / (1024 ** 3);
const pct = Math.min(99, Math.round(prog.fraction * 100));
const est = estimate(mmapSamples, prog.bytes_loaded, prog.bytes_total);
const base = `${loadedGb.toFixed(1)} of ${totalGb.toFixed(1)} GB in memory`;
const label = est.stable
? `${base}${formatRate(est.rate)}${
formatEta(est.eta) !== "--" ? `${formatEta(est.eta)} left` : ""
}`
: base;
setLoadProgress({
percent: pct,
label,
phase: "starting",
});
if (loadToastDismissedRef.current) return;
toast(null, {
id: toastId,
description: renderLoadDescription(
"Starting model…",
"Paging weights into memory.",
pct,
label,
cancelLoading,
),
duration: Infinity,
closeButton: false,
classNames: MODEL_LOAD_TOAST_CLASSNAMES,
onDismiss: (dismissedToast) => {
if (loadToastIdRef.current !== dismissedToast.id) return;
setLoadToastDismissedState(true);
},
});
} catch {
// Ignore polling errors.
}
};
const pollProgress = async () => {
if (!downloadComplete) {
await pollDownload();
} else {
await pollLoad();
}
};
let hasShownProgress = false;
setTimeout(pollProgress, 500);
progressInterval = setInterval(pollProgress, 2000);
try {
await performLoad();

View file

@ -0,0 +1,98 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
/**
* Compute rate (bytes/sec) and ETA (seconds) from a time-series of
* cumulative ``bytes`` values, using a rolling window of recent samples.
*
* Shared between the chat-flow download toast, the training-start
* overlay, and the model-load phase UI. All three have the same shape:
* a counter that rises monotonically from 0 toward ``totalBytes``, polled
* on an interval. The derived stats are identical regardless of whether
* the bytes came from an HTTP download or an mmap page-in.
*
* Stability rule: ``stable`` stays ``false`` until we've observed at
* least 3 samples spanning 3 seconds. That keeps the UI from flashing
* wildly varying rates during the first tick or two when the denominator
* is effectively zero.
*/
import { useEffect, useRef, useState } from "react";
export type TransferStats = {
rateBytesPerSecond: number;
etaSeconds: number;
/**
* False for the first few ticks (window not filled, or no forward
* progress yet). Consumers should hide rate/ETA while unstable so the
* UI doesn't flicker "123 GB/s" during the first tick.
*/
stable: boolean;
};
const MIN_SAMPLES = 3;
const MIN_WINDOW_SECONDS = 3;
const MAX_WINDOW_SECONDS = 15;
export function useTransferStats(
bytes: number | null | undefined,
totalBytes: number | null | undefined,
): TransferStats {
const samplesRef = useRef<{ t: number; b: number }[]>([]);
const [state, setState] = useState<TransferStats>({
rateBytesPerSecond: 0,
etaSeconds: 0,
stable: false,
});
useEffect(() => {
const now = Date.now() / 1000;
const cur = typeof bytes === "number" && Number.isFinite(bytes) ? bytes : 0;
const total =
typeof totalBytes === "number" && Number.isFinite(totalBytes)
? totalBytes
: 0;
// If the counter resets (e.g. user unloaded and started a new
// download), drop the stale window.
const samples = samplesRef.current;
if (samples.length > 0 && cur < samples[samples.length - 1].b) {
samples.length = 0;
}
samples.push({ t: now, b: cur });
// Drop samples older than MAX_WINDOW_SECONDS; keep at least 2 so we
// can still compute a rate when the counter hasn't moved in a while.
const cutoff = now - MAX_WINDOW_SECONDS;
while (samples.length > 2 && samples[0].t < cutoff) {
samples.shift();
}
if (samples.length < MIN_SAMPLES) {
setState({ rateBytesPerSecond: 0, etaSeconds: 0, stable: false });
return;
}
const first = samples[0];
const last = samples[samples.length - 1];
const dt = last.t - first.t;
const db = last.b - first.b;
if (dt < MIN_WINDOW_SECONDS || db <= 0) {
setState({ rateBytesPerSecond: 0, etaSeconds: 0, stable: false });
return;
}
const rate = db / dt;
const eta =
total > 0 && cur < total && rate > 0 ? (total - cur) / rate : 0;
setState({
rateBytesPerSecond: rate,
etaSeconds: eta,
stable: true,
});
}, [bytes, totalBytes]);
return state;
}

View file

@ -0,0 +1,44 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
/**
* Format a byte-per-second rate as a human-readable string.
*
* 512 "512 B/s"
* 1_234_567 "1.2 MB/s"
* 1_234_567_890 "1.15 GB/s"
*
* Returns `"--"` for non-finite or non-positive inputs so callers can
* render the label safely before the first stable sample arrives.
*/
export function formatRate(bytesPerSecond: number): string {
if (!Number.isFinite(bytesPerSecond) || bytesPerSecond <= 0) return "--";
const bps = bytesPerSecond;
if (bps < 1024) return `${bps.toFixed(0)} B/s`;
if (bps < 1024 ** 2) return `${(bps / 1024).toFixed(1)} KB/s`;
if (bps < 1024 ** 3) return `${(bps / 1024 ** 2).toFixed(1)} MB/s`;
return `${(bps / 1024 ** 3).toFixed(2)} GB/s`;
}
/**
* Format an ETA (in seconds) as a short human-readable string.
*
* 47 "47s"
* 125 "2m 5s"
* 3725 "1h 2m"
*
* Returns `"--"` for non-finite or non-positive inputs.
*/
export function formatEta(seconds: number): string {
if (!Number.isFinite(seconds) || seconds <= 0) return "--";
const s = Math.round(seconds);
if (s < 60) return `${s}s`;
if (s < 3600) {
const m = Math.floor(s / 60);
const rem = s % 60;
return rem > 0 ? `${m}m ${rem}s` : `${m}m`;
}
const h = Math.floor(s / 3600);
const m = Math.floor((s % 3600) / 60);
return m > 0 ? `${h}h ${m}m` : `${h}h`;
}

View file

@ -23,6 +23,8 @@ import {
getDownloadProgress,
type DownloadProgressResponse,
} from "@/features/chat/api/chat-api";
import { useTransferStats } from "@/features/chat/hooks/use-transfer-stats";
import { formatEta, formatRate } from "@/features/chat/utils/format-transfer";
import {
useTrainingActions,
useTrainingConfigStore,
@ -151,6 +153,11 @@ type DownloadRowProps = {
};
function DownloadRow({ label, state }: DownloadRowProps): ReactElement | null {
// Compute a rolling-window rate + ETA from the same cumulative-byte
// series the poll hook already produces, so we can show
// "5.2 / 20.7 GB • 85.3 MB/s • 3m 12s left" instead of just the pair.
const stats = useTransferStats(state.downloadedBytes, state.totalBytes);
if (state.downloadedBytes <= 0 && !state.cachePath) return null;
const isComplete = state.totalBytes > 0 && state.percent >= 100;
const statusLabel = isComplete
@ -160,11 +167,16 @@ function DownloadRow({ label, state }: DownloadRowProps): ReactElement | null {
: state.downloadedBytes === 0
? "Preparing"
: null;
const showRate = stats.stable && !isComplete;
const rateSuffix = showRate ? `${formatRate(stats.rateBytesPerSecond)}` : "";
const etaStr =
showRate && state.totalBytes > 0 ? formatEta(stats.etaSeconds) : "--";
const etaSuffix = etaStr !== "--" ? `${etaStr} left` : "";
const sizeLabel =
state.totalBytes > 0
? `${formatBytes(state.downloadedBytes)} / ${formatBytes(state.totalBytes)}`
? `${formatBytes(state.downloadedBytes)} / ${formatBytes(state.totalBytes)}${rateSuffix}${etaSuffix}`
: state.downloadedBytes > 0
? `${formatBytes(state.downloadedBytes)} downloaded`
? `${formatBytes(state.downloadedBytes)} downloaded${rateSuffix}`
: null;
return (
<div className="flex flex-col gap-1.5 rounded-md border border-border/50 bg-muted/20 px-3 py-2">