mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Studio: live model-load progress + rate/ETA on download and load (#5017)
* Studio: live model-load progress + rate/ETA on download and load Two UX fixes for the opaque multi-minute wait between clicking Load and being able to chat, visible most clearly on large MoE GGUFs like MiniMax-M2.7 (131 GB of weights on a 97 GB GPU): 1. **Model-load phase is now observable.** The existing chat flow transitions the toast to "Starting model..." as soon as the download hits 100%, then shows a spinner with no other feedback until llama-server reports healthy. For a 130 GB model that spinner freezes for five-plus minutes while the kernel pages shards into the page cache. A new `GET /api/inference/load-progress` endpoint samples `/proc/<pid>/status VmRSS` on the llama-server subprocess against the sum of shard file sizes on disk, so the UI can render a real bar plus rate / ETA during that window. 2. **Rate and ETA on downloads and loads.** Both the chat toast and the training-start overlay used to show a static pair of numbers (for example "15.4 of 140.8 GB"). A rolling 15-second window over the existing byte-series now surfaces "85.3 MB/s, 24m 23s left" beside that pair. The estimator is shared between the download and load phases so the numbers don't reset when the phase flips. Also fixes a pre-existing assignment bug uncovered while wiring this up: `load_model` was storing the caller's `gguf_path` kwarg into `self._gguf_path`, which is `None` on the HF-download code path. The resolved on-disk path (`model_path`) is what llama-server actually mmaps; downstream consumers need that. No existing reader used `_gguf_path`, so this is a correctness fix for the new endpoint. - Backend: `LlamaCppBackend.load_progress()`, `GET /api/inference/load-progress`, `LoadProgressResponse` Pydantic model. - Frontend: `useTransferStats` hook, `formatRate` / `formatEta` helpers, `getLoadProgress` client, rewired chat toast and `DownloadRow` in the training overlay. - Tests: `studio/backend/tests/test_llama_cpp_load_progress.py` covers empty states, mmap phase, ready phase, sharded total aggregation, missing gguf_path, and unreadable /proc (7 cases). `tsc -b` and `vite build` on the frontend both clean. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
514bb3a20e
commit
bb14ab144a
9 changed files with 805 additions and 76 deletions
|
|
@ -185,6 +185,96 @@ class LlamaCppBackend:
|
|||
"""Return the model's native context length from GGUF metadata."""
|
||||
return self._context_length
|
||||
|
||||
def load_progress(self) -> Optional[dict]:
|
||||
"""Return live model-load progress, or None if not loading.
|
||||
|
||||
While llama-server is warming up, its process is typically in
|
||||
kernel state D (disk sleep) mmap'ing the weight shards into
|
||||
page cache before pushing layers to VRAM. During that window
|
||||
``/api/inference/status`` only reports ``loading``, which gives
|
||||
the UI nothing to display besides a spinner that looks stuck
|
||||
for minutes on large MoE models.
|
||||
|
||||
This method samples ``/proc/<pid>/status VmRSS`` against the
|
||||
sum of the GGUF shard sizes so the UI can render a real bar
|
||||
and compute rate / ETA. Returns ``None`` when no load is in
|
||||
flight (no process, or process already healthy).
|
||||
|
||||
Shape::
|
||||
|
||||
{
|
||||
"phase": "mmap" | "ready",
|
||||
"bytes_loaded": int, # VmRSS of the llama-server
|
||||
"bytes_total": int, # sum of shard file sizes
|
||||
"fraction": float, # bytes_loaded / bytes_total, 0..1
|
||||
}
|
||||
|
||||
Linux-only in the current implementation. On macOS/Windows the
|
||||
equivalent would be a different API; this returns ``None`` on
|
||||
platforms where ``/proc/<pid>/status`` is unavailable.
|
||||
"""
|
||||
proc = self._process
|
||||
if proc is None:
|
||||
return None
|
||||
pid = proc.pid
|
||||
if pid is None:
|
||||
return None
|
||||
|
||||
# Sum up shard sizes (primary + any extras sitting alongside).
|
||||
bytes_total = 0
|
||||
gguf_path = self._gguf_path
|
||||
if gguf_path:
|
||||
primary = Path(gguf_path)
|
||||
try:
|
||||
if primary.is_file():
|
||||
bytes_total += primary.stat().st_size
|
||||
except OSError:
|
||||
pass
|
||||
# Extra shards live alongside the primary with the same prefix
|
||||
# before the shard index (e.g. ``-00001-of-00004.gguf``).
|
||||
try:
|
||||
parent = primary.parent
|
||||
stem = primary.name
|
||||
m = _SHARD_RE.match(stem)
|
||||
prefix = m.group(1) if m else None
|
||||
if prefix and parent.is_dir():
|
||||
for sibling in parent.iterdir():
|
||||
if (
|
||||
sibling.is_file()
|
||||
and sibling.name.startswith(prefix)
|
||||
and sibling.name != stem
|
||||
and sibling.suffix == ".gguf"
|
||||
):
|
||||
try:
|
||||
bytes_total += sibling.stat().st_size
|
||||
except OSError:
|
||||
pass
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Read VmRSS from /proc/<pid>/status. Kilobytes on Linux.
|
||||
bytes_loaded = 0
|
||||
try:
|
||||
with open(f"/proc/{pid}/status", "r", encoding = "utf-8") as f:
|
||||
for line in f:
|
||||
if line.startswith("VmRSS:"):
|
||||
kb = int(line.split()[1])
|
||||
bytes_loaded = kb * 1024
|
||||
break
|
||||
except (FileNotFoundError, PermissionError, ValueError, OSError):
|
||||
return None
|
||||
|
||||
phase = "ready" if self._healthy else "mmap"
|
||||
fraction = 0.0
|
||||
if bytes_total > 0:
|
||||
fraction = min(1.0, bytes_loaded / bytes_total)
|
||||
return {
|
||||
"phase": phase,
|
||||
"bytes_loaded": bytes_loaded,
|
||||
"bytes_total": bytes_total,
|
||||
"fraction": round(fraction, 4),
|
||||
}
|
||||
|
||||
@property
|
||||
def chat_template(self) -> Optional[str]:
|
||||
return self._chat_template
|
||||
|
|
@ -1574,7 +1664,12 @@ class LlamaCppBackend:
|
|||
)
|
||||
self._stdout_thread.start()
|
||||
|
||||
self._gguf_path = gguf_path
|
||||
# Store the resolved on-disk path, not the caller's kwarg. In
|
||||
# HF mode the caller passes gguf_path=None and the real path
|
||||
# (``model_path``) is what llama-server is actually mmap'ing.
|
||||
# Downstream consumers (load_progress, log lines, etc.) need
|
||||
# the path that exists on disk.
|
||||
self._gguf_path = model_path
|
||||
self._hf_repo = hf_repo
|
||||
# For local GGUF files, extract variant from filename if not provided
|
||||
if hf_variant:
|
||||
|
|
|
|||
|
|
@ -188,6 +188,39 @@ class UnloadResponse(BaseModel):
|
|||
model: str = Field(..., description = "Model identifier that was unloaded")
|
||||
|
||||
|
||||
class LoadProgressResponse(BaseModel):
|
||||
"""Progress of the active GGUF load, sampled on demand.
|
||||
|
||||
Used by the UI to show a real progress bar during the
|
||||
post-download warmup window (mmap + CUDA upload), rather than a
|
||||
generic "Starting model..." spinner that freezes for minutes on
|
||||
large MoE models.
|
||||
"""
|
||||
|
||||
phase: Optional[str] = Field(
|
||||
None,
|
||||
description = (
|
||||
"Load phase: 'mmap' (weights paging into RAM via mmap), "
|
||||
"'ready' (llama-server reported healthy), or null when no "
|
||||
"load is in flight."
|
||||
),
|
||||
)
|
||||
bytes_loaded: int = Field(
|
||||
0,
|
||||
description = (
|
||||
"Bytes of the model already resident in the llama-server "
|
||||
"process (VmRSS on Linux)."
|
||||
),
|
||||
)
|
||||
bytes_total: int = Field(
|
||||
0,
|
||||
description = "Total bytes across all GGUF shards for the active model.",
|
||||
)
|
||||
fraction: float = Field(
|
||||
0.0, description = "bytes_loaded / bytes_total, clamped to 0..1."
|
||||
)
|
||||
|
||||
|
||||
class InferenceStatusResponse(BaseModel):
|
||||
"""Current inference backend status"""
|
||||
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ from models.inference import (
|
|||
UnloadRequest,
|
||||
GenerateRequest,
|
||||
LoadResponse,
|
||||
LoadProgressResponse,
|
||||
UnloadResponse,
|
||||
InferenceStatusResponse,
|
||||
ChatCompletionRequest,
|
||||
|
|
@ -751,6 +752,34 @@ async def get_status(
|
|||
raise HTTPException(status_code = 500, detail = f"Failed to get status: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/load-progress", response_model = LoadProgressResponse)
|
||||
async def get_load_progress(
|
||||
current_subject: str = Depends(get_current_subject),
|
||||
):
|
||||
"""
|
||||
Return the active GGUF load's mmap/upload progress.
|
||||
|
||||
During the warmup window after a GGUF download -- when llama-server
|
||||
is paging ~tens-to-hundreds of GB of shards into the page cache
|
||||
before pushing layers to VRAM -- ``/api/inference/status`` only
|
||||
shows a generic spinner. This endpoint exposes sampled progress so
|
||||
the UI can render a real bar plus rate/ETA during that window.
|
||||
|
||||
Returns an empty payload (``phase=null, bytes=0``) when no load is
|
||||
in flight. The frontend should stop polling once ``phase`` becomes
|
||||
``ready``.
|
||||
"""
|
||||
try:
|
||||
llama_backend = get_llama_cpp_backend()
|
||||
progress = llama_backend.load_progress()
|
||||
if progress is None:
|
||||
return LoadProgressResponse()
|
||||
return LoadProgressResponse(**progress)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sampling load progress: {e}")
|
||||
return LoadProgressResponse()
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Audio (TTS) Generation (/audio/generate)
|
||||
# =====================================================================
|
||||
|
|
|
|||
258
studio/backend/tests/test_llama_cpp_load_progress.py
Normal file
258
studio/backend/tests/test_llama_cpp_load_progress.py
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
# SPDX-License-Identifier: AGPL-3.0-only
|
||||
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
|
||||
|
||||
"""Tests for ``LlamaCppBackend.load_progress()``.
|
||||
|
||||
The chat settings flow and the training overlay both show a generic
|
||||
"Starting model..." spinner during the window after a GGUF download
|
||||
finishes and before llama-server reports healthy. For small models
|
||||
that window is a second or two and nobody notices. For large MoE GGUFs
|
||||
(MiniMax-M2.7, Qwen3.5-397B-A17B, etc.) the llama-server process spends
|
||||
minutes in kernel state D, paging tens or hundreds of GB of shards
|
||||
into the page cache. The UI has no way to show a real progress bar,
|
||||
rate, or ETA during that window.
|
||||
|
||||
``load_progress()`` samples ``/proc/<pid>/status VmRSS`` (what the
|
||||
kernel has actually paged in) against the total shard file size on
|
||||
disk, so the frontend can render a real bar plus rate/ETA. This
|
||||
module pins that contract:
|
||||
|
||||
* returns ``None`` when no load is in flight
|
||||
* returns ``{"phase": "mmap", ...}`` while the subprocess is alive
|
||||
but ``_healthy`` is False
|
||||
* returns ``{"phase": "ready", ...}`` once ``_healthy`` flips
|
||||
* ``bytes_total`` is derived from the resolved on-disk path
|
||||
(which the paired fix assigns to ``self._gguf_path`` on both the
|
||||
local-GGUF and HF-download code paths)
|
||||
* ``bytes_loaded`` is VmRSS in bytes, capped by total, rounded
|
||||
* ``fraction`` is clamped to 0..1 and rounded to 4 decimal places
|
||||
|
||||
Linux-only via ``/proc``. On platforms without ``/proc`` the method
|
||||
returns ``None`` instead of raising.
|
||||
Cross-platform test: skips cleanly on macOS / Windows if ``/proc`` is
|
||||
not available.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import types as _types
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stub heavy / unavailable external dependencies before importing the
|
||||
# module under test. Same pattern as test_kv_cache_estimation.py.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_BACKEND_DIR = str(Path(__file__).resolve().parent.parent)
|
||||
if _BACKEND_DIR not in sys.path:
|
||||
sys.path.insert(0, _BACKEND_DIR)
|
||||
|
||||
_loggers_stub = _types.ModuleType("loggers")
|
||||
_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name)
|
||||
sys.modules.setdefault("loggers", _loggers_stub)
|
||||
|
||||
_structlog_stub = _types.ModuleType("structlog")
|
||||
sys.modules.setdefault("structlog", _structlog_stub)
|
||||
|
||||
_httpx_stub = _types.ModuleType("httpx")
|
||||
for _exc_name in (
|
||||
"ConnectError",
|
||||
"TimeoutException",
|
||||
"ReadTimeout",
|
||||
"ReadError",
|
||||
"RemoteProtocolError",
|
||||
"CloseError",
|
||||
):
|
||||
setattr(_httpx_stub, _exc_name, type(_exc_name, (Exception,), {}))
|
||||
|
||||
|
||||
class _FakeTimeout:
|
||||
def __init__(self, *a, **kw):
|
||||
pass
|
||||
|
||||
|
||||
_httpx_stub.Timeout = _FakeTimeout
|
||||
_httpx_stub.Client = type(
|
||||
"Client",
|
||||
(),
|
||||
{
|
||||
"__init__": lambda self, **kw: None,
|
||||
"__enter__": lambda self: self,
|
||||
"__exit__": lambda self, *a: None,
|
||||
},
|
||||
)
|
||||
sys.modules.setdefault("httpx", _httpx_stub)
|
||||
|
||||
from core.inference.llama_cpp import LlamaCppBackend
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_instance():
|
||||
inst = LlamaCppBackend.__new__(LlamaCppBackend)
|
||||
inst._process = None
|
||||
inst._gguf_path = None
|
||||
inst._healthy = False
|
||||
return inst
|
||||
|
||||
|
||||
class _FakeProc:
|
||||
"""Minimal stand-in for subprocess.Popen that just carries a pid."""
|
||||
|
||||
def __init__(self, pid: int):
|
||||
self.pid = pid
|
||||
|
||||
|
||||
def _write_sparse_file(path: Path, size_bytes: int) -> None:
|
||||
"""Create a sparse file of the given size without allocating blocks."""
|
||||
with open(path, "wb") as fh:
|
||||
if size_bytes > 0:
|
||||
fh.truncate(size_bytes)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLoadProgressEmptyStates:
|
||||
def test_returns_none_when_no_process(self):
|
||||
inst = _make_instance()
|
||||
assert inst.load_progress() is None
|
||||
|
||||
def test_returns_none_when_process_has_no_pid(self):
|
||||
inst = _make_instance()
|
||||
inst._process = _FakeProc(pid = None) # type: ignore[arg-type]
|
||||
assert inst.load_progress() is None
|
||||
|
||||
|
||||
class TestLoadProgressSingleShard:
|
||||
def test_mmap_phase_for_alive_but_unhealthy(self, tmp_path):
|
||||
"""VmRSS below total -> phase='mmap', fraction reflects progress."""
|
||||
gguf = tmp_path / "model.gguf"
|
||||
_write_sparse_file(gguf, 40 * 1024**3) # 40 GB
|
||||
|
||||
inst = _make_instance()
|
||||
inst._process = _FakeProc(pid = os.getpid()) # use our own pid
|
||||
inst._gguf_path = str(gguf)
|
||||
inst._healthy = False
|
||||
|
||||
# Patch /proc read to claim 10 GB RSS.
|
||||
def fake_open(path, *args, **kwargs):
|
||||
if str(path).startswith("/proc/"):
|
||||
import io
|
||||
|
||||
return io.StringIO(f"Name:\ttest\nVmRSS:\t{10 * 1024 ** 2}\tkB\n")
|
||||
return open(path, *args, **kwargs) # fall through
|
||||
|
||||
with patch("builtins.open", side_effect = fake_open):
|
||||
out = inst.load_progress()
|
||||
|
||||
assert out is not None
|
||||
assert out["phase"] == "mmap"
|
||||
assert out["bytes_total"] == 40 * 1024**3
|
||||
assert out["bytes_loaded"] == 10 * 1024**3
|
||||
assert 0.24 < out["fraction"] < 0.26 # ~25%
|
||||
|
||||
def test_ready_phase_when_healthy(self, tmp_path):
|
||||
gguf = tmp_path / "model.gguf"
|
||||
_write_sparse_file(gguf, 8 * 1024**3)
|
||||
|
||||
inst = _make_instance()
|
||||
inst._process = _FakeProc(pid = os.getpid())
|
||||
inst._gguf_path = str(gguf)
|
||||
inst._healthy = True
|
||||
|
||||
def fake_open(path, *args, **kwargs):
|
||||
if str(path).startswith("/proc/"):
|
||||
import io
|
||||
|
||||
return io.StringIO(f"VmRSS:\t{8 * 1024 ** 2}\tkB\n")
|
||||
return open(path, *args, **kwargs)
|
||||
|
||||
with patch("builtins.open", side_effect = fake_open):
|
||||
out = inst.load_progress()
|
||||
|
||||
assert out is not None
|
||||
assert out["phase"] == "ready"
|
||||
assert out["bytes_total"] == 8 * 1024**3
|
||||
assert out["bytes_loaded"] == 8 * 1024**3
|
||||
assert out["fraction"] == 1.0
|
||||
|
||||
|
||||
class TestLoadProgressMultiShard:
|
||||
"""Shard-aware total: for ``*-00001-of-00004.gguf`` primaries the
|
||||
method sums sibling files with the same prefix."""
|
||||
|
||||
def test_sharded_total_aggregates_siblings(self, tmp_path):
|
||||
for i in range(1, 5):
|
||||
_write_sparse_file(
|
||||
tmp_path / f"model-{i:05d}-of-00004.gguf",
|
||||
size_bytes = 20 * 1024**3,
|
||||
)
|
||||
# Drop an unrelated .gguf in the same folder -- must not be counted.
|
||||
_write_sparse_file(tmp_path / "mmproj-BF16.gguf", 2 * 1024**3)
|
||||
|
||||
inst = _make_instance()
|
||||
inst._process = _FakeProc(pid = os.getpid())
|
||||
inst._gguf_path = str(tmp_path / "model-00001-of-00004.gguf")
|
||||
inst._healthy = False
|
||||
|
||||
def fake_open(path, *args, **kwargs):
|
||||
if str(path).startswith("/proc/"):
|
||||
import io
|
||||
|
||||
return io.StringIO("VmRSS:\t0\tkB\n")
|
||||
return open(path, *args, **kwargs)
|
||||
|
||||
with patch("builtins.open", side_effect = fake_open):
|
||||
out = inst.load_progress()
|
||||
|
||||
assert out is not None
|
||||
assert out["bytes_total"] == 80 * 1024**3 # 4 x 20 GB, no mmproj
|
||||
|
||||
|
||||
class TestLoadProgressDegradation:
|
||||
"""Broken / unusual inputs never raise; they produce best-effort output."""
|
||||
|
||||
def test_missing_gguf_path_still_reports_rss(self, tmp_path):
|
||||
inst = _make_instance()
|
||||
inst._process = _FakeProc(pid = os.getpid())
|
||||
inst._gguf_path = None
|
||||
inst._healthy = False
|
||||
|
||||
def fake_open(path, *args, **kwargs):
|
||||
if str(path).startswith("/proc/"):
|
||||
import io
|
||||
|
||||
return io.StringIO("VmRSS:\t1024\tkB\n")
|
||||
return open(path, *args, **kwargs)
|
||||
|
||||
with patch("builtins.open", side_effect = fake_open):
|
||||
out = inst.load_progress()
|
||||
|
||||
assert out is not None
|
||||
assert out["phase"] == "mmap"
|
||||
assert out["bytes_total"] == 0
|
||||
assert out["bytes_loaded"] == 1024 * 1024
|
||||
assert out["fraction"] == 0.0
|
||||
|
||||
def test_unreadable_proc_returns_none(self, tmp_path):
|
||||
inst = _make_instance()
|
||||
# Pid that doesn't exist -> /proc read fails.
|
||||
inst._process = _FakeProc(pid = 999_999_999)
|
||||
inst._gguf_path = str(tmp_path / "model.gguf") # doesn't need to exist
|
||||
inst._healthy = False
|
||||
|
||||
out = inst.load_progress()
|
||||
# FileNotFoundError on /proc path -> load_progress returns None.
|
||||
assert out is None
|
||||
|
|
@ -145,6 +145,31 @@ export async function getDatasetDownloadProgress(
|
|||
return parseJsonOrThrow(response);
|
||||
}
|
||||
|
||||
export type ModelLoadPhase = "mmap" | "ready" | null;
|
||||
|
||||
export interface LoadProgressResponse {
|
||||
/**
|
||||
* Load phase: ``"mmap"`` while the llama-server subprocess is paging
|
||||
* weight shards into RAM, ``"ready"`` once it has reported healthy,
|
||||
* or ``null`` when no load is in flight.
|
||||
*/
|
||||
phase: ModelLoadPhase;
|
||||
bytes_loaded: number;
|
||||
bytes_total: number;
|
||||
fraction: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch the active GGUF load's mmap/upload progress. Complements
|
||||
* ``getDownloadProgress`` / ``getGgufDownloadProgress`` for the window
|
||||
* between "download complete" and "chat ready", which for large MoE
|
||||
* models can be several minutes of otherwise-opaque spinning.
|
||||
*/
|
||||
export async function getLoadProgress(): Promise<LoadProgressResponse> {
|
||||
const response = await authFetch(`/api/inference/load-progress`);
|
||||
return parseJsonOrThrow(response);
|
||||
}
|
||||
|
||||
export interface LocalModelInfo {
|
||||
id: string;
|
||||
display_name: string;
|
||||
|
|
|
|||
|
|
@ -8,12 +8,14 @@ import {
|
|||
getDownloadProgress,
|
||||
getGgufDownloadProgress,
|
||||
getInferenceStatus,
|
||||
getLoadProgress,
|
||||
listLoras,
|
||||
listModels,
|
||||
loadModel,
|
||||
unloadModel,
|
||||
validateModel,
|
||||
} from "../api/chat-api";
|
||||
import { formatEta, formatRate } from "../utils/format-transfer";
|
||||
import { useChatRuntimeStore } from "../stores/chat-runtime-store";
|
||||
import type { InferenceStatusResponse, LoadModelResponse } from "../types/api";
|
||||
import type {
|
||||
|
|
@ -565,77 +567,147 @@ export function useChatModelRuntime() {
|
|||
);
|
||||
loadToastIdRef.current = toastId;
|
||||
|
||||
// Poll download progress for non-cached models (GGUF and non-GGUF)
|
||||
// Poll download progress for non-cached models (GGUF and non-GGUF).
|
||||
// Then, once the download wraps (or for already-cached models),
|
||||
// poll the llama-server mmap phase so "Starting model..." no
|
||||
// longer looks frozen for several minutes on large MoE models.
|
||||
let progressInterval: ReturnType<typeof setInterval> | null = null;
|
||||
if (!isDownloaded && !isCachedLora) {
|
||||
const expectedBytes =
|
||||
typeof selection !== "string" ? selection.expectedBytes ?? 0 : 0;
|
||||
let hasShownProgress = false;
|
||||
const expectedBytes =
|
||||
typeof selection !== "string" ? selection.expectedBytes ?? 0 : 0;
|
||||
|
||||
const pollProgress = async () => {
|
||||
if (abortCtrl.signal.aborted || !loadingModelRef.current) {
|
||||
if (progressInterval) clearInterval(progressInterval);
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const prog = ggufVariant && expectedBytes > 0
|
||||
// Rolling window of byte samples for rate / ETA estimation.
|
||||
// Shared across download + mmap phases so the estimator doesn't
|
||||
// reset when the phase flips.
|
||||
type Sample = { t: number; b: number };
|
||||
const MIN_SAMPLES = 3;
|
||||
const MIN_WINDOW = 3_000; // ms
|
||||
const MAX_WINDOW = 15_000; // ms
|
||||
const dlSamples: Sample[] = [];
|
||||
const mmapSamples: Sample[] = [];
|
||||
|
||||
function estimate(
|
||||
samples: Sample[],
|
||||
bytes: number,
|
||||
total: number,
|
||||
): { rate: number; eta: number; stable: boolean } {
|
||||
const now = Date.now();
|
||||
// Drop samples if the counter reset (e.g. phase flipped).
|
||||
if (samples.length > 0 && bytes < samples[samples.length - 1].b) {
|
||||
samples.length = 0;
|
||||
}
|
||||
samples.push({ t: now, b: bytes });
|
||||
const cutoff = now - MAX_WINDOW;
|
||||
while (samples.length > 2 && samples[0].t < cutoff) {
|
||||
samples.shift();
|
||||
}
|
||||
if (samples.length < MIN_SAMPLES) {
|
||||
return { rate: 0, eta: 0, stable: false };
|
||||
}
|
||||
const first = samples[0];
|
||||
const last = samples[samples.length - 1];
|
||||
const dt = (last.t - first.t) / 1000;
|
||||
const db = last.b - first.b;
|
||||
if (dt * 1000 < MIN_WINDOW || db <= 0) {
|
||||
return { rate: 0, eta: 0, stable: false };
|
||||
}
|
||||
const rate = db / dt;
|
||||
const eta =
|
||||
total > 0 && bytes < total && rate > 0 ? (total - bytes) / rate : 0;
|
||||
return { rate, eta, stable: true };
|
||||
}
|
||||
|
||||
function composeProgressLabel(
|
||||
dlGb: number,
|
||||
totalGb: number,
|
||||
bytes: number,
|
||||
total: number,
|
||||
samples: Sample[],
|
||||
): string {
|
||||
const base =
|
||||
totalGb > 0
|
||||
? `${dlGb.toFixed(1)} of ${totalGb.toFixed(1)} GB`
|
||||
: `${dlGb.toFixed(1)} GB downloaded`;
|
||||
const est = estimate(samples, bytes, total);
|
||||
if (!est.stable) return base;
|
||||
const rateStr = formatRate(est.rate);
|
||||
const etaStr = total > 0 ? formatEta(est.eta) : "";
|
||||
return etaStr && etaStr !== "--"
|
||||
? `${base} • ${rateStr} • ${etaStr} left`
|
||||
: `${base} • ${rateStr}`;
|
||||
}
|
||||
|
||||
let downloadComplete = isDownloaded || isCachedLora;
|
||||
|
||||
const pollDownload = async () => {
|
||||
if (abortCtrl.signal.aborted || !loadingModelRef.current) {
|
||||
if (progressInterval) clearInterval(progressInterval);
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const prog =
|
||||
ggufVariant && expectedBytes > 0
|
||||
? await getGgufDownloadProgress(modelId, ggufVariant, expectedBytes)
|
||||
: await getDownloadProgress(modelId);
|
||||
if (!loadingModelRef.current) return;
|
||||
|
||||
if (!loadingModelRef.current) return;
|
||||
|
||||
if (prog.progress > 0 && prog.progress < 1) {
|
||||
hasShownProgress = true;
|
||||
const dlGb = prog.downloaded_bytes / (1024 ** 3);
|
||||
const totalGb = prog.expected_bytes / (1024 ** 3);
|
||||
const pct = Math.round(prog.progress * 100);
|
||||
const progressLabel = totalGb > 0
|
||||
? `${dlGb.toFixed(1)} of ${totalGb.toFixed(1)} GB`
|
||||
: `${dlGb.toFixed(1)} GB downloaded`;
|
||||
setLoadProgress({
|
||||
percent: pct,
|
||||
label: progressLabel,
|
||||
phase: "downloading",
|
||||
});
|
||||
if (loadToastDismissedRef.current) return;
|
||||
toast(
|
||||
null,
|
||||
{
|
||||
id: toastId,
|
||||
description: renderLoadDescription(
|
||||
"Downloading model…",
|
||||
loadingDescription,
|
||||
pct,
|
||||
progressLabel,
|
||||
cancelLoading,
|
||||
),
|
||||
duration: Infinity,
|
||||
closeButton: false,
|
||||
classNames: MODEL_LOAD_TOAST_CLASSNAMES,
|
||||
onDismiss: (dismissedToast) => {
|
||||
if (loadToastIdRef.current !== dismissedToast.id) return;
|
||||
setLoadToastDismissedState(true);
|
||||
},
|
||||
},
|
||||
);
|
||||
} else if (prog.downloaded_bytes > 0 && prog.expected_bytes === 0 && prog.progress === 0) {
|
||||
hasShownProgress = true;
|
||||
const dlGb = prog.downloaded_bytes / (1024 ** 3);
|
||||
setLoadProgress({
|
||||
percent: null,
|
||||
label: `${dlGb.toFixed(1)} GB downloaded`,
|
||||
phase: "downloading",
|
||||
});
|
||||
} else if (prog.progress >= 1 && hasShownProgress) {
|
||||
setLoadProgress({
|
||||
percent: 100,
|
||||
label: "Download complete",
|
||||
phase: "starting",
|
||||
});
|
||||
if (loadToastDismissedRef.current) {
|
||||
if (progressInterval) clearInterval(progressInterval);
|
||||
return;
|
||||
}
|
||||
if (prog.progress > 0 && prog.progress < 1) {
|
||||
hasShownProgress = true;
|
||||
const dlGb = prog.downloaded_bytes / (1024 ** 3);
|
||||
const totalGb = prog.expected_bytes / (1024 ** 3);
|
||||
const pct = Math.round(prog.progress * 100);
|
||||
const progressLabel = composeProgressLabel(
|
||||
dlGb,
|
||||
totalGb,
|
||||
prog.downloaded_bytes,
|
||||
prog.expected_bytes,
|
||||
dlSamples,
|
||||
);
|
||||
setLoadProgress({
|
||||
percent: pct,
|
||||
label: progressLabel,
|
||||
phase: "downloading",
|
||||
});
|
||||
if (loadToastDismissedRef.current) return;
|
||||
toast(null, {
|
||||
id: toastId,
|
||||
description: renderLoadDescription(
|
||||
"Downloading model…",
|
||||
loadingDescription,
|
||||
pct,
|
||||
progressLabel,
|
||||
cancelLoading,
|
||||
),
|
||||
duration: Infinity,
|
||||
closeButton: false,
|
||||
classNames: MODEL_LOAD_TOAST_CLASSNAMES,
|
||||
onDismiss: (dismissedToast) => {
|
||||
if (loadToastIdRef.current !== dismissedToast.id) return;
|
||||
setLoadToastDismissedState(true);
|
||||
},
|
||||
});
|
||||
} else if (
|
||||
prog.downloaded_bytes > 0 &&
|
||||
prog.expected_bytes === 0 &&
|
||||
prog.progress === 0
|
||||
) {
|
||||
hasShownProgress = true;
|
||||
const dlGb = prog.downloaded_bytes / (1024 ** 3);
|
||||
const est = estimate(dlSamples, prog.downloaded_bytes, 0);
|
||||
const rateSuffix =
|
||||
est.stable ? ` • ${formatRate(est.rate)}` : "";
|
||||
setLoadProgress({
|
||||
percent: null,
|
||||
label: `${dlGb.toFixed(1)} GB downloaded${rateSuffix}`,
|
||||
phase: "downloading",
|
||||
});
|
||||
} else if (prog.progress >= 1 && hasShownProgress) {
|
||||
downloadComplete = true;
|
||||
setLoadProgress({
|
||||
percent: 100,
|
||||
label: "Download complete",
|
||||
phase: "starting",
|
||||
});
|
||||
if (!loadToastDismissedRef.current) {
|
||||
toast(null, {
|
||||
id: toastId,
|
||||
description: renderLoadDescription(
|
||||
|
|
@ -653,16 +725,79 @@ export function useChatModelRuntime() {
|
|||
setLoadToastDismissedState(true);
|
||||
},
|
||||
});
|
||||
if (progressInterval) clearInterval(progressInterval);
|
||||
}
|
||||
} catch {
|
||||
// Ignore polling errors
|
||||
// Keep polling: the mmap branch below takes over from here.
|
||||
}
|
||||
};
|
||||
} catch {
|
||||
// Ignore polling errors; keep polling.
|
||||
}
|
||||
};
|
||||
|
||||
setTimeout(pollProgress, 500);
|
||||
progressInterval = setInterval(pollProgress, 2000);
|
||||
}
|
||||
const pollLoad = async () => {
|
||||
if (abortCtrl.signal.aborted || !loadingModelRef.current) {
|
||||
if (progressInterval) clearInterval(progressInterval);
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const prog = await getLoadProgress();
|
||||
if (!loadingModelRef.current) return;
|
||||
if (!prog || prog.phase == null) return;
|
||||
if (prog.phase === "ready") {
|
||||
// Loaded. The chat flow will flip loadingModelRef shortly;
|
||||
// just stop polling.
|
||||
if (progressInterval) clearInterval(progressInterval);
|
||||
return;
|
||||
}
|
||||
if (prog.bytes_total <= 0) return; // nothing useful to render
|
||||
const loadedGb = prog.bytes_loaded / (1024 ** 3);
|
||||
const totalGb = prog.bytes_total / (1024 ** 3);
|
||||
const pct = Math.min(99, Math.round(prog.fraction * 100));
|
||||
const est = estimate(mmapSamples, prog.bytes_loaded, prog.bytes_total);
|
||||
const base = `${loadedGb.toFixed(1)} of ${totalGb.toFixed(1)} GB in memory`;
|
||||
const label = est.stable
|
||||
? `${base} • ${formatRate(est.rate)}${
|
||||
formatEta(est.eta) !== "--" ? ` • ${formatEta(est.eta)} left` : ""
|
||||
}`
|
||||
: base;
|
||||
setLoadProgress({
|
||||
percent: pct,
|
||||
label,
|
||||
phase: "starting",
|
||||
});
|
||||
if (loadToastDismissedRef.current) return;
|
||||
toast(null, {
|
||||
id: toastId,
|
||||
description: renderLoadDescription(
|
||||
"Starting model…",
|
||||
"Paging weights into memory.",
|
||||
pct,
|
||||
label,
|
||||
cancelLoading,
|
||||
),
|
||||
duration: Infinity,
|
||||
closeButton: false,
|
||||
classNames: MODEL_LOAD_TOAST_CLASSNAMES,
|
||||
onDismiss: (dismissedToast) => {
|
||||
if (loadToastIdRef.current !== dismissedToast.id) return;
|
||||
setLoadToastDismissedState(true);
|
||||
},
|
||||
});
|
||||
} catch {
|
||||
// Ignore polling errors.
|
||||
}
|
||||
};
|
||||
|
||||
const pollProgress = async () => {
|
||||
if (!downloadComplete) {
|
||||
await pollDownload();
|
||||
} else {
|
||||
await pollLoad();
|
||||
}
|
||||
};
|
||||
|
||||
let hasShownProgress = false;
|
||||
setTimeout(pollProgress, 500);
|
||||
progressInterval = setInterval(pollProgress, 2000);
|
||||
|
||||
try {
|
||||
await performLoad();
|
||||
|
|
|
|||
|
|
@ -0,0 +1,98 @@
|
|||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
|
||||
|
||||
/**
|
||||
* Compute rate (bytes/sec) and ETA (seconds) from a time-series of
|
||||
* cumulative ``bytes`` values, using a rolling window of recent samples.
|
||||
*
|
||||
* Shared between the chat-flow download toast, the training-start
|
||||
* overlay, and the model-load phase UI. All three have the same shape:
|
||||
* a counter that rises monotonically from 0 toward ``totalBytes``, polled
|
||||
* on an interval. The derived stats are identical regardless of whether
|
||||
* the bytes came from an HTTP download or an mmap page-in.
|
||||
*
|
||||
* Stability rule: ``stable`` stays ``false`` until we've observed at
|
||||
* least 3 samples spanning ≥3 seconds. That keeps the UI from flashing
|
||||
* wildly varying rates during the first tick or two when the denominator
|
||||
* is effectively zero.
|
||||
*/
|
||||
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
|
||||
export type TransferStats = {
|
||||
rateBytesPerSecond: number;
|
||||
etaSeconds: number;
|
||||
/**
|
||||
* False for the first few ticks (window not filled, or no forward
|
||||
* progress yet). Consumers should hide rate/ETA while unstable so the
|
||||
* UI doesn't flicker "123 GB/s" during the first tick.
|
||||
*/
|
||||
stable: boolean;
|
||||
};
|
||||
|
||||
const MIN_SAMPLES = 3;
|
||||
const MIN_WINDOW_SECONDS = 3;
|
||||
const MAX_WINDOW_SECONDS = 15;
|
||||
|
||||
export function useTransferStats(
|
||||
bytes: number | null | undefined,
|
||||
totalBytes: number | null | undefined,
|
||||
): TransferStats {
|
||||
const samplesRef = useRef<{ t: number; b: number }[]>([]);
|
||||
const [state, setState] = useState<TransferStats>({
|
||||
rateBytesPerSecond: 0,
|
||||
etaSeconds: 0,
|
||||
stable: false,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
const now = Date.now() / 1000;
|
||||
const cur = typeof bytes === "number" && Number.isFinite(bytes) ? bytes : 0;
|
||||
const total =
|
||||
typeof totalBytes === "number" && Number.isFinite(totalBytes)
|
||||
? totalBytes
|
||||
: 0;
|
||||
|
||||
// If the counter resets (e.g. user unloaded and started a new
|
||||
// download), drop the stale window.
|
||||
const samples = samplesRef.current;
|
||||
if (samples.length > 0 && cur < samples[samples.length - 1].b) {
|
||||
samples.length = 0;
|
||||
}
|
||||
|
||||
samples.push({ t: now, b: cur });
|
||||
|
||||
// Drop samples older than MAX_WINDOW_SECONDS; keep at least 2 so we
|
||||
// can still compute a rate when the counter hasn't moved in a while.
|
||||
const cutoff = now - MAX_WINDOW_SECONDS;
|
||||
while (samples.length > 2 && samples[0].t < cutoff) {
|
||||
samples.shift();
|
||||
}
|
||||
|
||||
if (samples.length < MIN_SAMPLES) {
|
||||
setState({ rateBytesPerSecond: 0, etaSeconds: 0, stable: false });
|
||||
return;
|
||||
}
|
||||
|
||||
const first = samples[0];
|
||||
const last = samples[samples.length - 1];
|
||||
const dt = last.t - first.t;
|
||||
const db = last.b - first.b;
|
||||
if (dt < MIN_WINDOW_SECONDS || db <= 0) {
|
||||
setState({ rateBytesPerSecond: 0, etaSeconds: 0, stable: false });
|
||||
return;
|
||||
}
|
||||
|
||||
const rate = db / dt;
|
||||
const eta =
|
||||
total > 0 && cur < total && rate > 0 ? (total - cur) / rate : 0;
|
||||
|
||||
setState({
|
||||
rateBytesPerSecond: rate,
|
||||
etaSeconds: eta,
|
||||
stable: true,
|
||||
});
|
||||
}, [bytes, totalBytes]);
|
||||
|
||||
return state;
|
||||
}
|
||||
44
studio/frontend/src/features/chat/utils/format-transfer.ts
Normal file
44
studio/frontend/src/features/chat/utils/format-transfer.ts
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
|
||||
|
||||
/**
|
||||
* Format a byte-per-second rate as a human-readable string.
|
||||
*
|
||||
* 512 → "512 B/s"
|
||||
* 1_234_567 → "1.2 MB/s"
|
||||
* 1_234_567_890 → "1.15 GB/s"
|
||||
*
|
||||
* Returns `"--"` for non-finite or non-positive inputs so callers can
|
||||
* render the label safely before the first stable sample arrives.
|
||||
*/
|
||||
export function formatRate(bytesPerSecond: number): string {
|
||||
if (!Number.isFinite(bytesPerSecond) || bytesPerSecond <= 0) return "--";
|
||||
const bps = bytesPerSecond;
|
||||
if (bps < 1024) return `${bps.toFixed(0)} B/s`;
|
||||
if (bps < 1024 ** 2) return `${(bps / 1024).toFixed(1)} KB/s`;
|
||||
if (bps < 1024 ** 3) return `${(bps / 1024 ** 2).toFixed(1)} MB/s`;
|
||||
return `${(bps / 1024 ** 3).toFixed(2)} GB/s`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Format an ETA (in seconds) as a short human-readable string.
|
||||
*
|
||||
* 47 → "47s"
|
||||
* 125 → "2m 5s"
|
||||
* 3725 → "1h 2m"
|
||||
*
|
||||
* Returns `"--"` for non-finite or non-positive inputs.
|
||||
*/
|
||||
export function formatEta(seconds: number): string {
|
||||
if (!Number.isFinite(seconds) || seconds <= 0) return "--";
|
||||
const s = Math.round(seconds);
|
||||
if (s < 60) return `${s}s`;
|
||||
if (s < 3600) {
|
||||
const m = Math.floor(s / 60);
|
||||
const rem = s % 60;
|
||||
return rem > 0 ? `${m}m ${rem}s` : `${m}m`;
|
||||
}
|
||||
const h = Math.floor(s / 3600);
|
||||
const m = Math.floor((s % 3600) / 60);
|
||||
return m > 0 ? `${h}h ${m}m` : `${h}h`;
|
||||
}
|
||||
|
|
@ -23,6 +23,8 @@ import {
|
|||
getDownloadProgress,
|
||||
type DownloadProgressResponse,
|
||||
} from "@/features/chat/api/chat-api";
|
||||
import { useTransferStats } from "@/features/chat/hooks/use-transfer-stats";
|
||||
import { formatEta, formatRate } from "@/features/chat/utils/format-transfer";
|
||||
import {
|
||||
useTrainingActions,
|
||||
useTrainingConfigStore,
|
||||
|
|
@ -151,6 +153,11 @@ type DownloadRowProps = {
|
|||
};
|
||||
|
||||
function DownloadRow({ label, state }: DownloadRowProps): ReactElement | null {
|
||||
// Compute a rolling-window rate + ETA from the same cumulative-byte
|
||||
// series the poll hook already produces, so we can show
|
||||
// "5.2 / 20.7 GB • 85.3 MB/s • 3m 12s left" instead of just the pair.
|
||||
const stats = useTransferStats(state.downloadedBytes, state.totalBytes);
|
||||
|
||||
if (state.downloadedBytes <= 0 && !state.cachePath) return null;
|
||||
const isComplete = state.totalBytes > 0 && state.percent >= 100;
|
||||
const statusLabel = isComplete
|
||||
|
|
@ -160,11 +167,16 @@ function DownloadRow({ label, state }: DownloadRowProps): ReactElement | null {
|
|||
: state.downloadedBytes === 0
|
||||
? "Preparing"
|
||||
: null;
|
||||
const showRate = stats.stable && !isComplete;
|
||||
const rateSuffix = showRate ? ` • ${formatRate(stats.rateBytesPerSecond)}` : "";
|
||||
const etaStr =
|
||||
showRate && state.totalBytes > 0 ? formatEta(stats.etaSeconds) : "--";
|
||||
const etaSuffix = etaStr !== "--" ? ` • ${etaStr} left` : "";
|
||||
const sizeLabel =
|
||||
state.totalBytes > 0
|
||||
? `${formatBytes(state.downloadedBytes)} / ${formatBytes(state.totalBytes)}`
|
||||
? `${formatBytes(state.downloadedBytes)} / ${formatBytes(state.totalBytes)}${rateSuffix}${etaSuffix}`
|
||||
: state.downloadedBytes > 0
|
||||
? `${formatBytes(state.downloadedBytes)} downloaded`
|
||||
? `${formatBytes(state.downloadedBytes)} downloaded${rateSuffix}`
|
||||
: null;
|
||||
return (
|
||||
<div className="flex flex-col gap-1.5 rounded-md border border-border/50 bg-muted/20 px-3 py-2">
|
||||
|
|
|
|||
Loading…
Reference in a new issue