mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
studio: show HF model download progress in training start overlay (#4894)
* studio: show HF model download progress in training start overlay During the training setup phase, the overlay only displayed a static "Loading model..." line while model weights were being downloaded from Hugging Face. On slow connections this looked like the app had frozen. This adds a small self-contained progress block inside the existing TrainingStartOverlay that polls the existing GET /api/models/download-progress endpoint and renders a Progress bar with bytes downloaded, total bytes, and percent complete. Notes: - Frontend only change. No backend, worker, SSE, or runtime store edits. - Reuses the existing getDownloadProgress client wrapper and the existing /api/models/download-progress endpoint that already scans the HF blob cache for completed and .incomplete files. - selectedModel is read directly from useTrainingConfigStore inside the overlay, so no prop drilling and live-training-view.tsx is unchanged. - Polling runs at 1500 ms and is gated on the HF repo regex (^[A-Za-z0-9._-]+/[A-Za-z0-9._-]+$), the same regex the backend uses, so local paths and empty form state never hit the endpoint. - Polling stops once progress reaches 1.0 so the bar can stay at 100 until the overlay hides on the first training step. - Network errors are silently swallowed, matching the chat side flow (the bar simply freezes at the last value). - When downloadedBytes is 0 the block is hidden entirely, so cached models do not flash a progress bar. - When the HF API cannot determine the total size, the block falls back to "X downloaded" with no percent and no bar. Verified with bun run build (tsc -b plus vite build, no TypeScript errors). * training overlay: track dataset download + show on-disk realpath Adds a dedicated "Downloading dataset..." section to the training-start overlay alongside the existing model-weights one, so an HF dataset that is downloading mid-startup is no longer mislabeled as model weights or hidden entirely. The new GET /api/datasets/download-progress endpoint mirrors /api/models/download-progress against the datasets-- prefix in HF_HUB_CACHE. Both endpoints now also return cache_path, the resolved on-disk realpath of the snapshot directory (or the cache repo root if no snapshot is materialized yet). The overlay surfaces this under each download row so users can immediately see where the model and dataset landed without digging through server logs. The frontend's existing useModelDownloadProgress hook is generalized to a single useHfDownloadProgress(repoId, fetcher) hook that the model and dataset variants both delegate to, keeping polling, gating, and completion semantics in one place. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Studio: Polish training start overlay download progress UI (#4957) * studio: polish training start overlay download progress visuals * Fix formatCachePath cross-platform support and redundant sizeLabel - Extend formatCachePath regex to also shorten macOS /Users/<user> paths to ~ - Suppress sizeLabel when no byte info is available (cachePath-only state), since the "Preparing" badge already conveys the status * Fix misleading status badge when download total is unknown - Hide badge when totalBytes is 0 but downloadedBytes > 0, since we cannot determine if the download is still in progress or already complete (happens when HF size metadata lookup fails for gated/private repos) - Keep "Preparing" badge for the zero-bytes cachePath-only state - Add Windows native path shortening to formatCachePath (C:\Users\<name>) --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> --------- Co-authored-by: studio-install <studio@local.install> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Lee Jackson <130007945+Imagineer99@users.noreply.github.com>
This commit is contained in:
parent
44082cf88e
commit
eca592effe
5 changed files with 400 additions and 10 deletions
|
|
@ -11,10 +11,55 @@ import json
|
|||
import sys
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile
|
||||
import re as _re
|
||||
import structlog
|
||||
from loggers import get_logger
|
||||
|
||||
_VALID_REPO_ID = _re.compile(r"^[A-Za-z0-9._-]+/[A-Za-z0-9._-]+$")
|
||||
|
||||
|
||||
def _is_valid_repo_id(repo_id: str) -> bool:
|
||||
return bool(_VALID_REPO_ID.fullmatch(repo_id))
|
||||
|
||||
|
||||
_dataset_size_cache: dict[str, int] = {}
|
||||
|
||||
|
||||
def _get_dataset_size_cached(repo_id: str) -> int:
|
||||
if repo_id in _dataset_size_cache:
|
||||
return _dataset_size_cache[repo_id]
|
||||
try:
|
||||
from huggingface_hub import dataset_info as hf_dataset_info
|
||||
|
||||
info = hf_dataset_info(repo_id, token = None, files_metadata = True)
|
||||
total = sum(s.size for s in info.siblings if getattr(s, "size", None))
|
||||
_dataset_size_cache[repo_id] = total
|
||||
return total
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def _resolve_hf_cache_realpath(repo_dir: Path) -> Optional[str]:
|
||||
"""Pick the most useful on-disk path for a HF cache repo dir.
|
||||
|
||||
Mirrors the helper in routes/models.py: prefer the most-recent
|
||||
snapshot dir, fall back to the cache repo root, return resolved
|
||||
realpath. Duplicated here to keep routes/datasets.py self-contained.
|
||||
"""
|
||||
try:
|
||||
snapshots_dir = repo_dir / "snapshots"
|
||||
if snapshots_dir.is_dir():
|
||||
snaps = [s for s in snapshots_dir.iterdir() if s.is_dir()]
|
||||
if snaps:
|
||||
latest = max(snaps, key = lambda s: s.stat().st_mtime)
|
||||
return str(latest.resolve())
|
||||
return str(repo_dir.resolve())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
# Add backend directory to path
|
||||
backend_path = Path(__file__).parent.parent.parent
|
||||
if str(backend_path) not in sys.path:
|
||||
|
|
@ -308,6 +353,89 @@ def list_local_datasets(
|
|||
return LocalDatasetsResponse(datasets = _build_local_dataset_items())
|
||||
|
||||
|
||||
@router.get("/download-progress")
|
||||
async def get_dataset_download_progress(
|
||||
repo_id: str = Query(
|
||||
..., description = "HuggingFace dataset repo ID, e.g. 'unsloth/LaTeX_OCR'"
|
||||
),
|
||||
current_subject: str = Depends(get_current_subject),
|
||||
):
|
||||
"""Return download progress for a HuggingFace dataset repo.
|
||||
|
||||
Mirrors ``GET /api/models/download-progress`` but scans the
|
||||
``datasets--owner--name`` cache directory under HF_HUB_CACHE.
|
||||
Modern ``datasets``/``huggingface_hub`` caches both raw model and
|
||||
raw dataset blobs in HF_HUB_CACHE; the ``datasets`` library writes
|
||||
its processed Arrow shards elsewhere, but the in-progress *download*
|
||||
bytes are observable here. Returns ``cache_path`` so the UI can
|
||||
show users where the dataset blobs landed on disk.
|
||||
"""
|
||||
_empty = {
|
||||
"downloaded_bytes": 0,
|
||||
"expected_bytes": 0,
|
||||
"progress": 0,
|
||||
"cache_path": None,
|
||||
}
|
||||
try:
|
||||
if not _is_valid_repo_id(repo_id):
|
||||
return _empty
|
||||
|
||||
from huggingface_hub import constants as hf_constants
|
||||
|
||||
cache_dir = Path(hf_constants.HF_HUB_CACHE)
|
||||
target = f"datasets--{repo_id.replace('/', '--')}".lower()
|
||||
completed_bytes = 0
|
||||
in_progress_bytes = 0
|
||||
cache_path: Optional[str] = None
|
||||
|
||||
if cache_dir.is_dir():
|
||||
for entry in cache_dir.iterdir():
|
||||
if entry.name.lower() != target:
|
||||
continue
|
||||
cache_path = _resolve_hf_cache_realpath(entry)
|
||||
blobs_dir = entry / "blobs"
|
||||
if not blobs_dir.is_dir():
|
||||
break
|
||||
for f in blobs_dir.iterdir():
|
||||
if not f.is_file():
|
||||
continue
|
||||
if f.name.endswith(".incomplete"):
|
||||
in_progress_bytes += f.stat().st_size
|
||||
else:
|
||||
completed_bytes += f.stat().st_size
|
||||
break
|
||||
|
||||
downloaded_bytes = completed_bytes + in_progress_bytes
|
||||
if downloaded_bytes == 0:
|
||||
return {**_empty, "cache_path": cache_path}
|
||||
|
||||
expected_bytes = _get_dataset_size_cached(repo_id)
|
||||
if expected_bytes <= 0:
|
||||
return {
|
||||
"downloaded_bytes": downloaded_bytes,
|
||||
"expected_bytes": 0,
|
||||
"progress": 0,
|
||||
"cache_path": cache_path,
|
||||
}
|
||||
|
||||
# Same 95% completion threshold as the model endpoint -- HF blob
|
||||
# dedup makes completed_bytes drift slightly under expected_bytes,
|
||||
# and inter-file gaps would otherwise look like "done".
|
||||
if completed_bytes >= expected_bytes * 0.95:
|
||||
progress = 1.0
|
||||
else:
|
||||
progress = min(downloaded_bytes / expected_bytes, 0.99)
|
||||
return {
|
||||
"downloaded_bytes": downloaded_bytes,
|
||||
"expected_bytes": expected_bytes,
|
||||
"progress": round(progress, 3),
|
||||
"cache_path": cache_path,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking dataset download progress for {repo_id}: {e}")
|
||||
return _empty
|
||||
|
||||
|
||||
@router.post("/check-format", response_model = CheckFormatResponse)
|
||||
def check_format(
|
||||
request: CheckFormatRequest,
|
||||
|
|
|
|||
|
|
@ -1091,6 +1091,25 @@ async def get_gguf_download_progress(
|
|||
return {"downloaded_bytes": 0, "expected_bytes": expected_bytes, "progress": 0}
|
||||
|
||||
|
||||
def _resolve_hf_cache_realpath(repo_dir: Path) -> Optional[str]:
|
||||
"""Pick the most useful on-disk path for a HF cache repo.
|
||||
|
||||
Prefers the most-recent snapshot dir (what `from_pretrained` actually
|
||||
points at). Falls back to the cache repo root. Returns the resolved
|
||||
realpath so symlinks under snapshots/ are followed back to blobs/.
|
||||
"""
|
||||
try:
|
||||
snapshots_dir = repo_dir / "snapshots"
|
||||
if snapshots_dir.is_dir():
|
||||
snaps = [s for s in snapshots_dir.iterdir() if s.is_dir()]
|
||||
if snaps:
|
||||
latest = max(snaps, key = lambda s: s.stat().st_mtime)
|
||||
return str(latest.resolve())
|
||||
return str(repo_dir.resolve())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/download-progress")
|
||||
async def get_download_progress(
|
||||
repo_id: str = Query(..., description = "HuggingFace repo ID"),
|
||||
|
|
@ -1101,8 +1120,16 @@ async def get_download_progress(
|
|||
Checks the local HF cache for completed blobs and in-progress
|
||||
(.incomplete) downloads. Uses the HF API to determine the expected
|
||||
total size on the first call, then caches it for subsequent polls.
|
||||
Also returns ``cache_path``: the realpath of the snapshot directory
|
||||
(or the cache repo root if no snapshot exists yet) so the UI can
|
||||
show users where the weights actually live on disk.
|
||||
"""
|
||||
_empty = {"downloaded_bytes": 0, "expected_bytes": 0, "progress": 0}
|
||||
_empty = {
|
||||
"downloaded_bytes": 0,
|
||||
"expected_bytes": 0,
|
||||
"progress": 0,
|
||||
"cache_path": None,
|
||||
}
|
||||
try:
|
||||
if not _is_valid_repo_id(repo_id):
|
||||
return _empty
|
||||
|
|
@ -1113,10 +1140,12 @@ async def get_download_progress(
|
|||
target = f"models--{repo_id.replace('/', '--')}".lower()
|
||||
completed_bytes = 0
|
||||
in_progress_bytes = 0
|
||||
cache_path: Optional[str] = None
|
||||
|
||||
for entry in cache_dir.iterdir():
|
||||
if entry.name.lower() != target:
|
||||
continue
|
||||
cache_path = _resolve_hf_cache_realpath(entry)
|
||||
blobs_dir = entry / "blobs"
|
||||
if not blobs_dir.is_dir():
|
||||
break
|
||||
|
|
@ -1131,7 +1160,7 @@ async def get_download_progress(
|
|||
|
||||
downloaded_bytes = completed_bytes + in_progress_bytes
|
||||
if downloaded_bytes == 0:
|
||||
return _empty
|
||||
return {**_empty, "cache_path": cache_path}
|
||||
|
||||
# Get expected size from HF API (cached per repo_id)
|
||||
expected_bytes = _get_repo_size_cached(repo_id)
|
||||
|
|
@ -1141,6 +1170,7 @@ async def get_download_progress(
|
|||
"downloaded_bytes": downloaded_bytes,
|
||||
"expected_bytes": 0,
|
||||
"progress": 0,
|
||||
"cache_path": cache_path,
|
||||
}
|
||||
|
||||
# Use 95% threshold for completion (blob deduplication can make
|
||||
|
|
@ -1156,6 +1186,7 @@ async def get_download_progress(
|
|||
"downloaded_bytes": downloaded_bytes,
|
||||
"expected_bytes": expected_bytes,
|
||||
"progress": round(progress, 3),
|
||||
"cache_path": cache_path,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking download progress for {repo_id}: {e}")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
|
||||
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
|
||||
|
||||
"use client";
|
||||
|
||||
import { Progress as ProgressPrimitive } from "radix-ui";
|
||||
|
|
@ -10,9 +10,12 @@ import { cn } from "@/lib/utils";
|
|||
|
||||
function Progress({
|
||||
className,
|
||||
indicatorClassName,
|
||||
value,
|
||||
...props
|
||||
}: React.ComponentProps<typeof ProgressPrimitive.Root>) {
|
||||
}: React.ComponentProps<typeof ProgressPrimitive.Root> & {
|
||||
indicatorClassName?: string;
|
||||
}) {
|
||||
return (
|
||||
<ProgressPrimitive.Root
|
||||
data-slot="progress"
|
||||
|
|
@ -24,7 +27,7 @@ function Progress({
|
|||
>
|
||||
<ProgressPrimitive.Indicator
|
||||
data-slot="progress-indicator"
|
||||
className="bg-primary size-full flex-1 transition-all"
|
||||
className={cn("bg-primary size-full flex-1 transition-all", indicatorClassName)}
|
||||
style={{ transform: `translateX(-${100 - (value || 0)}%)` }}
|
||||
/>
|
||||
</ProgressPrimitive.Root>
|
||||
|
|
|
|||
|
|
@ -117,14 +117,34 @@ export async function getGgufDownloadProgress(
|
|||
return parseJsonOrThrow(response);
|
||||
}
|
||||
|
||||
export interface DownloadProgressResponse {
|
||||
downloaded_bytes: number;
|
||||
expected_bytes: number;
|
||||
progress: number;
|
||||
/**
|
||||
* Resolved on-disk path of the snapshot dir (or cache repo root if no
|
||||
* snapshot exists yet). Null when nothing has been written to the
|
||||
* cache for this repo.
|
||||
*/
|
||||
cache_path: string | null;
|
||||
}
|
||||
|
||||
export async function getDownloadProgress(
|
||||
repoId: string,
|
||||
): Promise<{ downloaded_bytes: number; expected_bytes: number; progress: number }> {
|
||||
): Promise<DownloadProgressResponse> {
|
||||
const params = new URLSearchParams({ repo_id: repoId });
|
||||
const response = await authFetch(`/api/models/download-progress?${params}`);
|
||||
return parseJsonOrThrow(response);
|
||||
}
|
||||
|
||||
export async function getDatasetDownloadProgress(
|
||||
repoId: string,
|
||||
): Promise<DownloadProgressResponse> {
|
||||
const params = new URLSearchParams({ repo_id: repoId });
|
||||
const response = await authFetch(`/api/datasets/download-progress?${params}`);
|
||||
return parseJsonOrThrow(response);
|
||||
}
|
||||
|
||||
export interface LocalModelInfo {
|
||||
id: string;
|
||||
display_name: string;
|
||||
|
|
|
|||
|
|
@ -12,16 +12,200 @@ import {
|
|||
AlertDialogTitle,
|
||||
} from "@/components/ui/alert-dialog";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Progress } from "@/components/ui/progress";
|
||||
import {
|
||||
AnimatedSpan,
|
||||
Terminal,
|
||||
TypingAnimation,
|
||||
} from "@/components/ui/terminal";
|
||||
import { useTrainingActions, useTrainingRuntimeStore } from "@/features/training";
|
||||
import {
|
||||
getDatasetDownloadProgress,
|
||||
getDownloadProgress,
|
||||
type DownloadProgressResponse,
|
||||
} from "@/features/chat/api/chat-api";
|
||||
import {
|
||||
useTrainingActions,
|
||||
useTrainingConfigStore,
|
||||
useTrainingRuntimeStore,
|
||||
} from "@/features/training";
|
||||
import { Cancel01Icon } from "@hugeicons/core-free-icons";
|
||||
import { HugeiconsIcon } from "@hugeicons/react";
|
||||
import { useEffect, useState, type ReactElement } from "react";
|
||||
|
||||
const HF_REPO_REGEX = /^[A-Za-z0-9._-]+\/[A-Za-z0-9._-]+$/;
|
||||
|
||||
function formatBytes(n: number): string {
|
||||
if (n <= 0) return "0 B";
|
||||
if (n < 1024) return `${n} B`;
|
||||
if (n < 1024 ** 2) return `${(n / 1024).toFixed(1)} KB`;
|
||||
if (n < 1024 ** 3) return `${(n / 1024 ** 2).toFixed(1)} MB`;
|
||||
return `${(n / 1024 ** 3).toFixed(2)} GB`;
|
||||
}
|
||||
|
||||
function formatCachePath(path: string): string {
|
||||
return path
|
||||
.replace(/^\/(?:home|Users)\/[^/]+/, "~")
|
||||
.replace(/^[A-Za-z]:[/\\]Users[/\\][^/\\]+/, "~");
|
||||
}
|
||||
|
||||
type DownloadState = {
|
||||
downloadedBytes: number;
|
||||
totalBytes: number;
|
||||
percent: number;
|
||||
cachePath: string | null;
|
||||
};
|
||||
|
||||
const EMPTY_DOWNLOAD_STATE: DownloadState = {
|
||||
downloadedBytes: 0,
|
||||
totalBytes: 0,
|
||||
percent: 0,
|
||||
cachePath: null,
|
||||
};
|
||||
|
||||
type Fetcher = (repoId: string) => Promise<DownloadProgressResponse>;
|
||||
|
||||
/**
|
||||
* Polls a HF repo's download progress on a 1.5s tick. Used for both
|
||||
* model weights (`/api/models/download-progress`) and dataset blobs
|
||||
* (`/api/datasets/download-progress`) by swapping the fetcher.
|
||||
*
|
||||
* Stops polling once `progress >= 1.0` -- the bar freezes at the final
|
||||
* value rather than disappearing, mirroring the existing chat flow.
|
||||
*/
|
||||
function useHfDownloadProgress(
|
||||
repoId: string | null,
|
||||
fetcher: Fetcher,
|
||||
): DownloadState {
|
||||
const phase = useTrainingRuntimeStore((s) => s.phase);
|
||||
const isStarting = useTrainingRuntimeStore((s) => s.isStarting);
|
||||
const [state, setState] = useState<DownloadState>(EMPTY_DOWNLOAD_STATE);
|
||||
|
||||
const shouldPoll =
|
||||
isStarting ||
|
||||
phase === "configuring" ||
|
||||
phase === "downloading_model" ||
|
||||
phase === "downloading_dataset" ||
|
||||
phase === "loading_model" ||
|
||||
phase === "loading_dataset";
|
||||
|
||||
useEffect(() => {
|
||||
if (!repoId || !HF_REPO_REGEX.test(repoId) || !shouldPoll) {
|
||||
return;
|
||||
}
|
||||
|
||||
let cancelled = false;
|
||||
let finished = false;
|
||||
let interval: ReturnType<typeof setInterval> | null = null;
|
||||
|
||||
const poll = async () => {
|
||||
if (cancelled || finished) return;
|
||||
try {
|
||||
const prog = await fetcher(repoId);
|
||||
if (cancelled) return;
|
||||
const downloaded = prog.downloaded_bytes ?? 0;
|
||||
const total = prog.expected_bytes ?? 0;
|
||||
const ratio = prog.progress ?? 0;
|
||||
const pct =
|
||||
total > 0 ? Math.min(100, Math.round(ratio * 100)) : 0;
|
||||
setState({
|
||||
downloadedBytes: downloaded,
|
||||
totalBytes: total,
|
||||
percent: pct,
|
||||
cachePath: prog.cache_path ?? null,
|
||||
});
|
||||
if (ratio >= 1.0) {
|
||||
finished = true;
|
||||
if (interval) {
|
||||
clearInterval(interval);
|
||||
interval = null;
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Silently swallow; bar freezes at last value (matches chat flow).
|
||||
}
|
||||
};
|
||||
|
||||
void poll();
|
||||
interval = setInterval(poll, 1500);
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
if (interval) clearInterval(interval);
|
||||
};
|
||||
}, [repoId, shouldPoll, fetcher]);
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
function useModelDownloadProgress(modelName: string | null): DownloadState {
|
||||
return useHfDownloadProgress(modelName, getDownloadProgress);
|
||||
}
|
||||
|
||||
function useDatasetDownloadProgress(datasetName: string | null): DownloadState {
|
||||
return useHfDownloadProgress(datasetName, getDatasetDownloadProgress);
|
||||
}
|
||||
|
||||
type DownloadRowProps = {
|
||||
label: string;
|
||||
state: DownloadState;
|
||||
};
|
||||
|
||||
function DownloadRow({ label, state }: DownloadRowProps): ReactElement | null {
|
||||
if (state.downloadedBytes <= 0 && !state.cachePath) return null;
|
||||
const isComplete = state.totalBytes > 0 && state.percent >= 100;
|
||||
const statusLabel = isComplete
|
||||
? "Ready"
|
||||
: state.totalBytes > 0
|
||||
? "Downloading"
|
||||
: state.downloadedBytes === 0
|
||||
? "Preparing"
|
||||
: null;
|
||||
const sizeLabel =
|
||||
state.totalBytes > 0
|
||||
? `${formatBytes(state.downloadedBytes)} / ${formatBytes(state.totalBytes)}`
|
||||
: state.downloadedBytes > 0
|
||||
? `${formatBytes(state.downloadedBytes)} downloaded`
|
||||
: null;
|
||||
return (
|
||||
<div className="flex flex-col gap-1.5 rounded-md border border-border/50 bg-muted/20 px-3 py-2">
|
||||
<div className="flex items-center justify-between gap-3">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-xs text-foreground/90">{label}</span>
|
||||
{statusLabel ? (
|
||||
<span
|
||||
className={`rounded-full px-1.5 py-0.5 text-[10px] font-medium ${isComplete ? "bg-emerald-100 text-emerald-700 ring-1 ring-emerald-200/80 dark:bg-emerald-500/15 dark:text-emerald-300 dark:ring-emerald-500/30" : "bg-muted text-muted-foreground"}`}
|
||||
>
|
||||
{statusLabel}
|
||||
</span>
|
||||
) : null}
|
||||
</div>
|
||||
<span className="text-xs tabular-nums text-muted-foreground">
|
||||
{state.totalBytes > 0 ? `${state.percent}%` : ""}
|
||||
</span>
|
||||
</div>
|
||||
{sizeLabel ? (
|
||||
<div className="text-[11px] tabular-nums text-muted-foreground">
|
||||
{sizeLabel}
|
||||
</div>
|
||||
) : null}
|
||||
{state.totalBytes > 0 ? (
|
||||
<Progress
|
||||
value={state.percent}
|
||||
indicatorClassName="bg-[linear-gradient(90deg,oklch(0.66_0.142_166.6)_0%,oklch(0.705_0.132_166.6)_55%,oklch(0.75_0.122_166.6)_100%)]"
|
||||
/>
|
||||
) : null}
|
||||
{state.cachePath ? (
|
||||
<div
|
||||
className="truncate rounded bg-muted/50 px-2 py-1 text-[10px] text-muted-foreground/70"
|
||||
title={state.cachePath}
|
||||
>
|
||||
{formatCachePath(state.cachePath)}
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
type TrainingStartOverlayProps = {
|
||||
message: string
|
||||
currentStep: number
|
||||
|
|
@ -33,6 +217,14 @@ export function TrainingStartOverlay({
|
|||
}: TrainingStartOverlayProps): ReactElement {
|
||||
const { stopTrainingRun, dismissTrainingRun } = useTrainingActions();
|
||||
const isStarting = useTrainingRuntimeStore((s) => s.isStarting);
|
||||
const selectedModel = useTrainingConfigStore((s) => s.selectedModel);
|
||||
const datasetSource = useTrainingConfigStore((s) => s.datasetSource);
|
||||
const dataset = useTrainingConfigStore((s) => s.dataset);
|
||||
// Only HF datasets have a download phase to track. Uploaded files are
|
||||
// already on disk by the time the overlay shows up.
|
||||
const hfDatasetName = datasetSource === "huggingface" ? dataset : null;
|
||||
const modelDownload = useModelDownloadProgress(selectedModel);
|
||||
const datasetDownload = useDatasetDownloadProgress(hfDatasetName);
|
||||
const [cancelDialogOpen, setCancelDialogOpen] = useState(false);
|
||||
const [cancelRequested, setCancelRequested] = useState(false);
|
||||
|
||||
|
|
@ -112,6 +304,22 @@ export function TrainingStartOverlay({
|
|||
<AnimatedSpan className="mt-2 text-muted-foreground">
|
||||
{`> ${message || "starting training..."} | waiting for first step... (${currentStep})`}
|
||||
</AnimatedSpan>
|
||||
{datasetDownload.downloadedBytes > 0 || datasetDownload.cachePath ? (
|
||||
<AnimatedSpan className="mt-3">
|
||||
<DownloadRow
|
||||
label="Dataset"
|
||||
state={datasetDownload}
|
||||
/>
|
||||
</AnimatedSpan>
|
||||
) : null}
|
||||
{modelDownload.downloadedBytes > 0 || modelDownload.cachePath ? (
|
||||
<AnimatedSpan className="mt-3">
|
||||
<DownloadRow
|
||||
label="Model weights"
|
||||
state={modelDownload}
|
||||
/>
|
||||
</AnimatedSpan>
|
||||
) : null}
|
||||
</Terminal>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
Loading…
Reference in a new issue