studio: show HF model download progress in training start overlay (#4894)

* studio: show HF model download progress in training start overlay

During the training setup phase, the overlay only displayed a static
"Loading model..." line while model weights were being downloaded from
Hugging Face. On slow connections this looked like the app had frozen.

This adds a small self-contained progress block inside the existing
TrainingStartOverlay that polls the existing
GET /api/models/download-progress endpoint and renders a Progress bar
with bytes downloaded, total bytes, and percent complete.

Notes:

- Frontend only change. No backend, worker, SSE, or runtime store edits.
- Reuses the existing getDownloadProgress client wrapper and the
  existing /api/models/download-progress endpoint that already scans
  the HF blob cache for completed and .incomplete files.
- selectedModel is read directly from useTrainingConfigStore inside the
  overlay, so no prop drilling and live-training-view.tsx is unchanged.
- Polling runs at 1500 ms and is gated on the HF repo regex
  (^[A-Za-z0-9._-]+/[A-Za-z0-9._-]+$), the same regex the backend uses,
  so local paths and empty form state never hit the endpoint.
- Polling stops once progress reaches 1.0 so the bar can stay at 100
  until the overlay hides on the first training step.
- Network errors are silently swallowed, matching the chat side flow
  (the bar simply freezes at the last value).
- When downloadedBytes is 0 the block is hidden entirely, so cached
  models do not flash a progress bar.
- When the HF API cannot determine the total size, the block falls
  back to "X downloaded" with no percent and no bar.

Verified with bun run build (tsc -b plus vite build, no TypeScript
errors).

* training overlay: track dataset download + show on-disk realpath

Adds a dedicated "Downloading dataset..." section to the training-start
overlay alongside the existing model-weights one, so an HF dataset that
is downloading mid-startup is no longer mislabeled as model weights or
hidden entirely. The new GET /api/datasets/download-progress endpoint
mirrors /api/models/download-progress against the datasets-- prefix in
HF_HUB_CACHE.

Both endpoints now also return cache_path, the resolved on-disk
realpath of the snapshot directory (or the cache repo root if no
snapshot is materialized yet). The overlay surfaces this under each
download row so users can immediately see where the model and dataset
landed without digging through server logs.

The frontend's existing useModelDownloadProgress hook is generalized
to a single useHfDownloadProgress(repoId, fetcher) hook that the
model and dataset variants both delegate to, keeping polling, gating,
and completion semantics in one place.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Studio: Polish training start overlay download progress UI (#4957)

* studio: polish training start overlay download progress visuals

* Fix formatCachePath cross-platform support and redundant sizeLabel

- Extend formatCachePath regex to also shorten macOS /Users/<user> paths to ~
- Suppress sizeLabel when no byte info is available (cachePath-only state),
  since the "Preparing" badge already conveys the status

* Fix misleading status badge when download total is unknown

- Hide badge when totalBytes is 0 but downloadedBytes > 0, since we cannot
  determine if the download is still in progress or already complete (happens
  when HF size metadata lookup fails for gated/private repos)
- Keep "Preparing" badge for the zero-bytes cachePath-only state
- Add Windows native path shortening to formatCachePath (C:\Users\<name>)

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

---------

Co-authored-by: studio-install <studio@local.install>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Lee Jackson <130007945+Imagineer99@users.noreply.github.com>
This commit is contained in:
Daniel Han 2026-04-14 08:54:01 -07:00 committed by GitHub
parent 44082cf88e
commit eca592effe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 400 additions and 10 deletions

View file

@ -11,10 +11,55 @@ import json
import sys
from pathlib import Path
from uuid import uuid4
from fastapi import APIRouter, Depends, HTTPException, UploadFile
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile
import re as _re
import structlog
from loggers import get_logger
_VALID_REPO_ID = _re.compile(r"^[A-Za-z0-9._-]+/[A-Za-z0-9._-]+$")
def _is_valid_repo_id(repo_id: str) -> bool:
return bool(_VALID_REPO_ID.fullmatch(repo_id))
_dataset_size_cache: dict[str, int] = {}
def _get_dataset_size_cached(repo_id: str) -> int:
if repo_id in _dataset_size_cache:
return _dataset_size_cache[repo_id]
try:
from huggingface_hub import dataset_info as hf_dataset_info
info = hf_dataset_info(repo_id, token = None, files_metadata = True)
total = sum(s.size for s in info.siblings if getattr(s, "size", None))
_dataset_size_cache[repo_id] = total
return total
except Exception:
return 0
def _resolve_hf_cache_realpath(repo_dir: Path) -> Optional[str]:
"""Pick the most useful on-disk path for a HF cache repo dir.
Mirrors the helper in routes/models.py: prefer the most-recent
snapshot dir, fall back to the cache repo root, return resolved
realpath. Duplicated here to keep routes/datasets.py self-contained.
"""
try:
snapshots_dir = repo_dir / "snapshots"
if snapshots_dir.is_dir():
snaps = [s for s in snapshots_dir.iterdir() if s.is_dir()]
if snaps:
latest = max(snaps, key = lambda s: s.stat().st_mtime)
return str(latest.resolve())
return str(repo_dir.resolve())
except Exception:
return None
# Add backend directory to path
backend_path = Path(__file__).parent.parent.parent
if str(backend_path) not in sys.path:
@ -308,6 +353,89 @@ def list_local_datasets(
return LocalDatasetsResponse(datasets = _build_local_dataset_items())
@router.get("/download-progress")
async def get_dataset_download_progress(
repo_id: str = Query(
..., description = "HuggingFace dataset repo ID, e.g. 'unsloth/LaTeX_OCR'"
),
current_subject: str = Depends(get_current_subject),
):
"""Return download progress for a HuggingFace dataset repo.
Mirrors ``GET /api/models/download-progress`` but scans the
``datasets--owner--name`` cache directory under HF_HUB_CACHE.
Modern ``datasets``/``huggingface_hub`` caches both raw model and
raw dataset blobs in HF_HUB_CACHE; the ``datasets`` library writes
its processed Arrow shards elsewhere, but the in-progress *download*
bytes are observable here. Returns ``cache_path`` so the UI can
show users where the dataset blobs landed on disk.
"""
_empty = {
"downloaded_bytes": 0,
"expected_bytes": 0,
"progress": 0,
"cache_path": None,
}
try:
if not _is_valid_repo_id(repo_id):
return _empty
from huggingface_hub import constants as hf_constants
cache_dir = Path(hf_constants.HF_HUB_CACHE)
target = f"datasets--{repo_id.replace('/', '--')}".lower()
completed_bytes = 0
in_progress_bytes = 0
cache_path: Optional[str] = None
if cache_dir.is_dir():
for entry in cache_dir.iterdir():
if entry.name.lower() != target:
continue
cache_path = _resolve_hf_cache_realpath(entry)
blobs_dir = entry / "blobs"
if not blobs_dir.is_dir():
break
for f in blobs_dir.iterdir():
if not f.is_file():
continue
if f.name.endswith(".incomplete"):
in_progress_bytes += f.stat().st_size
else:
completed_bytes += f.stat().st_size
break
downloaded_bytes = completed_bytes + in_progress_bytes
if downloaded_bytes == 0:
return {**_empty, "cache_path": cache_path}
expected_bytes = _get_dataset_size_cached(repo_id)
if expected_bytes <= 0:
return {
"downloaded_bytes": downloaded_bytes,
"expected_bytes": 0,
"progress": 0,
"cache_path": cache_path,
}
# Same 95% completion threshold as the model endpoint -- HF blob
# dedup makes completed_bytes drift slightly under expected_bytes,
# and inter-file gaps would otherwise look like "done".
if completed_bytes >= expected_bytes * 0.95:
progress = 1.0
else:
progress = min(downloaded_bytes / expected_bytes, 0.99)
return {
"downloaded_bytes": downloaded_bytes,
"expected_bytes": expected_bytes,
"progress": round(progress, 3),
"cache_path": cache_path,
}
except Exception as e:
logger.warning(f"Error checking dataset download progress for {repo_id}: {e}")
return _empty
@router.post("/check-format", response_model = CheckFormatResponse)
def check_format(
request: CheckFormatRequest,

View file

@ -1091,6 +1091,25 @@ async def get_gguf_download_progress(
return {"downloaded_bytes": 0, "expected_bytes": expected_bytes, "progress": 0}
def _resolve_hf_cache_realpath(repo_dir: Path) -> Optional[str]:
"""Pick the most useful on-disk path for a HF cache repo.
Prefers the most-recent snapshot dir (what `from_pretrained` actually
points at). Falls back to the cache repo root. Returns the resolved
realpath so symlinks under snapshots/ are followed back to blobs/.
"""
try:
snapshots_dir = repo_dir / "snapshots"
if snapshots_dir.is_dir():
snaps = [s for s in snapshots_dir.iterdir() if s.is_dir()]
if snaps:
latest = max(snaps, key = lambda s: s.stat().st_mtime)
return str(latest.resolve())
return str(repo_dir.resolve())
except Exception:
return None
@router.get("/download-progress")
async def get_download_progress(
repo_id: str = Query(..., description = "HuggingFace repo ID"),
@ -1101,8 +1120,16 @@ async def get_download_progress(
Checks the local HF cache for completed blobs and in-progress
(.incomplete) downloads. Uses the HF API to determine the expected
total size on the first call, then caches it for subsequent polls.
Also returns ``cache_path``: the realpath of the snapshot directory
(or the cache repo root if no snapshot exists yet) so the UI can
show users where the weights actually live on disk.
"""
_empty = {"downloaded_bytes": 0, "expected_bytes": 0, "progress": 0}
_empty = {
"downloaded_bytes": 0,
"expected_bytes": 0,
"progress": 0,
"cache_path": None,
}
try:
if not _is_valid_repo_id(repo_id):
return _empty
@ -1113,10 +1140,12 @@ async def get_download_progress(
target = f"models--{repo_id.replace('/', '--')}".lower()
completed_bytes = 0
in_progress_bytes = 0
cache_path: Optional[str] = None
for entry in cache_dir.iterdir():
if entry.name.lower() != target:
continue
cache_path = _resolve_hf_cache_realpath(entry)
blobs_dir = entry / "blobs"
if not blobs_dir.is_dir():
break
@ -1131,7 +1160,7 @@ async def get_download_progress(
downloaded_bytes = completed_bytes + in_progress_bytes
if downloaded_bytes == 0:
return _empty
return {**_empty, "cache_path": cache_path}
# Get expected size from HF API (cached per repo_id)
expected_bytes = _get_repo_size_cached(repo_id)
@ -1141,6 +1170,7 @@ async def get_download_progress(
"downloaded_bytes": downloaded_bytes,
"expected_bytes": 0,
"progress": 0,
"cache_path": cache_path,
}
# Use 95% threshold for completion (blob deduplication can make
@ -1156,6 +1186,7 @@ async def get_download_progress(
"downloaded_bytes": downloaded_bytes,
"expected_bytes": expected_bytes,
"progress": round(progress, 3),
"cache_path": cache_path,
}
except Exception as e:
logger.warning(f"Error checking download progress for {repo_id}: {e}")

View file

@ -1,6 +1,6 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
// SPDX-License-Identifier: AGPL-3.0-only
// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"use client";
import { Progress as ProgressPrimitive } from "radix-ui";
@ -10,9 +10,12 @@ import { cn } from "@/lib/utils";
function Progress({
className,
indicatorClassName,
value,
...props
}: React.ComponentProps<typeof ProgressPrimitive.Root>) {
}: React.ComponentProps<typeof ProgressPrimitive.Root> & {
indicatorClassName?: string;
}) {
return (
<ProgressPrimitive.Root
data-slot="progress"
@ -24,7 +27,7 @@ function Progress({
>
<ProgressPrimitive.Indicator
data-slot="progress-indicator"
className="bg-primary size-full flex-1 transition-all"
className={cn("bg-primary size-full flex-1 transition-all", indicatorClassName)}
style={{ transform: `translateX(-${100 - (value || 0)}%)` }}
/>
</ProgressPrimitive.Root>

View file

@ -117,14 +117,34 @@ export async function getGgufDownloadProgress(
return parseJsonOrThrow(response);
}
export interface DownloadProgressResponse {
downloaded_bytes: number;
expected_bytes: number;
progress: number;
/**
* Resolved on-disk path of the snapshot dir (or cache repo root if no
* snapshot exists yet). Null when nothing has been written to the
* cache for this repo.
*/
cache_path: string | null;
}
export async function getDownloadProgress(
repoId: string,
): Promise<{ downloaded_bytes: number; expected_bytes: number; progress: number }> {
): Promise<DownloadProgressResponse> {
const params = new URLSearchParams({ repo_id: repoId });
const response = await authFetch(`/api/models/download-progress?${params}`);
return parseJsonOrThrow(response);
}
export async function getDatasetDownloadProgress(
repoId: string,
): Promise<DownloadProgressResponse> {
const params = new URLSearchParams({ repo_id: repoId });
const response = await authFetch(`/api/datasets/download-progress?${params}`);
return parseJsonOrThrow(response);
}
export interface LocalModelInfo {
id: string;
display_name: string;

View file

@ -12,16 +12,200 @@ import {
AlertDialogTitle,
} from "@/components/ui/alert-dialog";
import { Button } from "@/components/ui/button";
import { Progress } from "@/components/ui/progress";
import {
AnimatedSpan,
Terminal,
TypingAnimation,
} from "@/components/ui/terminal";
import { useTrainingActions, useTrainingRuntimeStore } from "@/features/training";
import {
getDatasetDownloadProgress,
getDownloadProgress,
type DownloadProgressResponse,
} from "@/features/chat/api/chat-api";
import {
useTrainingActions,
useTrainingConfigStore,
useTrainingRuntimeStore,
} from "@/features/training";
import { Cancel01Icon } from "@hugeicons/core-free-icons";
import { HugeiconsIcon } from "@hugeicons/react";
import { useEffect, useState, type ReactElement } from "react";
const HF_REPO_REGEX = /^[A-Za-z0-9._-]+\/[A-Za-z0-9._-]+$/;
function formatBytes(n: number): string {
if (n <= 0) return "0 B";
if (n < 1024) return `${n} B`;
if (n < 1024 ** 2) return `${(n / 1024).toFixed(1)} KB`;
if (n < 1024 ** 3) return `${(n / 1024 ** 2).toFixed(1)} MB`;
return `${(n / 1024 ** 3).toFixed(2)} GB`;
}
function formatCachePath(path: string): string {
return path
.replace(/^\/(?:home|Users)\/[^/]+/, "~")
.replace(/^[A-Za-z]:[/\\]Users[/\\][^/\\]+/, "~");
}
type DownloadState = {
downloadedBytes: number;
totalBytes: number;
percent: number;
cachePath: string | null;
};
const EMPTY_DOWNLOAD_STATE: DownloadState = {
downloadedBytes: 0,
totalBytes: 0,
percent: 0,
cachePath: null,
};
type Fetcher = (repoId: string) => Promise<DownloadProgressResponse>;
/**
* Polls a HF repo's download progress on a 1.5s tick. Used for both
* model weights (`/api/models/download-progress`) and dataset blobs
* (`/api/datasets/download-progress`) by swapping the fetcher.
*
* Stops polling once `progress >= 1.0` -- the bar freezes at the final
* value rather than disappearing, mirroring the existing chat flow.
*/
function useHfDownloadProgress(
repoId: string | null,
fetcher: Fetcher,
): DownloadState {
const phase = useTrainingRuntimeStore((s) => s.phase);
const isStarting = useTrainingRuntimeStore((s) => s.isStarting);
const [state, setState] = useState<DownloadState>(EMPTY_DOWNLOAD_STATE);
const shouldPoll =
isStarting ||
phase === "configuring" ||
phase === "downloading_model" ||
phase === "downloading_dataset" ||
phase === "loading_model" ||
phase === "loading_dataset";
useEffect(() => {
if (!repoId || !HF_REPO_REGEX.test(repoId) || !shouldPoll) {
return;
}
let cancelled = false;
let finished = false;
let interval: ReturnType<typeof setInterval> | null = null;
const poll = async () => {
if (cancelled || finished) return;
try {
const prog = await fetcher(repoId);
if (cancelled) return;
const downloaded = prog.downloaded_bytes ?? 0;
const total = prog.expected_bytes ?? 0;
const ratio = prog.progress ?? 0;
const pct =
total > 0 ? Math.min(100, Math.round(ratio * 100)) : 0;
setState({
downloadedBytes: downloaded,
totalBytes: total,
percent: pct,
cachePath: prog.cache_path ?? null,
});
if (ratio >= 1.0) {
finished = true;
if (interval) {
clearInterval(interval);
interval = null;
}
}
} catch {
// Silently swallow; bar freezes at last value (matches chat flow).
}
};
void poll();
interval = setInterval(poll, 1500);
return () => {
cancelled = true;
if (interval) clearInterval(interval);
};
}, [repoId, shouldPoll, fetcher]);
return state;
}
function useModelDownloadProgress(modelName: string | null): DownloadState {
return useHfDownloadProgress(modelName, getDownloadProgress);
}
function useDatasetDownloadProgress(datasetName: string | null): DownloadState {
return useHfDownloadProgress(datasetName, getDatasetDownloadProgress);
}
type DownloadRowProps = {
label: string;
state: DownloadState;
};
function DownloadRow({ label, state }: DownloadRowProps): ReactElement | null {
if (state.downloadedBytes <= 0 && !state.cachePath) return null;
const isComplete = state.totalBytes > 0 && state.percent >= 100;
const statusLabel = isComplete
? "Ready"
: state.totalBytes > 0
? "Downloading"
: state.downloadedBytes === 0
? "Preparing"
: null;
const sizeLabel =
state.totalBytes > 0
? `${formatBytes(state.downloadedBytes)} / ${formatBytes(state.totalBytes)}`
: state.downloadedBytes > 0
? `${formatBytes(state.downloadedBytes)} downloaded`
: null;
return (
<div className="flex flex-col gap-1.5 rounded-md border border-border/50 bg-muted/20 px-3 py-2">
<div className="flex items-center justify-between gap-3">
<div className="flex items-center gap-2">
<span className="text-xs text-foreground/90">{label}</span>
{statusLabel ? (
<span
className={`rounded-full px-1.5 py-0.5 text-[10px] font-medium ${isComplete ? "bg-emerald-100 text-emerald-700 ring-1 ring-emerald-200/80 dark:bg-emerald-500/15 dark:text-emerald-300 dark:ring-emerald-500/30" : "bg-muted text-muted-foreground"}`}
>
{statusLabel}
</span>
) : null}
</div>
<span className="text-xs tabular-nums text-muted-foreground">
{state.totalBytes > 0 ? `${state.percent}%` : ""}
</span>
</div>
{sizeLabel ? (
<div className="text-[11px] tabular-nums text-muted-foreground">
{sizeLabel}
</div>
) : null}
{state.totalBytes > 0 ? (
<Progress
value={state.percent}
indicatorClassName="bg-[linear-gradient(90deg,oklch(0.66_0.142_166.6)_0%,oklch(0.705_0.132_166.6)_55%,oklch(0.75_0.122_166.6)_100%)]"
/>
) : null}
{state.cachePath ? (
<div
className="truncate rounded bg-muted/50 px-2 py-1 text-[10px] text-muted-foreground/70"
title={state.cachePath}
>
{formatCachePath(state.cachePath)}
</div>
) : null}
</div>
);
}
type TrainingStartOverlayProps = {
message: string
currentStep: number
@ -33,6 +217,14 @@ export function TrainingStartOverlay({
}: TrainingStartOverlayProps): ReactElement {
const { stopTrainingRun, dismissTrainingRun } = useTrainingActions();
const isStarting = useTrainingRuntimeStore((s) => s.isStarting);
const selectedModel = useTrainingConfigStore((s) => s.selectedModel);
const datasetSource = useTrainingConfigStore((s) => s.datasetSource);
const dataset = useTrainingConfigStore((s) => s.dataset);
// Only HF datasets have a download phase to track. Uploaded files are
// already on disk by the time the overlay shows up.
const hfDatasetName = datasetSource === "huggingface" ? dataset : null;
const modelDownload = useModelDownloadProgress(selectedModel);
const datasetDownload = useDatasetDownloadProgress(hfDatasetName);
const [cancelDialogOpen, setCancelDialogOpen] = useState(false);
const [cancelRequested, setCancelRequested] = useState(false);
@ -112,6 +304,22 @@ export function TrainingStartOverlay({
<AnimatedSpan className="mt-2 text-muted-foreground">
{`> ${message || "starting training..."} | waiting for first step... (${currentStep})`}
</AnimatedSpan>
{datasetDownload.downloadedBytes > 0 || datasetDownload.cachePath ? (
<AnimatedSpan className="mt-3">
<DownloadRow
label="Dataset"
state={datasetDownload}
/>
</AnimatedSpan>
) : null}
{modelDownload.downloadedBytes > 0 || modelDownload.cachePath ? (
<AnimatedSpan className="mt-3">
<DownloadRow
label="Model weights"
state={modelDownload}
/>
</AnimatedSpan>
) : null}
</Terminal>
</div>
</div>