mirror of
https://github.com/mudler/LocalAI
synced 2026-04-21 13:27:21 +00:00
feat: track files being staged (#9275)
This changeset makes visible when files are being staged, so users are aware that the model "isn't ready yet" for requests. Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
0e9d1a6588
commit
39c6b3ed66
9 changed files with 425 additions and 43 deletions
|
|
@ -1936,6 +1936,56 @@
|
|||
40% { transform: scale(1); opacity: 1; }
|
||||
}
|
||||
|
||||
/* Staging progress indicator (replaces thinking dots during model transfer) */
|
||||
.chat-staging-progress {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 6px;
|
||||
min-width: 200px;
|
||||
max-width: 320px;
|
||||
}
|
||||
.chat-staging-label {
|
||||
font-size: 0.8rem;
|
||||
color: var(--color-text-secondary);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
}
|
||||
.chat-staging-label i {
|
||||
color: var(--color-primary);
|
||||
}
|
||||
.chat-staging-detail {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
.chat-staging-bar-container {
|
||||
flex: 1;
|
||||
height: 4px;
|
||||
background: var(--color-bg-tertiary);
|
||||
border-radius: 2px;
|
||||
overflow: hidden;
|
||||
}
|
||||
.chat-staging-bar {
|
||||
height: 100%;
|
||||
background: var(--color-primary);
|
||||
border-radius: 2px;
|
||||
transition: width 300ms ease;
|
||||
}
|
||||
.chat-staging-pct {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-muted);
|
||||
min-width: 32px;
|
||||
text-align: right;
|
||||
}
|
||||
.chat-staging-file {
|
||||
font-size: 0.7rem;
|
||||
color: var(--color-text-muted);
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
/* Message completion flash */
|
||||
.chat-message-bubble {
|
||||
transition: border-color 300ms ease;
|
||||
|
|
|
|||
|
|
@ -27,6 +27,11 @@ export default function OperationsBar() {
|
|||
({op.error})
|
||||
</span>
|
||||
</>
|
||||
) : op.taskType === 'staging' ? (
|
||||
<>
|
||||
<i className="fas fa-cloud-arrow-up" style={{ marginRight: 'var(--spacing-xs)' }} />
|
||||
Staging model: {op.name}{op.nodeName ? ` → ${op.nodeName}` : ''}
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
{op.isDeletion ? 'Removing' : 'Installing'}{' '}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import UnifiedMCPDropdown from '../components/UnifiedMCPDropdown'
|
|||
import { loadClientMCPServers } from '../utils/mcpClientStorage'
|
||||
import ConfirmDialog from '../components/ConfirmDialog'
|
||||
import { useAuth } from '../context/AuthContext'
|
||||
import { useOperations } from '../hooks/useOperations'
|
||||
import { relativeTime } from '../utils/format'
|
||||
|
||||
function getLastMessagePreview(chat) {
|
||||
|
|
@ -277,6 +278,7 @@ export default function Chat() {
|
|||
const { addToast } = useOutletContext()
|
||||
const navigate = useNavigate()
|
||||
const { isAdmin } = useAuth()
|
||||
const { operations } = useOperations()
|
||||
const {
|
||||
chats, activeChat, activeChatId, isStreaming, streamingChatId, streamingContent,
|
||||
streamingReasoning, streamingToolCalls, tokensPerSecond, maxTokensPerSecond,
|
||||
|
|
@ -284,6 +286,12 @@ export default function Chat() {
|
|||
sendMessage, stopGeneration, clearHistory, getContextUsagePercent, addMessage,
|
||||
} = useChat(urlModel || '')
|
||||
|
||||
// Detect active staging operation for the current chat's model
|
||||
const stagingOp = useMemo(() => {
|
||||
if (!isStreaming || !activeChat?.model) return null
|
||||
return operations.find(op => op.taskType === 'staging' && op.name === activeChat.model) || null
|
||||
}, [operations, isStreaming, activeChat?.model])
|
||||
|
||||
const [input, setInput] = useState('')
|
||||
const [files, setFiles] = useState([])
|
||||
const [showSettings, setShowSettings] = useState(false)
|
||||
|
|
@ -1187,9 +1195,28 @@ export default function Chat() {
|
|||
</div>
|
||||
<div className="chat-message-bubble">
|
||||
<div className="chat-message-content chat-thinking-indicator">
|
||||
<span className="chat-thinking-dots">
|
||||
<span /><span /><span />
|
||||
</span>
|
||||
{stagingOp ? (
|
||||
<div className="chat-staging-progress">
|
||||
<div className="chat-staging-label">
|
||||
<i className="fas fa-cloud-arrow-up" /> Transferring model{stagingOp.nodeName ? ` to ${stagingOp.nodeName}` : ''}...
|
||||
</div>
|
||||
{stagingOp.progress > 0 && (
|
||||
<div className="chat-staging-detail">
|
||||
<div className="chat-staging-bar-container">
|
||||
<div className="chat-staging-bar" style={{ width: `${stagingOp.progress}%` }} />
|
||||
</div>
|
||||
<span className="chat-staging-pct">{Math.round(stagingOp.progress)}%</span>
|
||||
</div>
|
||||
)}
|
||||
{stagingOp.message && (
|
||||
<div className="chat-staging-file">{stagingOp.message}</div>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<span className="chat-thinking-dots">
|
||||
<span /><span /><span />
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -153,6 +153,27 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
|||
operations = append(operations, opData)
|
||||
}
|
||||
|
||||
// Append active file staging operations (distributed mode only)
|
||||
if d := applicationInstance.Distributed(); d != nil && d.Router != nil {
|
||||
for modelID, status := range d.Router.StagingTracker().GetAll() {
|
||||
operations = append(operations, map[string]any{
|
||||
"id": "staging:" + modelID,
|
||||
"name": modelID,
|
||||
"fullName": modelID,
|
||||
"jobID": "staging:" + modelID,
|
||||
"progress": int(status.Progress),
|
||||
"taskType": "staging",
|
||||
"isDeletion": false,
|
||||
"isBackend": false,
|
||||
"isQueued": false,
|
||||
"isCancelled": false,
|
||||
"cancellable": false,
|
||||
"message": status.Message,
|
||||
"nodeName": status.NodeName,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Sort operations by progress (ascending), then by ID for stable display order
|
||||
slices.SortFunc(operations, func(a, b map[string]any) int {
|
||||
progressA := a["progress"].(int)
|
||||
|
|
|
|||
|
|
@ -146,10 +146,11 @@ func (h *HTTPFileStager) doUpload(ctx context.Context, addr, nodeID, localPath,
|
|||
defer f.Close()
|
||||
|
||||
var body io.Reader = f
|
||||
// For files > 100MB, wrap with progress logging
|
||||
cb := StagingProgressFromContext(ctx)
|
||||
// For files > 100MB or when a progress callback is set, wrap with progress reporting
|
||||
const progressThreshold = 100 << 20
|
||||
if fileSize > progressThreshold {
|
||||
body = newProgressReader(f, fileSize, filepath.Base(localPath), nodeID)
|
||||
if fileSize > progressThreshold || cb != nil {
|
||||
body = newProgressReader(f, fileSize, filepath.Base(localPath), nodeID, cb)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, body)
|
||||
|
|
@ -268,26 +269,30 @@ func (h *HTTPFileStager) probeExisting(ctx context.Context, addr, localPath, key
|
|||
}
|
||||
|
||||
// progressReader wraps an io.Reader and logs upload progress periodically.
|
||||
// If a StagingProgressCallback is present in the context, it also calls it
|
||||
// for UI-visible progress updates.
|
||||
type progressReader struct {
|
||||
reader io.Reader
|
||||
total int64
|
||||
read int64
|
||||
file string
|
||||
node string
|
||||
lastLog time.Time
|
||||
lastPct int
|
||||
start time.Time
|
||||
mu sync.Mutex
|
||||
reader io.Reader
|
||||
total int64
|
||||
read int64
|
||||
file string
|
||||
node string
|
||||
lastLog time.Time
|
||||
lastPct int
|
||||
start time.Time
|
||||
mu sync.Mutex
|
||||
progressCb StagingProgressCallback
|
||||
}
|
||||
|
||||
func newProgressReader(r io.Reader, total int64, file, node string) *progressReader {
|
||||
func newProgressReader(r io.Reader, total int64, file, node string, cb StagingProgressCallback) *progressReader {
|
||||
return &progressReader{
|
||||
reader: r,
|
||||
total: total,
|
||||
file: file,
|
||||
node: node,
|
||||
start: time.Now(),
|
||||
lastLog: time.Now(),
|
||||
reader: r,
|
||||
total: total,
|
||||
file: file,
|
||||
node: node,
|
||||
start: time.Now(),
|
||||
lastLog: time.Now(),
|
||||
progressCb: cb,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -313,6 +318,10 @@ func (pr *progressReader) Read(p []byte) (int, error) {
|
|||
pr.lastLog = now
|
||||
pr.lastPct = pct
|
||||
}
|
||||
// Call external progress callback for UI visibility
|
||||
if pr.progressCb != nil {
|
||||
pr.progressCb(pr.file, pr.read, pr.total)
|
||||
}
|
||||
pr.mu.Unlock()
|
||||
}
|
||||
return n, err
|
||||
|
|
@ -385,7 +394,19 @@ func (h *HTTPFileStager) FetchRemoteByKey(ctx context.Context, nodeID, key, loca
|
|||
}
|
||||
defer f.Close()
|
||||
|
||||
written, err := io.Copy(f, resp.Body)
|
||||
// Wrap response body with progress reporting if callback is set or file is large
|
||||
var src io.Reader = resp.Body
|
||||
cb := StagingProgressFromContext(ctx)
|
||||
totalSize := resp.ContentLength
|
||||
const progressThreshold = 100 << 20
|
||||
if totalSize > progressThreshold || cb != nil {
|
||||
if totalSize <= 0 {
|
||||
totalSize = 0 // unknown size — progress reader will still report bytes
|
||||
}
|
||||
src = newProgressReader(resp.Body, totalSize, filepath.Base(key), nodeID, cb)
|
||||
}
|
||||
|
||||
written, err := io.Copy(f, src)
|
||||
if err != nil {
|
||||
os.Remove(localDst)
|
||||
return fmt.Errorf("writing to %s: %w", localDst, err)
|
||||
|
|
|
|||
|
|
@ -70,7 +70,14 @@ func (s *S3NATSFileStager) EnsureRemote(ctx context.Context, nodeID, localPath,
|
|||
// Upload to S3 if not already present
|
||||
exists, _ := s.fm.Exists(ctx, key)
|
||||
if !exists {
|
||||
if err := s.fm.Upload(ctx, key, localPath); err != nil {
|
||||
// Wrap with progress reporting if a staging callback is available
|
||||
var progressFn storage.UploadProgressFunc
|
||||
if cb := StagingProgressFromContext(ctx); cb != nil {
|
||||
progressFn = func(fileName string, bytesWritten, totalBytes int64) {
|
||||
cb(fileName, bytesWritten, totalBytes)
|
||||
}
|
||||
}
|
||||
if err := s.fm.UploadWithProgress(ctx, key, localPath, progressFn); err != nil {
|
||||
return "", fmt.Errorf("uploading %s to S3: %w", localPath, err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,12 +42,13 @@ type SmartRouterOptions struct {
|
|||
// SmartRouter routes inference requests to the best available backend node.
|
||||
// It uses the ModelRouter interface (backed by NodeRegistry in production) for routing decisions.
|
||||
type SmartRouter struct {
|
||||
registry ModelRouter
|
||||
unloader NodeCommandSender // optional, for NATS-driven load/unload
|
||||
fileStager FileStager // optional, for distributed file transfer
|
||||
galleriesJSON string // backend gallery config for dynamic installation
|
||||
clientFactory BackendClientFactory // creates gRPC backend clients
|
||||
db *gorm.DB // for advisory locks during routing
|
||||
registry ModelRouter
|
||||
unloader NodeCommandSender // optional, for NATS-driven load/unload
|
||||
fileStager FileStager // optional, for distributed file transfer
|
||||
galleriesJSON string // backend gallery config for dynamic installation
|
||||
clientFactory BackendClientFactory // creates gRPC backend clients
|
||||
db *gorm.DB // for advisory locks during routing
|
||||
stagingTracker *StagingTracker // tracks file staging progress for UI visibility
|
||||
}
|
||||
|
||||
// NewSmartRouter creates a new SmartRouter backed by the given ModelRouter.
|
||||
|
|
@ -58,18 +59,22 @@ func NewSmartRouter(registry ModelRouter, opts SmartRouterOptions) *SmartRouter
|
|||
factory = &tokenClientFactory{token: opts.AuthToken}
|
||||
}
|
||||
return &SmartRouter{
|
||||
registry: registry,
|
||||
unloader: opts.Unloader,
|
||||
fileStager: opts.FileStager,
|
||||
galleriesJSON: opts.GalleriesJSON,
|
||||
clientFactory: factory,
|
||||
db: opts.DB,
|
||||
registry: registry,
|
||||
unloader: opts.Unloader,
|
||||
fileStager: opts.FileStager,
|
||||
galleriesJSON: opts.GalleriesJSON,
|
||||
clientFactory: factory,
|
||||
db: opts.DB,
|
||||
stagingTracker: NewStagingTracker(),
|
||||
}
|
||||
}
|
||||
|
||||
// Unloader returns the remote unloader adapter for external use.
|
||||
func (r *SmartRouter) Unloader() NodeCommandSender { return r.unloader }
|
||||
|
||||
// StagingTracker returns the staging progress tracker for UI visibility.
|
||||
func (r *SmartRouter) StagingTracker() *StagingTracker { return r.stagingTracker }
|
||||
|
||||
// scheduleLoadResult holds the result of scheduling and loading a model on a node.
|
||||
type scheduleLoadResult struct {
|
||||
Node *BackendNode
|
||||
|
|
@ -568,6 +573,33 @@ func (r *SmartRouter) stageModelFiles(ctx context.Context, node *BackendNode, op
|
|||
{"AudioPath", &opts.AudioPath},
|
||||
}
|
||||
|
||||
// Count stageable files for progress tracking
|
||||
totalFiles := 0
|
||||
for _, f := range fields {
|
||||
if *f.val != "" {
|
||||
if _, err := os.Stat(*f.val); err == nil {
|
||||
totalFiles++
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, adapter := range opts.LoraAdapters {
|
||||
if adapter != "" {
|
||||
if _, err := os.Stat(adapter); err == nil {
|
||||
totalFiles++
|
||||
}
|
||||
}
|
||||
}
|
||||
if opts.LoraBase != "" {
|
||||
if _, err := os.Stat(opts.LoraBase); err == nil {
|
||||
totalFiles++
|
||||
}
|
||||
}
|
||||
|
||||
// Start tracking staging progress
|
||||
r.stagingTracker.Start(trackingKey, node.Name, totalFiles)
|
||||
defer r.stagingTracker.Complete(trackingKey)
|
||||
|
||||
fileIdx := 0
|
||||
for _, f := range fields {
|
||||
if *f.val == "" {
|
||||
continue
|
||||
|
|
@ -578,9 +610,17 @@ func (r *SmartRouter) stageModelFiles(ctx context.Context, node *BackendNode, op
|
|||
*f.val = ""
|
||||
continue
|
||||
}
|
||||
fileIdx++
|
||||
localPath := *f.val
|
||||
key := keyMapper.Key(localPath)
|
||||
remotePath, err := r.fileStager.EnsureRemote(ctx, node.ID, localPath, key)
|
||||
|
||||
// Attach progress callback to context for byte-level tracking
|
||||
fileName := filepath.Base(localPath)
|
||||
stageCtx := r.withStagingCallback(ctx, trackingKey, fileName, fileIdx, totalFiles)
|
||||
|
||||
xlog.Info("Staging file", "model", trackingKey, "node", node.Name, "field", f.name, "file", fileName, "fileIndex", fileIdx, "totalFiles", totalFiles)
|
||||
|
||||
remotePath, err := r.fileStager.EnsureRemote(stageCtx, node.ID, localPath, key)
|
||||
if err != nil {
|
||||
// ModelFile is required — fail the whole operation
|
||||
if f.name == "ModelFile" {
|
||||
|
|
@ -592,6 +632,8 @@ func (r *SmartRouter) stageModelFiles(ctx context.Context, node *BackendNode, op
|
|||
*f.val = ""
|
||||
continue
|
||||
}
|
||||
|
||||
r.stagingTracker.FileComplete(trackingKey, fileIdx, totalFiles)
|
||||
xlog.Debug("Staged model field", "field", f.name, "remotePath", remotePath)
|
||||
*f.val = remotePath
|
||||
|
||||
|
|
@ -609,7 +651,7 @@ func (r *SmartRouter) stageModelFiles(ctx context.Context, node *BackendNode, op
|
|||
}
|
||||
|
||||
// Handle LoraAdapters (array) — rewritten to absolute remote paths
|
||||
staged := make([]string, 0, len(opts.LoraAdapters))
|
||||
stagedAdapters := make([]string, 0, len(opts.LoraAdapters))
|
||||
for _, adapter := range opts.LoraAdapters {
|
||||
if adapter == "" {
|
||||
continue
|
||||
|
|
@ -618,21 +660,31 @@ func (r *SmartRouter) stageModelFiles(ctx context.Context, node *BackendNode, op
|
|||
xlog.Debug("Skipping staging for non-existent lora adapter", "path", adapter)
|
||||
continue
|
||||
}
|
||||
fileIdx++
|
||||
fileName := filepath.Base(adapter)
|
||||
stageCtx := r.withStagingCallback(ctx, trackingKey, fileName, fileIdx, totalFiles)
|
||||
|
||||
key := keyMapper.Key(adapter)
|
||||
remotePath, err := r.fileStager.EnsureRemote(ctx, node.ID, adapter, key)
|
||||
remotePath, err := r.fileStager.EnsureRemote(stageCtx, node.ID, adapter, key)
|
||||
if err != nil {
|
||||
xlog.Warn("Failed to stage lora adapter, skipping", "path", adapter, "error", err)
|
||||
continue
|
||||
}
|
||||
staged = append(staged, remotePath)
|
||||
r.stagingTracker.FileComplete(trackingKey, fileIdx, totalFiles)
|
||||
stagedAdapters = append(stagedAdapters, remotePath)
|
||||
}
|
||||
opts.LoraAdapters = staged
|
||||
opts.LoraAdapters = stagedAdapters
|
||||
|
||||
// Handle LoraBase field — rewritten to absolute remote path
|
||||
if opts.LoraBase != "" {
|
||||
if _, err := os.Stat(opts.LoraBase); err == nil {
|
||||
fileIdx++
|
||||
fileName := filepath.Base(opts.LoraBase)
|
||||
stageCtx := r.withStagingCallback(ctx, trackingKey, fileName, fileIdx, totalFiles)
|
||||
|
||||
key := keyMapper.Key(opts.LoraBase)
|
||||
if remotePath, err := r.fileStager.EnsureRemote(ctx, node.ID, opts.LoraBase, key); err == nil {
|
||||
if remotePath, err := r.fileStager.EnsureRemote(stageCtx, node.ID, opts.LoraBase, key); err == nil {
|
||||
r.stagingTracker.FileComplete(trackingKey, fileIdx, totalFiles)
|
||||
opts.LoraBase = remotePath
|
||||
} else {
|
||||
xlog.Warn("Failed to stage LoraBase, clearing field", "path", opts.LoraBase, "error", err)
|
||||
|
|
@ -649,6 +701,20 @@ func (r *SmartRouter) stageModelFiles(ctx context.Context, node *BackendNode, op
|
|||
return opts, nil
|
||||
}
|
||||
|
||||
// withStagingCallback creates a context with a progress callback that updates the staging tracker.
|
||||
func (r *SmartRouter) withStagingCallback(ctx context.Context, trackingKey, fileName string, fileIdx, totalFiles int) context.Context {
|
||||
start := time.Now()
|
||||
return WithStagingProgress(ctx, func(fn string, bytesSent, totalBytes int64) {
|
||||
var speed string
|
||||
elapsed := time.Since(start)
|
||||
if elapsed > 0 {
|
||||
bytesPerSec := float64(bytesSent) / elapsed.Seconds()
|
||||
speed = humanFileSize(int64(bytesPerSec)) + "/s"
|
||||
}
|
||||
r.stagingTracker.UpdateFile(trackingKey, fn, fileIdx, bytesSent, totalBytes, speed)
|
||||
})
|
||||
}
|
||||
|
||||
// stageCompanionFiles stages known companion files that exist alongside
|
||||
// localPath. For example, piper TTS implicitly loads ".onnx.json" next to
|
||||
// the ".onnx" model file. Errors are logged but not propagated.
|
||||
|
|
|
|||
145
core/services/nodes/staging_progress.go
Normal file
145
core/services/nodes/staging_progress.go
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// StagingStatus represents the current progress of a model staging operation.
|
||||
type StagingStatus struct {
|
||||
ModelID string `json:"model_id"`
|
||||
NodeName string `json:"node_name"`
|
||||
FileName string `json:"file_name"`
|
||||
BytesSent int64 `json:"bytes_sent"`
|
||||
TotalBytes int64 `json:"total_bytes"`
|
||||
Progress float64 `json:"progress"` // 0-100 overall progress
|
||||
Speed string `json:"speed"`
|
||||
FileIndex int `json:"file_index"`
|
||||
TotalFiles int `json:"total_files"`
|
||||
Message string `json:"message"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
}
|
||||
|
||||
// StagingTracker tracks active file staging operations in-memory.
|
||||
// Used by SmartRouter to publish progress and by /api/operations to surface it.
|
||||
type StagingTracker struct {
|
||||
mu sync.RWMutex
|
||||
active map[string]*StagingStatus
|
||||
}
|
||||
|
||||
// NewStagingTracker creates a new tracker.
|
||||
func NewStagingTracker() *StagingTracker {
|
||||
return &StagingTracker{
|
||||
active: make(map[string]*StagingStatus),
|
||||
}
|
||||
}
|
||||
|
||||
// Start registers a new staging operation for the given model.
|
||||
func (t *StagingTracker) Start(modelID, nodeName string, totalFiles int) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.active[modelID] = &StagingStatus{
|
||||
ModelID: modelID,
|
||||
NodeName: nodeName,
|
||||
TotalFiles: totalFiles,
|
||||
StartedAt: time.Now(),
|
||||
Message: "Preparing to stage model files",
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateFile updates the tracker with current file transfer progress.
|
||||
func (t *StagingTracker) UpdateFile(modelID, fileName string, fileIndex int, bytesSent, totalBytes int64, speed string) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
s, ok := t.active[modelID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
s.FileName = fileName
|
||||
s.FileIndex = fileIndex
|
||||
s.BytesSent = bytesSent
|
||||
s.TotalBytes = totalBytes
|
||||
s.Speed = speed
|
||||
|
||||
// Calculate overall progress across all files
|
||||
if s.TotalFiles > 0 && totalBytes > 0 {
|
||||
filePct := float64(bytesSent) / float64(totalBytes) * 100
|
||||
s.Progress = (float64(fileIndex-1)*100 + filePct) / float64(s.TotalFiles)
|
||||
}
|
||||
|
||||
// Build human-readable message
|
||||
if totalBytes > 0 {
|
||||
s.Message = fmt.Sprintf("%s (%s / %s", fileName, humanFileSize(bytesSent), humanFileSize(totalBytes))
|
||||
if speed != "" {
|
||||
s.Message += ", " + speed
|
||||
}
|
||||
s.Message += ")"
|
||||
} else {
|
||||
s.Message = fmt.Sprintf("Staging %s", fileName)
|
||||
}
|
||||
}
|
||||
|
||||
// FileComplete marks a single file as done within a staging operation.
|
||||
func (t *StagingTracker) FileComplete(modelID string, fileIndex, totalFiles int) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
s, ok := t.active[modelID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if totalFiles > 0 {
|
||||
s.Progress = float64(fileIndex) / float64(totalFiles) * 100
|
||||
}
|
||||
s.BytesSent = 0
|
||||
s.TotalBytes = 0
|
||||
s.Speed = ""
|
||||
}
|
||||
|
||||
// Complete removes a staging operation (it's done).
|
||||
func (t *StagingTracker) Complete(modelID string) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
delete(t.active, modelID)
|
||||
}
|
||||
|
||||
// GetAll returns a snapshot of all active staging operations.
|
||||
func (t *StagingTracker) GetAll() map[string]StagingStatus {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
result := make(map[string]StagingStatus, len(t.active))
|
||||
for k, v := range t.active {
|
||||
result[k] = *v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Get returns the status of a specific staging operation, or nil if not active.
|
||||
func (t *StagingTracker) Get(modelID string) *StagingStatus {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
s, ok := t.active[modelID]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
copy := *s
|
||||
return ©
|
||||
}
|
||||
|
||||
// StagingProgressCallback is called by file stagers to report byte-level progress.
|
||||
type StagingProgressCallback func(fileName string, bytesSent, totalBytes int64)
|
||||
|
||||
type stagingProgressKey struct{}
|
||||
|
||||
// WithStagingProgress attaches a progress callback to a context.
|
||||
func WithStagingProgress(ctx context.Context, cb StagingProgressCallback) context.Context {
|
||||
return context.WithValue(ctx, stagingProgressKey{}, cb)
|
||||
}
|
||||
|
||||
// StagingProgressFromContext extracts a progress callback from a context.
|
||||
// Returns nil if no callback is set.
|
||||
func StagingProgressFromContext(ctx context.Context) StagingProgressCallback {
|
||||
cb, _ := ctx.Value(stagingProgressKey{}).(StagingProgressCallback)
|
||||
return cb
|
||||
}
|
||||
|
|
@ -39,6 +39,14 @@ func NewFileManager(store ObjectStore, cacheDir string) (*FileManager, error) {
|
|||
// Upload stores a file in object storage under the given key.
|
||||
// The file is read from the local path.
|
||||
func (fm *FileManager) Upload(ctx context.Context, key, localPath string) error {
|
||||
return fm.UploadWithProgress(ctx, key, localPath, nil)
|
||||
}
|
||||
|
||||
// UploadProgressFunc is called periodically during upload with the file name and bytes written/total.
|
||||
type UploadProgressFunc func(fileName string, bytesWritten, totalBytes int64)
|
||||
|
||||
// UploadWithProgress stores a file in object storage, calling progressFn with byte-level updates.
|
||||
func (fm *FileManager) UploadWithProgress(ctx context.Context, key, localPath string, progressFn UploadProgressFunc) error {
|
||||
if fm.store == nil {
|
||||
return nil // no-op in single-node mode
|
||||
}
|
||||
|
|
@ -49,7 +57,21 @@ func (fm *FileManager) Upload(ctx context.Context, key, localPath string) error
|
|||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := fm.store.Put(ctx, key, f); err != nil {
|
||||
var r io.Reader = f
|
||||
if progressFn != nil {
|
||||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat %s: %w", localPath, err)
|
||||
}
|
||||
r = &uploadProgressReader{
|
||||
reader: f,
|
||||
total: fi.Size(),
|
||||
fileName: filepath.Base(localPath),
|
||||
progressFn: progressFn,
|
||||
}
|
||||
}
|
||||
|
||||
if err := fm.store.Put(ctx, key, r); err != nil {
|
||||
return fmt.Errorf("uploading %s to %s: %w", localPath, key, err)
|
||||
}
|
||||
|
||||
|
|
@ -57,6 +79,24 @@ func (fm *FileManager) Upload(ctx context.Context, key, localPath string) error
|
|||
return nil
|
||||
}
|
||||
|
||||
// uploadProgressReader wraps an io.Reader and calls a progress function.
|
||||
type uploadProgressReader struct {
|
||||
reader io.Reader
|
||||
total int64
|
||||
written int64
|
||||
fileName string
|
||||
progressFn UploadProgressFunc
|
||||
}
|
||||
|
||||
func (r *uploadProgressReader) Read(p []byte) (int, error) {
|
||||
n, err := r.reader.Read(p)
|
||||
if n > 0 {
|
||||
r.written += int64(n)
|
||||
r.progressFn(r.fileName, r.written, r.total)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Download retrieves a file from object storage and caches it locally.
|
||||
// Returns the local file path. If the file is already cached, returns immediately.
|
||||
func (fm *FileManager) Download(ctx context.Context, key string) (string, error) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue