mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
studio: stream export worker output into the export dialog (#4897)
* studio: stream export worker output into the export dialog
The Export Model dialog only showed a spinner on the "Exporting..."
button while the worker subprocess was doing the actual heavy lifting.
For Merged to 16bit and GGUF / Llama.cpp exports this meant several
minutes (or more, for large models) of opaque silence, with no way to
tell whether save_pretrained_merged, convert_hf_to_gguf.py, or
llama-quantize was making progress.
This adds a live terminal-style output panel inside the export dialog,
rendered just above the Cancel / Start Export buttons and scrollable
with auto-follow-tail. It shows stdout and stderr from both the worker
process itself and any child process it spawns (GGUF converter,
llama-quantize), coloured by stream.
Backend
- core/export/worker.py: new _setup_log_capture(resp_queue) installed
before LogConfig.setup_logging. It saves the original stdout/stderr
fds, creates pipes, os.dup2's the write ends onto fds 1 and 2 (so
every child process inherits the redirected fds), and spins up two
daemon reader threads. Each thread reads bytes from a pipe, echoes
them back to the original fd (so the server console keeps working),
splits on \n and \r, and forwards each line to the resp queue as
{"type":"log","stream":"stdout|stderr","line":...,"ts":...}.
PYTHONUNBUFFERED=1 is set so nested Python converters flush
immediately.
- core/export/orchestrator.py:
- Thread-safe ring buffer (collections.deque, maxlen 4000) with a
monotonically increasing seq counter. clear_logs(),
get_logs_since(cursor), get_current_log_seq(), is_export_active().
- _wait_response handles rtype == "log" by appending to the buffer
and continuing the wait loop. Status messages are also surfaced as
a "status" stream so users see high level progress alongside raw
subprocess output.
- load_checkpoint, _run_export, and cleanup_memory now wrap their
bodies with the existing self._lock (previously unused), clear the
log buffer at the start of each op, and flip _export_active in a
try/finally so the SSE endpoint can detect idle.
- routes/export.py:
- Wrapped every sync orchestrator call (load_checkpoint,
cleanup_memory, export_merged_model, export_base_model,
export_gguf, export_lora_adapter) in asyncio.to_thread so the
FastAPI event loop stays free during long exports. Without this
the new SSE endpoint could not be served concurrently with the
blocking export POST.
- New GET /api/export/logs/stream SSE endpoint. Honors
Last-Event-ID and a since query param for reconnect, emits log /
heartbeat / complete / error events, uses the id field to carry
the log seq so clients can resume cleanly. On first connect
without an explicit cursor it starts from the current seq so old
lines from a previous run are not replayed.
Frontend
- features/export/api/export-api.ts: streamExportLogs() helper that
authFetches the SSE endpoint and parses id / event / data fields
manually (same pattern as streamTrainingProgress in train-api.ts).
- features/export/components/export-dialog.tsx:
- Local useExportLogs(exporting) hook that opens the SSE stream on
exporting transitions to true, accumulates up to 4000 lines in
component state, and aborts on cleanup.
- New scrollable output panel rendered above DialogFooter, only
shown for Merged to 16bit and GGUF / Llama.cpp (LoRA adapter is
a fast disk write with nothing to show). Dark terminal styling
(bg-black/85, emerald text, rose for stderr, sky for status),
max-height 14rem, auto-scrolls to the bottom on new output but
stops following if the user scrolls up. A small streaming / idle
indicator is shown next to the panel title.
- DialogContent widens from sm:max-w-lg to sm:max-w-2xl when the
output panel is visible so the logs have room to breathe.
Verified
- Python smoke test (tests/smoke_export_log_capture.py): spawns a
real mp.get_context("spawn") process, installs _setup_log_capture,
confirms that parent stdout prints, parent stderr prints, AND a
child subprocess invoked via subprocess.run (both its stdout and
stderr) are all captured in the resp queue. Passes.
- Orchestrator log helpers tested in isolation: _append_log,
get_logs_since (with and without a cursor), clear_logs not
resetting seq so reconnecting clients still progress. Passes.
- routes.export imports cleanly in the studio venv and /logs/stream
shows up in router.routes.
- bun run build: tsc -b plus vite build, no TypeScript errors.
No existing export behavior is changed. If the subprocess, the SSE
endpoint, or the frontend hook fails, the export itself still runs to
completion the same way it did before, with or without logs visible.
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* export dialog: trim bootstrap noise, scope logs per screen, show realpath
Several follow-ups to the live export log work:
1. Worker bootstrap noise (transformers venv activation, Unsloth banner,
"Top GGUF/hub models" lists, vision detection, 2k-step weight load
bar) is dropped from the export-dialog stream. A threading.Event
gate in worker.py defaults closed and only opens once _handle_export
actually starts; until then the reader thread still echoes lines to
the saved console fd for debugging but does not push them onto the
resp_queue. The orchestrator already spawns a fresh subprocess for
every checkpoint load, so the gate is naturally reset between runs.
2. tqdm in non-tty mode defaults to a 10s mininterval, which makes
multi-step bars look frozen in the panel. Set TQDM_MININTERVAL=0.5
in the worker env so any tqdm-driven progress emits more often.
3. The dialog's useExportLogs hook now also clears its line buffer
when exportMethod or open changes, so re-opening the dialog into a
different action's screen no longer shows the previous action's
saved output. A useElapsedSeconds tick + "Working Xs" badge in the
log header gives users a visible sign that long single-step phases
(cache copies, GGUF conversion) are still running when no new lines
are arriving.
4. ExportBackend.export_{merged,base,gguf,lora} now return
(success, message, output_path); the worker forwards output_path on
each export_*_done response, the orchestrator's _run_export passes
it to routes/export.py, which surfaces it via
ExportOperationResponse.details.output_path. The dialog's Export
Complete screen renders the resolved on-disk realpath under "Saved
to" so users can find their exported model directly.
* fix(cli): unpack 3-tuple return from export backend
ExportOrchestrator.export_{merged,base,gguf,lora} now return
(success, message, output_path) so the studio dialog can show
the on-disk realpath. The CLI still unpacked 2 values, so every
`unsloth export --format ...` crashed with ValueError before
reporting completion. Update the four call sites and surface
output_path via a "Saved to:" echo.
* fix(studio): anchor export log SSE cursor at run start
The export dialog SSE defaulted its cursor to get_current_log_seq()
at connect time, so any line emitted between the POST that kicks
off the export and the client opening the stream was buffered with
seqs 1..k and then skipped (seq <= cursor). Long-running exports
looked silent during their first seconds.
Snapshot _log_seq into _run_start_seq inside clear_logs() and
expose it via get_run_start_seq(). The SSE default cursor now uses
that snapshot, so every line emitted since the current run began
is reachable regardless of when the client connects. Old runs
still can't leak in because their seqs are <= the snapshot.
* fix(studio): reconnect export log SSE on stream drop
useExportLogs launched streamExportLogs once per exporting
transition and recorded any drop in .catch(). Long GGUF exports
behind a proxy with an idle kill-timeout would silently lose the
stream for the rest of the run even though the backend already
supports Last-Event-ID resume. The "retry: 3000" directive emitted
by the backend is only meaningful to native EventSource; this
hook uses a manual fetch + ReadableStream parse so it had no
effect.
Wrap streamExportLogs in a retry loop that tracks lastSeq from
ExportLogEvent.id and passes it as since on reconnect. Backoff is
exponential with jitter, capped at 5s, reset on successful open.
The loop stops on explicit backend `complete` event or on effect
cleanup.
* fix(studio): register a second command so Typer keeps `export` as a subcommand
The CLI export unpacking tests wrap `unsloth_cli.commands.export.export`
in a fresh Typer app with a single registered command. Typer flattens a
single-command app into that command, so the test's
`runner.invoke(cli_app, ["export", ckpt, out, ...])` treats the leading
`"export"` token as an unexpected extra positional argument -- every
parametrized case failed with:
Got unexpected extra argument (.../out)
Register a harmless `noop` second command so Typer preserves subcommand
routing and the tests actually exercise the 3-tuple unpack path they
were written to guard.
Before: 4 failed
After: 4 passed
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: studio-install <studio@local.install>
Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
Co-authored-by: Lee Jackson <130007945+Imagineer99@users.noreply.github.com>
Co-authored-by: Roland Tannous <rolandtannous@gravityq.ai>
This commit is contained in:
parent
eca592effe
commit
7252410ccc
10 changed files with 1522 additions and 117 deletions
|
|
@ -310,7 +310,7 @@ class ExportBackend:
|
|||
repo_id: Optional[str] = None,
|
||||
hf_token: Optional[str] = None,
|
||||
private: bool = False,
|
||||
) -> Tuple[bool, str]:
|
||||
) -> Tuple[bool, str, Optional[str]]:
|
||||
"""
|
||||
Export merged model (for PEFT models).
|
||||
|
||||
|
|
@ -323,14 +323,21 @@ class ExportBackend:
|
|||
private: Whether to make the repo private
|
||||
|
||||
Returns:
|
||||
Tuple of (success: bool, message: str)
|
||||
Tuple of (success, message, output_path). output_path is the
|
||||
resolved absolute on-disk directory of the saved model when
|
||||
``save_directory`` was set, else None.
|
||||
"""
|
||||
if not self.current_model or not self.current_tokenizer:
|
||||
return False, "No model loaded. Please select a checkpoint first."
|
||||
return False, "No model loaded. Please select a checkpoint first.", None
|
||||
|
||||
if not self.is_peft:
|
||||
return False, "This is not a PEFT model. Use 'Export Base Model' instead."
|
||||
return (
|
||||
False,
|
||||
"This is not a PEFT model. Use 'Export Base Model' instead.",
|
||||
None,
|
||||
)
|
||||
|
||||
output_path: Optional[str] = None
|
||||
try:
|
||||
# Determine save method
|
||||
if format_type == "4-bit (FP4)":
|
||||
|
|
@ -354,6 +361,7 @@ class ExportBackend:
|
|||
# Write export metadata so the Chat page can identify the base model
|
||||
self._write_export_metadata(save_directory)
|
||||
logger.info(f"Model saved successfully to {save_directory}")
|
||||
output_path = str(Path(save_directory).resolve())
|
||||
|
||||
# Push to hub if requested
|
||||
if push_to_hub:
|
||||
|
|
@ -361,6 +369,7 @@ class ExportBackend:
|
|||
return (
|
||||
False,
|
||||
"Repository ID and Hugging Face token required for Hub upload",
|
||||
None,
|
||||
)
|
||||
|
||||
logger.info(f"Pushing merged model to Hub: {repo_id}")
|
||||
|
|
@ -378,14 +387,14 @@ class ExportBackend:
|
|||
)
|
||||
logger.info(f"Model pushed successfully to {repo_id}")
|
||||
|
||||
return True, "Model exported successfully"
|
||||
return True, "Model exported successfully", output_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error exporting merged model: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return False, f"Export failed: {str(e)}"
|
||||
return False, f"Export failed: {str(e)}", None
|
||||
|
||||
def export_base_model(
|
||||
self,
|
||||
|
|
@ -395,22 +404,26 @@ class ExportBackend:
|
|||
hf_token: Optional[str] = None,
|
||||
private: bool = False,
|
||||
base_model_id: Optional[str] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
) -> Tuple[bool, str, Optional[str]]:
|
||||
"""
|
||||
Export base model (for non-PEFT models).
|
||||
|
||||
Returns:
|
||||
Tuple of (success: bool, message: str)
|
||||
Tuple of (success, message, output_path). output_path is the
|
||||
resolved absolute on-disk directory of the saved model when
|
||||
``save_directory`` was set, else None.
|
||||
"""
|
||||
if not self.current_model or not self.current_tokenizer:
|
||||
return False, "No model loaded. Please select a checkpoint first."
|
||||
return False, "No model loaded. Please select a checkpoint first.", None
|
||||
|
||||
if self.is_peft:
|
||||
return (
|
||||
False,
|
||||
"This is a PEFT model. Use 'Merged Model' export type instead.",
|
||||
None,
|
||||
)
|
||||
|
||||
output_path: Optional[str] = None
|
||||
try:
|
||||
# Save locally if requested
|
||||
if save_directory:
|
||||
|
|
@ -424,6 +437,7 @@ class ExportBackend:
|
|||
# Write export metadata so the Chat page can identify the base model
|
||||
self._write_export_metadata(save_directory)
|
||||
logger.info(f"Model saved successfully to {save_directory}")
|
||||
output_path = str(Path(save_directory).resolve())
|
||||
|
||||
# Push to hub if requested
|
||||
if push_to_hub:
|
||||
|
|
@ -431,6 +445,7 @@ class ExportBackend:
|
|||
return (
|
||||
False,
|
||||
"Repository ID and Hugging Face token required for Hub upload",
|
||||
None,
|
||||
)
|
||||
|
||||
logger.info(f"Pushing base model to Hub: {repo_id}")
|
||||
|
|
@ -472,16 +487,16 @@ class ExportBackend:
|
|||
)
|
||||
logger.info(f"Model pushed successfully to {repo_id}")
|
||||
else:
|
||||
return False, "Local save directory required for Hub upload"
|
||||
return False, "Local save directory required for Hub upload", None
|
||||
|
||||
return True, "Model exported successfully"
|
||||
return True, "Model exported successfully", output_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error exporting base model: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return False, f"Export failed: {str(e)}"
|
||||
return False, f"Export failed: {str(e)}", None
|
||||
|
||||
def export_gguf(
|
||||
self,
|
||||
|
|
@ -490,7 +505,7 @@ class ExportBackend:
|
|||
push_to_hub: bool = False,
|
||||
repo_id: Optional[str] = None,
|
||||
hf_token: Optional[str] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
) -> Tuple[bool, str, Optional[str]]:
|
||||
"""
|
||||
Export model in GGUF format.
|
||||
|
||||
|
|
@ -502,11 +517,14 @@ class ExportBackend:
|
|||
hf_token: Hugging Face token
|
||||
|
||||
Returns:
|
||||
Tuple of (success: bool, message: str)
|
||||
Tuple of (success, message, output_path). output_path is the
|
||||
resolved absolute on-disk directory containing the .gguf
|
||||
files when ``save_directory`` was set, else None.
|
||||
"""
|
||||
if not self.current_model or not self.current_tokenizer:
|
||||
return False, "No model loaded. Please select a checkpoint first."
|
||||
return False, "No model loaded. Please select a checkpoint first.", None
|
||||
|
||||
output_path: Optional[str] = None
|
||||
try:
|
||||
# Convert quantization method to lowercase for unsloth
|
||||
quant_method = quantization_method.lower()
|
||||
|
|
@ -601,6 +619,7 @@ class ExportBackend:
|
|||
abs_save_dir,
|
||||
"\n ".join(os.path.basename(f) for f in final_ggufs) or "(none)",
|
||||
)
|
||||
output_path = str(Path(abs_save_dir).resolve())
|
||||
|
||||
# Push to hub if requested
|
||||
if push_to_hub:
|
||||
|
|
@ -608,6 +627,7 @@ class ExportBackend:
|
|||
return (
|
||||
False,
|
||||
"Repository ID and Hugging Face token required for Hub upload",
|
||||
None,
|
||||
)
|
||||
|
||||
logger.info(f"Pushing GGUF model to Hub: {repo_id}")
|
||||
|
|
@ -620,14 +640,18 @@ class ExportBackend:
|
|||
)
|
||||
logger.info(f"GGUF model pushed successfully to {repo_id}")
|
||||
|
||||
return True, f"GGUF model exported successfully ({quantization_method})"
|
||||
return (
|
||||
True,
|
||||
f"GGUF model exported successfully ({quantization_method})",
|
||||
output_path,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error exporting GGUF model: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return False, f"GGUF export failed: {str(e)}"
|
||||
return False, f"GGUF export failed: {str(e)}", None
|
||||
|
||||
def export_lora_adapter(
|
||||
self,
|
||||
|
|
@ -636,19 +660,22 @@ class ExportBackend:
|
|||
repo_id: Optional[str] = None,
|
||||
hf_token: Optional[str] = None,
|
||||
private: bool = False,
|
||||
) -> Tuple[bool, str]:
|
||||
) -> Tuple[bool, str, Optional[str]]:
|
||||
"""
|
||||
Export LoRA adapter only (not merged).
|
||||
|
||||
Returns:
|
||||
Tuple of (success: bool, message: str)
|
||||
Tuple of (success, message, output_path). output_path is the
|
||||
resolved absolute on-disk directory of the saved adapter
|
||||
when ``save_directory`` was set, else None.
|
||||
"""
|
||||
if not self.current_model or not self.current_tokenizer:
|
||||
return False, "No model loaded. Please select a checkpoint first."
|
||||
return False, "No model loaded. Please select a checkpoint first.", None
|
||||
|
||||
if not self.is_peft:
|
||||
return False, "This is not a PEFT model. No adapter to export."
|
||||
return False, "This is not a PEFT model. No adapter to export.", None
|
||||
|
||||
output_path: Optional[str] = None
|
||||
try:
|
||||
# Save locally if requested
|
||||
if save_directory:
|
||||
|
|
@ -659,6 +686,7 @@ class ExportBackend:
|
|||
self.current_model.save_pretrained(save_directory)
|
||||
self.current_tokenizer.save_pretrained(save_directory)
|
||||
logger.info(f"Adapter saved successfully to {save_directory}")
|
||||
output_path = str(Path(save_directory).resolve())
|
||||
|
||||
# Push to hub if requested
|
||||
if push_to_hub:
|
||||
|
|
@ -666,6 +694,7 @@ class ExportBackend:
|
|||
return (
|
||||
False,
|
||||
"Repository ID and Hugging Face token required for Hub upload",
|
||||
None,
|
||||
)
|
||||
|
||||
logger.info(f"Pushing LoRA adapter to Hub: {repo_id}")
|
||||
|
|
@ -676,14 +705,14 @@ class ExportBackend:
|
|||
)
|
||||
logger.info(f"Adapter pushed successfully to {repo_id}")
|
||||
|
||||
return True, "LoRA adapter exported successfully"
|
||||
return True, "LoRA adapter exported successfully", output_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error exporting LoRA adapter: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return False, f"Adapter export failed: {str(e)}"
|
||||
return False, f"Adapter export failed: {str(e)}", None
|
||||
|
||||
|
||||
# Global export backend instance
|
||||
|
|
|
|||
|
|
@ -16,19 +16,25 @@ Pattern follows core/inference/orchestrator.py.
|
|||
|
||||
import atexit
|
||||
import structlog
|
||||
from collections import deque
|
||||
from loggers import get_logger
|
||||
import multiprocessing as mp
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Any, Deque, Dict, List, Optional, Tuple
|
||||
from utils.paths import outputs_root
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_CTX = mp.get_context("spawn")
|
||||
|
||||
# Maximum number of captured log lines kept in memory per export
|
||||
# orchestrator. Acts as scrollback for the live export log panel in the
|
||||
# UI. 4000 lines is ~1 MB worst-case at 256 chars/line.
|
||||
_LOG_BUFFER_MAXLEN = 4000
|
||||
|
||||
|
||||
class ExportOrchestrator:
|
||||
"""
|
||||
|
|
@ -44,6 +50,9 @@ class ExportOrchestrator:
|
|||
self._proc: Optional[mp.Process] = None
|
||||
self._cmd_queue: Any = None
|
||||
self._resp_queue: Any = None
|
||||
# Serializes export operations (load_checkpoint, export_*,
|
||||
# cleanup) so concurrent HTTP requests can never interleave
|
||||
# commands on the subprocess queue. Previously unused.
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Local state mirrors (updated from subprocess responses)
|
||||
|
|
@ -51,9 +60,103 @@ class ExportOrchestrator:
|
|||
self.is_vision: bool = False
|
||||
self.is_peft: bool = False
|
||||
|
||||
# ── Live log capture ─────────────────────────────────────
|
||||
# Thread-safe ring buffer of log lines forwarded from the
|
||||
# worker subprocess. Powers the GET /api/export/logs/stream
|
||||
# SSE endpoint that the export dialog consumes.
|
||||
self._log_buffer: Deque[Dict[str, Any]] = deque(maxlen = _LOG_BUFFER_MAXLEN)
|
||||
self._log_lock = threading.Lock()
|
||||
# Monotonically increasing sequence number. Never reset across
|
||||
# operations, so SSE clients can use it as a stable cursor even
|
||||
# if clear_logs() is called mid-session.
|
||||
self._log_seq: int = 0
|
||||
# Snapshot of _log_seq captured at the start of the current run
|
||||
# (updated by clear_logs()). The SSE endpoint defaults its
|
||||
# cursor to this value so a client that connects AFTER the
|
||||
# worker has already emitted its first lines still sees the
|
||||
# full run. Every line appended during the current run has seq
|
||||
# strictly greater than _run_start_seq, and every line from
|
||||
# prior runs has seq less than or equal to it.
|
||||
self._run_start_seq: int = 0
|
||||
# True while an export operation (load/export/cleanup) is
|
||||
# running. The SSE endpoint ends the stream 1 second after
|
||||
# this flips back to False to drain any trailing log lines.
|
||||
self._export_active: bool = False
|
||||
|
||||
atexit.register(self._cleanup)
|
||||
logger.info("ExportOrchestrator initialized (subprocess mode)")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Live log capture helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _append_log(self, entry: Dict[str, Any]) -> None:
|
||||
"""Append a log line from the worker subprocess to the buffer.
|
||||
|
||||
Entries look like {"type": "log", "stream": "stdout"|"stderr",
|
||||
"line": "...", "ts": ...}. Each is stamped with a monotonic
|
||||
seq number before it lands in the buffer so SSE clients can
|
||||
cursor through new lines.
|
||||
"""
|
||||
line = entry.get("line")
|
||||
if not line:
|
||||
return
|
||||
with self._log_lock:
|
||||
self._log_seq += 1
|
||||
self._log_buffer.append(
|
||||
{
|
||||
"seq": self._log_seq,
|
||||
"stream": entry.get("stream", "stdout"),
|
||||
"line": line,
|
||||
"ts": entry.get("ts", time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
def clear_logs(self) -> None:
|
||||
"""Drop any buffered log lines from a previous operation.
|
||||
|
||||
Called at the start of each export op so the UI shows only the
|
||||
output of the current run. The seq counter is NOT reset, so an
|
||||
SSE client that captured the cursor before clear_logs() will
|
||||
still see new lines (with strictly greater seq numbers).
|
||||
|
||||
Also snapshots the current seq into ``_run_start_seq`` so the
|
||||
SSE endpoint can anchor its default cursor at the start of
|
||||
this run. Anything appended after this call has seq strictly
|
||||
greater than the snapshot and is reachable via
|
||||
``get_logs_since(get_run_start_seq())``.
|
||||
"""
|
||||
with self._log_lock:
|
||||
self._log_buffer.clear()
|
||||
self._run_start_seq = self._log_seq
|
||||
|
||||
def get_logs_since(self, cursor: int) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""Return log entries with seq > cursor, plus the new cursor."""
|
||||
with self._log_lock:
|
||||
new_entries = [entry for entry in self._log_buffer if entry["seq"] > cursor]
|
||||
if new_entries:
|
||||
return new_entries, new_entries[-1]["seq"]
|
||||
return [], cursor
|
||||
|
||||
def get_current_log_seq(self) -> int:
|
||||
"""Return the current seq counter without reading any entries."""
|
||||
with self._log_lock:
|
||||
return self._log_seq
|
||||
|
||||
def get_run_start_seq(self) -> int:
|
||||
"""Return the seq value captured at the start of the current run.
|
||||
|
||||
The SSE endpoint uses this as the default cursor so a client
|
||||
that connects AFTER the worker has already started emitting
|
||||
output still sees every line from the current run.
|
||||
"""
|
||||
with self._log_lock:
|
||||
return self._run_start_seq
|
||||
|
||||
def is_export_active(self) -> bool:
|
||||
"""True while an export / load / cleanup command is running."""
|
||||
return self._export_active
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Subprocess lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -179,8 +282,26 @@ class ExportOrchestrator:
|
|||
error_msg = resp.get("error", "Unknown error")
|
||||
raise RuntimeError(f"Subprocess error: {error_msg}")
|
||||
|
||||
if rtype == "log":
|
||||
# Forwarded stdout/stderr line from the worker process.
|
||||
self._append_log(resp)
|
||||
continue
|
||||
|
||||
if rtype == "status":
|
||||
logger.info("Export subprocess status: %s", resp.get("message", ""))
|
||||
message = resp.get("message", "")
|
||||
logger.info("Export subprocess status: %s", message)
|
||||
# Surface status messages in the live log panel too so
|
||||
# users see high level progress (e.g. "Importing
|
||||
# Unsloth...", "Loading checkpoint: ...") alongside
|
||||
# subprocess output.
|
||||
if message:
|
||||
self._append_log(
|
||||
{
|
||||
"stream": "status",
|
||||
"line": message,
|
||||
"ts": resp.get("ts", time.time()),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Other response types during wait — skip
|
||||
|
|
@ -231,37 +352,47 @@ class ExportOrchestrator:
|
|||
"hf_token": hf_token,
|
||||
}
|
||||
|
||||
# Always kill existing subprocess and spawn fresh.
|
||||
if self._ensure_subprocess_alive():
|
||||
self._shutdown_subprocess()
|
||||
elif self._proc is not None:
|
||||
self._shutdown_subprocess(timeout = 2)
|
||||
with self._lock:
|
||||
# Start a fresh log buffer for this operation so the UI
|
||||
# sees only the current run's output.
|
||||
self.clear_logs()
|
||||
self._export_active = True
|
||||
try:
|
||||
# Always kill existing subprocess and spawn fresh.
|
||||
if self._ensure_subprocess_alive():
|
||||
self._shutdown_subprocess()
|
||||
elif self._proc is not None:
|
||||
self._shutdown_subprocess(timeout = 2)
|
||||
|
||||
logger.info("Spawning fresh export subprocess for '%s'", checkpoint_path)
|
||||
self._spawn_subprocess(sub_config)
|
||||
logger.info(
|
||||
"Spawning fresh export subprocess for '%s'", checkpoint_path
|
||||
)
|
||||
self._spawn_subprocess(sub_config)
|
||||
|
||||
try:
|
||||
resp = self._wait_response("loaded")
|
||||
except RuntimeError as exc:
|
||||
self._shutdown_subprocess(timeout = 5)
|
||||
self.current_checkpoint = None
|
||||
self.is_vision = False
|
||||
self.is_peft = False
|
||||
return False, str(exc)
|
||||
try:
|
||||
resp = self._wait_response("loaded")
|
||||
except RuntimeError as exc:
|
||||
self._shutdown_subprocess(timeout = 5)
|
||||
self.current_checkpoint = None
|
||||
self.is_vision = False
|
||||
self.is_peft = False
|
||||
return False, str(exc)
|
||||
|
||||
if resp.get("success"):
|
||||
self.current_checkpoint = resp.get("checkpoint")
|
||||
self.is_vision = resp.get("is_vision", False)
|
||||
self.is_peft = resp.get("is_peft", False)
|
||||
logger.info("Checkpoint '%s' loaded in subprocess", checkpoint_path)
|
||||
return True, resp.get("message", "Loaded successfully")
|
||||
else:
|
||||
error = resp.get("message", "Failed to load checkpoint")
|
||||
logger.error("Failed to load checkpoint: %s", error)
|
||||
self.current_checkpoint = None
|
||||
self.is_vision = False
|
||||
self.is_peft = False
|
||||
return False, error
|
||||
if resp.get("success"):
|
||||
self.current_checkpoint = resp.get("checkpoint")
|
||||
self.is_vision = resp.get("is_vision", False)
|
||||
self.is_peft = resp.get("is_peft", False)
|
||||
logger.info("Checkpoint '%s' loaded in subprocess", checkpoint_path)
|
||||
return True, resp.get("message", "Loaded successfully")
|
||||
else:
|
||||
error = resp.get("message", "Failed to load checkpoint")
|
||||
logger.error("Failed to load checkpoint: %s", error)
|
||||
self.current_checkpoint = None
|
||||
self.is_vision = False
|
||||
self.is_peft = False
|
||||
return False, error
|
||||
finally:
|
||||
self._export_active = False
|
||||
|
||||
def export_merged_model(
|
||||
self,
|
||||
|
|
@ -271,7 +402,7 @@ class ExportOrchestrator:
|
|||
repo_id: Optional[str] = None,
|
||||
hf_token: Optional[str] = None,
|
||||
private: bool = False,
|
||||
) -> Tuple[bool, str]:
|
||||
) -> Tuple[bool, str, Optional[str]]:
|
||||
"""Export merged PEFT model."""
|
||||
return self._run_export(
|
||||
"merged",
|
||||
|
|
@ -293,7 +424,7 @@ class ExportOrchestrator:
|
|||
hf_token: Optional[str] = None,
|
||||
private: bool = False,
|
||||
base_model_id: Optional[str] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
) -> Tuple[bool, str, Optional[str]]:
|
||||
"""Export base model (non-PEFT)."""
|
||||
return self._run_export(
|
||||
"base",
|
||||
|
|
@ -314,7 +445,7 @@ class ExportOrchestrator:
|
|||
push_to_hub: bool = False,
|
||||
repo_id: Optional[str] = None,
|
||||
hf_token: Optional[str] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
) -> Tuple[bool, str, Optional[str]]:
|
||||
"""Export model in GGUF format."""
|
||||
return self._run_export(
|
||||
"gguf",
|
||||
|
|
@ -334,7 +465,7 @@ class ExportOrchestrator:
|
|||
repo_id: Optional[str] = None,
|
||||
hf_token: Optional[str] = None,
|
||||
private: bool = False,
|
||||
) -> Tuple[bool, str]:
|
||||
) -> Tuple[bool, str, Optional[str]]:
|
||||
"""Export LoRA adapter only."""
|
||||
return self._run_export(
|
||||
"lora",
|
||||
|
|
@ -347,46 +478,74 @@ class ExportOrchestrator:
|
|||
},
|
||||
)
|
||||
|
||||
def _run_export(self, export_type: str, params: dict) -> Tuple[bool, str]:
|
||||
"""Send an export command to the subprocess and wait for result."""
|
||||
if not self._ensure_subprocess_alive():
|
||||
return False, "No export subprocess running. Load a checkpoint first."
|
||||
def _run_export(
|
||||
self, export_type: str, params: dict
|
||||
) -> Tuple[bool, str, Optional[str]]:
|
||||
"""Send an export command to the subprocess and wait for result.
|
||||
|
||||
cmd = {"type": "export", "export_type": export_type, **params}
|
||||
Returns ``(success, message, output_path)``. ``output_path`` is the
|
||||
resolved on-disk directory the worker actually wrote to (None when
|
||||
the export only pushed to Hub or failed before any file was
|
||||
written). Surfaced via the export route's ``details.output_path``
|
||||
so the dialog's success screen can show the user where the model
|
||||
landed.
|
||||
"""
|
||||
with self._lock:
|
||||
if not self._ensure_subprocess_alive():
|
||||
return (
|
||||
False,
|
||||
"No export subprocess running. Load a checkpoint first.",
|
||||
None,
|
||||
)
|
||||
|
||||
try:
|
||||
self._send_cmd(cmd)
|
||||
resp = self._wait_response(
|
||||
f"export_{export_type}_done",
|
||||
timeout = 3600, # GGUF for 30B+ models can take 30+ min
|
||||
)
|
||||
return resp.get("success", False), resp.get("message", "")
|
||||
except RuntimeError as exc:
|
||||
return False, str(exc)
|
||||
self.clear_logs()
|
||||
self._export_active = True
|
||||
try:
|
||||
cmd = {"type": "export", "export_type": export_type, **params}
|
||||
try:
|
||||
self._send_cmd(cmd)
|
||||
resp = self._wait_response(
|
||||
f"export_{export_type}_done",
|
||||
timeout = 3600, # GGUF for 30B+ models can take 30+ min
|
||||
)
|
||||
return (
|
||||
resp.get("success", False),
|
||||
resp.get("message", ""),
|
||||
resp.get("output_path"),
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
return False, str(exc), None
|
||||
finally:
|
||||
self._export_active = False
|
||||
|
||||
def cleanup_memory(self) -> bool:
|
||||
"""Cleanup export-related models from memory."""
|
||||
if not self._ensure_subprocess_alive():
|
||||
# No subprocess — just clear local state
|
||||
self.current_checkpoint = None
|
||||
self.is_vision = False
|
||||
self.is_peft = False
|
||||
return True
|
||||
with self._lock:
|
||||
if not self._ensure_subprocess_alive():
|
||||
# No subprocess — just clear local state
|
||||
self.current_checkpoint = None
|
||||
self.is_vision = False
|
||||
self.is_peft = False
|
||||
return True
|
||||
|
||||
try:
|
||||
self._send_cmd({"type": "cleanup"})
|
||||
resp = self._wait_response("cleanup_done", timeout = 30)
|
||||
success = resp.get("success", False)
|
||||
except RuntimeError:
|
||||
success = False
|
||||
self._export_active = True
|
||||
try:
|
||||
try:
|
||||
self._send_cmd({"type": "cleanup"})
|
||||
resp = self._wait_response("cleanup_done", timeout = 30)
|
||||
success = resp.get("success", False)
|
||||
except RuntimeError:
|
||||
success = False
|
||||
|
||||
# Shut down subprocess after cleanup — no model loaded
|
||||
self._shutdown_subprocess()
|
||||
# Shut down subprocess after cleanup — no model loaded
|
||||
self._shutdown_subprocess()
|
||||
|
||||
self.current_checkpoint = None
|
||||
self.is_vision = False
|
||||
self.is_peft = False
|
||||
return success
|
||||
self.current_checkpoint = None
|
||||
self.is_vision = False
|
||||
self.is_peft = False
|
||||
return success
|
||||
finally:
|
||||
self._export_active = False
|
||||
|
||||
def scan_checkpoints(
|
||||
self, outputs_dir: str = str(outputs_root())
|
||||
|
|
|
|||
|
|
@ -17,10 +17,12 @@ Pattern follows core/inference/worker.py and core/training/worker.py.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import structlog
|
||||
from loggers import get_logger
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
|
@ -29,6 +31,154 @@ from typing import Any
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# Gate that controls whether captured stdout/stderr lines are forwarded
|
||||
# to the parent's resp_queue (and from there to the export-dialog SSE
|
||||
# stream). Closed by default so the noisy bootstrap phase -- transformers
|
||||
# venv activation, Unsloth/torch imports, base-model resolution, "Top
|
||||
# GGUF/hub models" lists, vision detection, weight loading bars -- is
|
||||
# suppressed in the UI. _handle_export() opens the gate at the start of
|
||||
# the actual export work and leaves it open; the orchestrator always
|
||||
# spawns a fresh subprocess for the next checkpoint load (see
|
||||
# orchestrator._spawn_subprocess) which resets this state.
|
||||
#
|
||||
# Lines dropped while the gate is closed are still echoed to the saved
|
||||
# original stdout/stderr fds so the server console / log file keeps the
|
||||
# full output for debugging.
|
||||
_log_forward_gate = threading.Event()
|
||||
|
||||
|
||||
def _setup_log_capture(resp_queue: Any) -> None:
|
||||
"""Redirect fds 1 and 2 through pipes so every line printed by this
|
||||
worker process and any child process it spawns is forwarded to the
|
||||
parent process via resp_queue as {"type": "log", ...} messages.
|
||||
|
||||
Must be called BEFORE LogConfig.setup_logging and BEFORE any ML
|
||||
imports, otherwise library handlers may capture the original stderr
|
||||
reference and bypass the pipe.
|
||||
|
||||
Lines are also echoed back to the original stdout/stderr so the
|
||||
server console keeps receiving the full subprocess output, even
|
||||
while ``_log_forward_gate`` is closed.
|
||||
"""
|
||||
|
||||
try:
|
||||
saved_out_fd = os.dup(1)
|
||||
saved_err_fd = os.dup(2)
|
||||
except OSError:
|
||||
# dup failed (exotic platforms) - give up quietly, export still
|
||||
# works, just no live log streaming.
|
||||
return
|
||||
|
||||
try:
|
||||
r_out, w_out = os.pipe()
|
||||
r_err, w_err = os.pipe()
|
||||
except OSError:
|
||||
os.close(saved_out_fd)
|
||||
os.close(saved_err_fd)
|
||||
return
|
||||
|
||||
try:
|
||||
os.dup2(w_out, 1)
|
||||
os.dup2(w_err, 2)
|
||||
except OSError:
|
||||
for fd in (saved_out_fd, saved_err_fd, r_out, w_out, r_err, w_err):
|
||||
try:
|
||||
os.close(fd)
|
||||
except OSError:
|
||||
pass
|
||||
return
|
||||
|
||||
# Close the write ends we just dup2'd (fds 1 and 2 are the real
|
||||
# write ends now).
|
||||
os.close(w_out)
|
||||
os.close(w_err)
|
||||
|
||||
# Replace Python's sys.stdout/sys.stderr with line-buffered writers
|
||||
# bound to the (now-redirected) fds 1 and 2.
|
||||
try:
|
||||
sys.stdout = os.fdopen(1, "w", buffering = 1, encoding = "utf-8", errors = "replace")
|
||||
sys.stderr = os.fdopen(2, "w", buffering = 1, encoding = "utf-8", errors = "replace")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _reader(read_fd: int, stream_name: str, echo_fd: int) -> None:
|
||||
buf = bytearray()
|
||||
while True:
|
||||
try:
|
||||
chunk = os.read(read_fd, 4096)
|
||||
except OSError as exc:
|
||||
if exc.errno == errno.EBADF:
|
||||
break
|
||||
continue
|
||||
if not chunk:
|
||||
break
|
||||
# Echo to the original fd so the server console still sees
|
||||
# the full output.
|
||||
try:
|
||||
os.write(echo_fd, chunk)
|
||||
except OSError:
|
||||
pass
|
||||
buf.extend(chunk)
|
||||
# Split on \n OR \r so tqdm-style progress bars update.
|
||||
while True:
|
||||
nl = -1
|
||||
for i, b in enumerate(buf):
|
||||
if b == 0x0A or b == 0x0D:
|
||||
nl = i
|
||||
break
|
||||
if nl < 0:
|
||||
break
|
||||
line = bytes(buf[:nl]).decode("utf-8", errors = "replace")
|
||||
del buf[: nl + 1]
|
||||
if not line:
|
||||
continue
|
||||
if not _log_forward_gate.is_set():
|
||||
# Gate closed (bootstrap phase) -- already echoed to
|
||||
# the saved console fd above; drop the line so the
|
||||
# export dialog doesn't see import / vendoring noise.
|
||||
continue
|
||||
try:
|
||||
resp_queue.put_nowait(
|
||||
{
|
||||
"type": "log",
|
||||
"stream": stream_name,
|
||||
"line": line,
|
||||
"ts": time.time(),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Queue put failed (full, closed, etc.) - drop the
|
||||
# line rather than crash the reader thread.
|
||||
pass
|
||||
if buf and _log_forward_gate.is_set():
|
||||
try:
|
||||
resp_queue.put_nowait(
|
||||
{
|
||||
"type": "log",
|
||||
"stream": stream_name,
|
||||
"line": bytes(buf).decode("utf-8", errors = "replace"),
|
||||
"ts": time.time(),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
t_out = threading.Thread(
|
||||
target = _reader,
|
||||
args = (r_out, "stdout", saved_out_fd),
|
||||
daemon = True,
|
||||
name = "export-log-stdout",
|
||||
)
|
||||
t_err = threading.Thread(
|
||||
target = _reader,
|
||||
args = (r_err, "stderr", saved_err_fd),
|
||||
daemon = True,
|
||||
name = "export-log-stderr",
|
||||
)
|
||||
t_out.start()
|
||||
t_err.start()
|
||||
|
||||
|
||||
def _activate_transformers_version(model_name: str) -> None:
|
||||
"""Activate the correct transformers version BEFORE any ML imports."""
|
||||
# Ensure backend is on path for utils imports
|
||||
|
|
@ -117,9 +267,17 @@ def _handle_export(backend, cmd: dict, resp_queue: Any) -> None:
|
|||
export_type = cmd["export_type"] # "merged", "base", "gguf", "lora"
|
||||
response_type = f"export_{export_type}_done"
|
||||
|
||||
# Open the log forwarding gate so the user sees the actual export
|
||||
# progress (Unsloth merge bars, file copies, GGUF conversion, etc.)
|
||||
# in the live log panel. The gate stays open for the rest of this
|
||||
# subprocess's life; the orchestrator spawns a fresh subprocess for
|
||||
# the next checkpoint load, which resets the gate to closed.
|
||||
_log_forward_gate.set()
|
||||
|
||||
output_path: Any = None
|
||||
try:
|
||||
if export_type == "merged":
|
||||
success, message = backend.export_merged_model(
|
||||
success, message, output_path = backend.export_merged_model(
|
||||
save_directory = cmd.get("save_directory", ""),
|
||||
format_type = cmd.get("format_type", "16-bit (FP16)"),
|
||||
push_to_hub = cmd.get("push_to_hub", False),
|
||||
|
|
@ -128,7 +286,7 @@ def _handle_export(backend, cmd: dict, resp_queue: Any) -> None:
|
|||
private = cmd.get("private", False),
|
||||
)
|
||||
elif export_type == "base":
|
||||
success, message = backend.export_base_model(
|
||||
success, message, output_path = backend.export_base_model(
|
||||
save_directory = cmd.get("save_directory", ""),
|
||||
push_to_hub = cmd.get("push_to_hub", False),
|
||||
repo_id = cmd.get("repo_id"),
|
||||
|
|
@ -137,7 +295,7 @@ def _handle_export(backend, cmd: dict, resp_queue: Any) -> None:
|
|||
base_model_id = cmd.get("base_model_id"),
|
||||
)
|
||||
elif export_type == "gguf":
|
||||
success, message = backend.export_gguf(
|
||||
success, message, output_path = backend.export_gguf(
|
||||
save_directory = cmd.get("save_directory", ""),
|
||||
quantization_method = cmd.get("quantization_method", "Q4_K_M"),
|
||||
push_to_hub = cmd.get("push_to_hub", False),
|
||||
|
|
@ -145,7 +303,7 @@ def _handle_export(backend, cmd: dict, resp_queue: Any) -> None:
|
|||
hf_token = cmd.get("hf_token"),
|
||||
)
|
||||
elif export_type == "lora":
|
||||
success, message = backend.export_lora_adapter(
|
||||
success, message, output_path = backend.export_lora_adapter(
|
||||
save_directory = cmd.get("save_directory", ""),
|
||||
push_to_hub = cmd.get("push_to_hub", False),
|
||||
repo_id = cmd.get("repo_id"),
|
||||
|
|
@ -161,6 +319,7 @@ def _handle_export(backend, cmd: dict, resp_queue: Any) -> None:
|
|||
"type": response_type,
|
||||
"success": success,
|
||||
"message": message,
|
||||
"output_path": output_path,
|
||||
"ts": time.time(),
|
||||
},
|
||||
)
|
||||
|
|
@ -172,6 +331,7 @@ def _handle_export(backend, cmd: dict, resp_queue: Any) -> None:
|
|||
"type": response_type,
|
||||
"success": False,
|
||||
"message": str(exc),
|
||||
"output_path": None,
|
||||
"stack": traceback.format_exc(limit = 20),
|
||||
"ts": time.time(),
|
||||
},
|
||||
|
|
@ -217,10 +377,26 @@ def run_export_process(
|
|||
"""
|
||||
import queue as _queue
|
||||
|
||||
# Install fd-level stdout/stderr capture FIRST so every subsequent
|
||||
# print and every child process inherits the redirected fds. This
|
||||
# is what powers the live export log stream in the UI.
|
||||
_setup_log_capture(resp_queue)
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["PYTHONWARNINGS"] = (
|
||||
"ignore" # Suppress warnings at C-level before imports
|
||||
)
|
||||
# Force unbuffered output from any child Python process (e.g. the
|
||||
# GGUF converter) so their prints surface in the log stream as they
|
||||
# happen rather than at the end.
|
||||
os.environ["PYTHONUNBUFFERED"] = "1"
|
||||
# tqdm defaults to a 10-second mininterval when stdout is not a tty
|
||||
# (which it isn't here -- we redirected fd 1/2 to a pipe). That makes
|
||||
# multi-step progress bars look frozen in the export log panel. Force
|
||||
# frequent flushes so the user sees movement during merge / GGUF
|
||||
# conversion. Has no effect on single-step bars (e.g. "Copying 1
|
||||
# files") which only emit start/end events regardless.
|
||||
os.environ.setdefault("TQDM_MININTERVAL", "0.5")
|
||||
|
||||
import warnings
|
||||
from loggers.config import LogConfig
|
||||
|
|
|
|||
|
|
@ -5,9 +5,15 @@
|
|||
Export API routes: checkpoint discovery and model export operations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
import structlog
|
||||
from loggers import get_logger
|
||||
|
||||
|
|
@ -97,7 +103,11 @@ async def load_checkpoint(
|
|||
logger.warning("Could not stop training: %s", e)
|
||||
|
||||
backend = get_export_backend()
|
||||
success, message = backend.load_checkpoint(
|
||||
# load_checkpoint spawns and waits on a subprocess and can take
|
||||
# minutes. Run it in a worker thread so the event loop stays
|
||||
# free to serve the live log SSE stream concurrently.
|
||||
success, message = await asyncio.to_thread(
|
||||
backend.load_checkpoint,
|
||||
checkpoint_path = request.checkpoint_path,
|
||||
max_seq_length = request.max_seq_length,
|
||||
load_in_4bit = request.load_in_4bit,
|
||||
|
|
@ -129,7 +139,7 @@ async def cleanup_export_memory(
|
|||
"""
|
||||
try:
|
||||
backend = get_export_backend()
|
||||
success = backend.cleanup_memory()
|
||||
success = await asyncio.to_thread(backend.cleanup_memory)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
|
|
@ -173,6 +183,17 @@ async def get_export_status(
|
|||
)
|
||||
|
||||
|
||||
def _export_details(output_path: Optional[str]) -> Optional[Dict[str, Any]]:
|
||||
"""Wrap the resolved on-disk export path into the details dict the
|
||||
frontend reads to populate the Export Complete screen. Returns None
|
||||
when the export had no local component (Hub-only push) so the
|
||||
Pydantic field stays absent rather than ``{"output_path": null}``.
|
||||
"""
|
||||
if not output_path:
|
||||
return None
|
||||
return {"output_path": output_path}
|
||||
|
||||
|
||||
@router.post("/export/merged", response_model = ExportOperationResponse)
|
||||
async def export_merged_model(
|
||||
request: ExportMergedModelRequest,
|
||||
|
|
@ -185,7 +206,8 @@ async def export_merged_model(
|
|||
"""
|
||||
try:
|
||||
backend = get_export_backend()
|
||||
success, message = backend.export_merged_model(
|
||||
success, message, output_path = await asyncio.to_thread(
|
||||
backend.export_merged_model,
|
||||
save_directory = request.save_directory,
|
||||
format_type = request.format_type,
|
||||
push_to_hub = request.push_to_hub,
|
||||
|
|
@ -197,7 +219,11 @@ async def export_merged_model(
|
|||
if not success:
|
||||
raise HTTPException(status_code = 400, detail = message)
|
||||
|
||||
return ExportOperationResponse(success = True, message = message)
|
||||
return ExportOperationResponse(
|
||||
success = True,
|
||||
message = message,
|
||||
details = _export_details(output_path),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
@ -220,7 +246,8 @@ async def export_base_model(
|
|||
"""
|
||||
try:
|
||||
backend = get_export_backend()
|
||||
success, message = backend.export_base_model(
|
||||
success, message, output_path = await asyncio.to_thread(
|
||||
backend.export_base_model,
|
||||
save_directory = request.save_directory,
|
||||
push_to_hub = request.push_to_hub,
|
||||
repo_id = request.repo_id,
|
||||
|
|
@ -232,7 +259,11 @@ async def export_base_model(
|
|||
if not success:
|
||||
raise HTTPException(status_code = 400, detail = message)
|
||||
|
||||
return ExportOperationResponse(success = True, message = message)
|
||||
return ExportOperationResponse(
|
||||
success = True,
|
||||
message = message,
|
||||
details = _export_details(output_path),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
@ -255,7 +286,8 @@ async def export_gguf(
|
|||
"""
|
||||
try:
|
||||
backend = get_export_backend()
|
||||
success, message = backend.export_gguf(
|
||||
success, message, output_path = await asyncio.to_thread(
|
||||
backend.export_gguf,
|
||||
save_directory = request.save_directory,
|
||||
quantization_method = request.quantization_method,
|
||||
push_to_hub = request.push_to_hub,
|
||||
|
|
@ -266,7 +298,11 @@ async def export_gguf(
|
|||
if not success:
|
||||
raise HTTPException(status_code = 400, detail = message)
|
||||
|
||||
return ExportOperationResponse(success = True, message = message)
|
||||
return ExportOperationResponse(
|
||||
success = True,
|
||||
message = message,
|
||||
details = _export_details(output_path),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
@ -289,7 +325,8 @@ async def export_lora_adapter(
|
|||
"""
|
||||
try:
|
||||
backend = get_export_backend()
|
||||
success, message = backend.export_lora_adapter(
|
||||
success, message, output_path = await asyncio.to_thread(
|
||||
backend.export_lora_adapter,
|
||||
save_directory = request.save_directory,
|
||||
push_to_hub = request.push_to_hub,
|
||||
repo_id = request.repo_id,
|
||||
|
|
@ -300,7 +337,11 @@ async def export_lora_adapter(
|
|||
if not success:
|
||||
raise HTTPException(status_code = 400, detail = message)
|
||||
|
||||
return ExportOperationResponse(success = True, message = message)
|
||||
return ExportOperationResponse(
|
||||
success = True,
|
||||
message = message,
|
||||
details = _export_details(output_path),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
@ -309,3 +350,155 @@ async def export_lora_adapter(
|
|||
status_code = 500,
|
||||
detail = f"Failed to export LoRA adapter: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# Live export log stream (Server-Sent Events)
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
#
|
||||
# The export worker subprocess redirects its stdout/stderr into a pipe
|
||||
# that a reader thread forwards to the orchestrator as log entries (see
|
||||
# core/export/worker.py::_setup_log_capture and
|
||||
# core/export/orchestrator.py::_append_log). This endpoint streams
|
||||
# those entries to the browser so the export dialog can show a live
|
||||
# terminal-style output panel while load_checkpoint / export_merged /
|
||||
# export_gguf / export_lora / export_base run.
|
||||
#
|
||||
# Shape follows the training progress SSE endpoint
|
||||
# (routes/training.py::stream_training_progress): each event carries
|
||||
# `id`, `event`, and `data` fields, the stream starts with a `retry:`
|
||||
# directive, and `Last-Event-ID` is honored on reconnect.
|
||||
|
||||
|
||||
def _format_sse(data: str, event: str, event_id: Optional[int] = None) -> str:
|
||||
"""Format a single SSE message with id/event/data fields."""
|
||||
lines = []
|
||||
if event_id is not None:
|
||||
lines.append(f"id: {event_id}")
|
||||
lines.append(f"event: {event}")
|
||||
lines.append(f"data: {data}")
|
||||
lines.append("")
|
||||
lines.append("")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@router.get("/logs/stream")
|
||||
async def stream_export_logs(
|
||||
request: Request,
|
||||
since: Optional[int] = Query(
|
||||
None,
|
||||
description = "Return log entries with seq strictly greater than this cursor.",
|
||||
),
|
||||
current_subject: str = Depends(get_current_subject),
|
||||
):
|
||||
"""
|
||||
Stream live stdout/stderr output from the export worker subprocess
|
||||
as Server-Sent Events.
|
||||
|
||||
Events:
|
||||
- `log` : a single log line (data: {"stream","line","ts"})
|
||||
- `heartbeat`: periodic keepalive when no new lines are available
|
||||
- `complete` : emitted once the export worker is idle and no new
|
||||
lines arrived for ~1 second. Clients should close.
|
||||
- `error` : unrecoverable server-side error
|
||||
|
||||
The `id:` field on each event is the log entry's monotonic seq
|
||||
number so the browser can resume via `Last-Event-ID` on reconnect.
|
||||
"""
|
||||
backend = get_export_backend()
|
||||
|
||||
# Determine starting cursor. Explicit `since` wins, then
|
||||
# Last-Event-ID header on reconnect, otherwise start from the
|
||||
# run-start snapshot captured by clear_logs() so the client sees
|
||||
# every line emitted since the current run began -- even if the
|
||||
# SSE connection opened after the POST that kicked off the export.
|
||||
# Using get_current_log_seq() here would lose the early bootstrap
|
||||
# lines that arrive in the gap between POST and SSE connect.
|
||||
last_event_id = request.headers.get("last-event-id")
|
||||
if since is None and last_event_id is not None:
|
||||
try:
|
||||
since = int(last_event_id)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if since is None:
|
||||
cursor = backend.get_run_start_seq()
|
||||
else:
|
||||
cursor = max(0, int(since))
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
nonlocal cursor
|
||||
# Tell the browser to reconnect after 3 seconds if the
|
||||
# connection drops mid-export.
|
||||
yield "retry: 3000\n\n"
|
||||
|
||||
last_yield = time.monotonic()
|
||||
idle_since: Optional[float] = None
|
||||
try:
|
||||
while True:
|
||||
if await request.is_disconnected():
|
||||
return
|
||||
|
||||
entries, new_cursor = backend.get_logs_since(cursor)
|
||||
if entries:
|
||||
for entry in entries:
|
||||
payload = json.dumps(
|
||||
{
|
||||
"stream": entry.get("stream", "stdout"),
|
||||
"line": entry.get("line", ""),
|
||||
"ts": entry.get("ts"),
|
||||
}
|
||||
)
|
||||
yield _format_sse(
|
||||
payload,
|
||||
event = "log",
|
||||
event_id = int(entry.get("seq", 0)),
|
||||
)
|
||||
cursor = new_cursor
|
||||
last_yield = time.monotonic()
|
||||
idle_since = None
|
||||
else:
|
||||
now = time.monotonic()
|
||||
if now - last_yield > 10.0:
|
||||
yield _format_sse("{}", event = "heartbeat")
|
||||
last_yield = now
|
||||
if not backend.is_export_active():
|
||||
# Give the reader thread a moment to drain any
|
||||
# trailing lines the worker process printed
|
||||
# just before signalling done.
|
||||
if idle_since is None:
|
||||
idle_since = now
|
||||
elif now - idle_since > 1.0:
|
||||
yield _format_sse(
|
||||
"{}",
|
||||
event = "complete",
|
||||
event_id = cursor,
|
||||
)
|
||||
return
|
||||
else:
|
||||
idle_since = None
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError:
|
||||
# Client disconnected mid-yield. Don't re-raise, just end
|
||||
# the generator cleanly so StreamingResponse finalizes.
|
||||
return
|
||||
except Exception as exc:
|
||||
logger.error("Export log stream failed: %s", exc, exc_info = True)
|
||||
try:
|
||||
yield _format_sse(
|
||||
json.dumps({"error": str(exc)}),
|
||||
event = "error",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type = "text/event-stream",
|
||||
headers = {
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
|
|
|||
179
studio/backend/tests/test_export_log_cursor.py
Normal file
179
studio/backend/tests/test_export_log_cursor.py
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
# SPDX-License-Identifier: AGPL-3.0-only
|
||||
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
|
||||
|
||||
"""
|
||||
Regression tests for the export log ring-buffer cursor semantics.
|
||||
|
||||
Context: the live export log SSE stream has a race where the frontend
|
||||
opens the SSE connection AFTER the POST that starts the export. Any
|
||||
lines the worker subprocess emits during the gap between POST and SSE
|
||||
connect get buffered with seqs 1..k, and then the SSE default cursor
|
||||
`get_current_log_seq()` returns k -- so lines 1..k are forever
|
||||
unreachable to that client.
|
||||
|
||||
Fix: `clear_logs()` snapshots the pre-run seq into `_run_start_seq`
|
||||
(exposed via `get_run_start_seq()`), and `routes/export.py` defaults
|
||||
the SSE cursor to that snapshot instead of the current seq. Every line
|
||||
appended during the current run has seq strictly greater than the
|
||||
snapshot, so the client sees the full run regardless of when it
|
||||
connects.
|
||||
|
||||
These tests exercise the orchestrator-side contract only (no
|
||||
subprocess, no FastAPI, no frontend). The routes-level integration
|
||||
with get_run_start_seq() is a one-line edit covered by manual testing
|
||||
and the frontend build.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# Backend root on sys.path so `from core.export.orchestrator import ...`
|
||||
# and friends resolve without the studio app bootstrap.
|
||||
_BACKEND_DIR = Path(__file__).resolve().parent.parent
|
||||
if str(_BACKEND_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(_BACKEND_DIR))
|
||||
|
||||
# ExportOrchestrator imports structlog and a few heavy modules at the
|
||||
# top of orchestrator.py. Stub the ones we don't need in these unit
|
||||
# tests so the import succeeds on machines without the full studio
|
||||
# venv.
|
||||
_loggers_stub = types.ModuleType("loggers")
|
||||
_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name)
|
||||
sys.modules.setdefault("loggers", _loggers_stub)
|
||||
|
||||
# structlog is only used for a module-level import; a bare stub is
|
||||
# enough because we never call into it in these tests.
|
||||
sys.modules.setdefault("structlog", types.ModuleType("structlog"))
|
||||
|
||||
# utils.paths.outputs_root is only called inside scan_checkpoints which
|
||||
# we don't hit in these tests. Provide a stub module so the top-level
|
||||
# import in orchestrator.py resolves.
|
||||
_utils_pkg = types.ModuleType("utils")
|
||||
_utils_pkg.__path__ = [] # mark as package
|
||||
_utils_paths_stub = types.ModuleType("utils.paths")
|
||||
_utils_paths_stub.outputs_root = lambda: Path("/tmp")
|
||||
sys.modules.setdefault("utils", _utils_pkg)
|
||||
sys.modules.setdefault("utils.paths", _utils_paths_stub)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def orchestrator():
|
||||
"""Fresh ExportOrchestrator with only the log-buffer state exercised."""
|
||||
from core.export.orchestrator import ExportOrchestrator
|
||||
|
||||
return ExportOrchestrator()
|
||||
|
||||
|
||||
def _append(orch, line: str, stream: str = "stdout") -> None:
|
||||
"""Shortcut for simulating a worker log message."""
|
||||
orch._append_log({"type": "log", "stream": stream, "line": line, "ts": 0.0})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# clear_logs() semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_run_start_seq_is_zero_before_any_logs(orchestrator) -> None:
|
||||
"""A brand-new orchestrator must report run_start_seq == 0 so a
|
||||
first SSE connection picks up every line from seq 1 onward."""
|
||||
assert orchestrator.get_run_start_seq() == 0
|
||||
|
||||
|
||||
def test_clear_logs_snapshots_current_seq(orchestrator) -> None:
|
||||
"""clear_logs() must capture _log_seq BEFORE clearing the buffer,
|
||||
so subsequent runs can anchor their SSE cursor at the snapshot."""
|
||||
_append(orchestrator, "old run line 1")
|
||||
_append(orchestrator, "old run line 2")
|
||||
_append(orchestrator, "old run line 3")
|
||||
assert orchestrator.get_current_log_seq() == 3
|
||||
|
||||
orchestrator.clear_logs()
|
||||
|
||||
assert orchestrator.get_run_start_seq() == 3
|
||||
assert orchestrator.get_current_log_seq() == 3 # seq counter preserved
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Race regression: SSE connects AFTER lines have been emitted
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_sse_default_cursor_catches_all_current_run_lines(orchestrator) -> None:
|
||||
"""Simulate the POST-then-SSE race: worker starts emitting lines
|
||||
immediately after clear_logs(), SSE connects several lines later.
|
||||
Using get_run_start_seq() as the default cursor MUST return every
|
||||
line emitted since clear_logs() ran.
|
||||
|
||||
Pre-fix, the SSE defaulted to get_current_log_seq() at connect
|
||||
time, which would return the last-seen seq and miss lines N+1..M.
|
||||
"""
|
||||
# Previous run leaves some buffered lines.
|
||||
_append(orchestrator, "previous run line A")
|
||||
_append(orchestrator, "previous run line B")
|
||||
|
||||
# New run starts: orchestrator clears the buffer and snapshots seq.
|
||||
orchestrator.clear_logs()
|
||||
run_start = orchestrator.get_run_start_seq()
|
||||
|
||||
# Worker emits early lines BEFORE the SSE connects.
|
||||
_append(orchestrator, "Importing Unsloth...")
|
||||
_append(orchestrator, "Loading checkpoint: /foo/bar")
|
||||
_append(orchestrator, "Starting export...")
|
||||
|
||||
# SSE connects now and asks "give me everything after the run
|
||||
# start cursor".
|
||||
entries, new_cursor = orchestrator.get_logs_since(run_start)
|
||||
|
||||
# All three early lines must be present. Pre-fix this was [].
|
||||
lines = [e["line"] for e in entries]
|
||||
assert lines == [
|
||||
"Importing Unsloth...",
|
||||
"Loading checkpoint: /foo/bar",
|
||||
"Starting export...",
|
||||
]
|
||||
assert new_cursor == entries[-1]["seq"]
|
||||
|
||||
|
||||
def test_sse_default_cursor_excludes_previous_run(orchestrator) -> None:
|
||||
"""After clear_logs(), lines from the PREVIOUS run must not leak
|
||||
into the new run's SSE stream. Pre-fix this worked correctly
|
||||
(clear_logs cleared the deque); the fix must preserve it.
|
||||
"""
|
||||
_append(orchestrator, "previous run line 1")
|
||||
_append(orchestrator, "previous run line 2")
|
||||
_append(orchestrator, "previous run line 3")
|
||||
assert orchestrator.get_current_log_seq() == 3
|
||||
|
||||
orchestrator.clear_logs()
|
||||
run_start = orchestrator.get_run_start_seq()
|
||||
|
||||
_append(orchestrator, "new run line")
|
||||
|
||||
entries, _ = orchestrator.get_logs_since(run_start)
|
||||
assert [e["line"] for e in entries] == ["new run line"]
|
||||
|
||||
|
||||
def test_clear_logs_twice_advances_run_start(orchestrator) -> None:
|
||||
"""Back-to-back clear_logs() calls (e.g. cleanup -> load ->
|
||||
export in the same dialog session) must each re-anchor run_start
|
||||
at the current seq, so successive runs each start with a fresh
|
||||
low-water mark."""
|
||||
_append(orchestrator, "run 1 line a")
|
||||
_append(orchestrator, "run 1 line b")
|
||||
|
||||
orchestrator.clear_logs()
|
||||
assert orchestrator.get_run_start_seq() == 2
|
||||
|
||||
_append(orchestrator, "run 2 line a")
|
||||
_append(orchestrator, "run 2 line b")
|
||||
_append(orchestrator, "run 2 line c")
|
||||
|
||||
orchestrator.clear_logs()
|
||||
assert orchestrator.get_run_start_seq() == 5
|
||||
|
|
@ -42,7 +42,13 @@ export interface CheckpointListResponse {
|
|||
export interface ExportOperationResponse {
|
||||
success: boolean;
|
||||
message: string;
|
||||
details?: Record<string, unknown> | null;
|
||||
/**
|
||||
* Optional extras returned by the backend. The export endpoints set
|
||||
* `details.output_path` to the resolved on-disk directory of the
|
||||
* saved model when a local save was requested. Hub-only pushes leave
|
||||
* `details` undefined.
|
||||
*/
|
||||
details?: { output_path?: string | null } & Record<string, unknown>;
|
||||
}
|
||||
|
||||
export async function fetchCheckpoints(): Promise<CheckpointListResponse> {
|
||||
|
|
@ -131,3 +137,172 @@ export async function cleanupExport(): Promise<ExportOperationResponse> {
|
|||
const response = await authFetch("/api/export/cleanup", { method: "POST" });
|
||||
return parseJson<ExportOperationResponse>(response);
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────
|
||||
// Live export log stream (Server-Sent Events)
|
||||
// ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
export type ExportLogStream = "stdout" | "stderr" | "status";
|
||||
|
||||
export interface ExportLogEntry {
|
||||
stream: ExportLogStream;
|
||||
line: string;
|
||||
ts: number | null;
|
||||
}
|
||||
|
||||
export type ExportLogEventName = "log" | "heartbeat" | "complete" | "error";
|
||||
|
||||
export interface ExportLogEvent {
|
||||
event: ExportLogEventName;
|
||||
id: number | null;
|
||||
/** Present on `log` events. */
|
||||
entry?: ExportLogEntry;
|
||||
/** Present on `error` events. */
|
||||
error?: string;
|
||||
}
|
||||
|
||||
interface ParsedSseMessage {
|
||||
event: string;
|
||||
id: number | null;
|
||||
data: string;
|
||||
}
|
||||
|
||||
function parseSseMessage(raw: string): ParsedSseMessage | null {
|
||||
const lines = raw.split(/\r?\n/);
|
||||
let event = "message";
|
||||
let id: number | null = null;
|
||||
const dataLines: string[] = [];
|
||||
|
||||
for (const line of lines) {
|
||||
if (!line) continue;
|
||||
if (line.startsWith("event:")) {
|
||||
event = line.slice(6).trim();
|
||||
continue;
|
||||
}
|
||||
if (line.startsWith("id:")) {
|
||||
const value = Number(line.slice(3).trim());
|
||||
id = Number.isFinite(value) ? value : null;
|
||||
continue;
|
||||
}
|
||||
if (line.startsWith("data:")) {
|
||||
dataLines.push(line.slice(5).trimStart());
|
||||
continue;
|
||||
}
|
||||
// Comment lines (":heartbeat" etc.) are ignored per SSE spec.
|
||||
}
|
||||
|
||||
if (dataLines.length === 0) return null;
|
||||
return { event, id, data: dataLines.join("\n") };
|
||||
}
|
||||
|
||||
function isAbortError(error: unknown): boolean {
|
||||
return error instanceof DOMException && error.name === "AbortError";
|
||||
}
|
||||
|
||||
export async function streamExportLogs(options: {
|
||||
signal: AbortSignal;
|
||||
since?: number | null;
|
||||
onOpen?: () => void;
|
||||
onEvent: (event: ExportLogEvent) => void;
|
||||
}): Promise<void> {
|
||||
const headers = new Headers();
|
||||
if (typeof options.since === "number") {
|
||||
headers.set("Last-Event-ID", String(options.since));
|
||||
}
|
||||
|
||||
const url =
|
||||
typeof options.since === "number"
|
||||
? `/api/export/logs/stream?since=${options.since}`
|
||||
: "/api/export/logs/stream";
|
||||
|
||||
const response = await authFetch(url, {
|
||||
method: "GET",
|
||||
headers,
|
||||
signal: options.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(await readError(response));
|
||||
}
|
||||
if (!response.body) {
|
||||
throw new Error("Export log stream unavailable");
|
||||
}
|
||||
|
||||
options.onOpen?.();
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { value, done } = await reader.read();
|
||||
if (done) return;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
|
||||
let separatorIndex = buffer.search(/\r?\n\r?\n/);
|
||||
while (separatorIndex >= 0) {
|
||||
const rawEvent = buffer.slice(0, separatorIndex);
|
||||
const separatorLength = buffer[separatorIndex] === "\r" ? 4 : 2;
|
||||
buffer = buffer.slice(separatorIndex + separatorLength);
|
||||
|
||||
if (rawEvent.startsWith("retry:") || rawEvent.startsWith(":")) {
|
||||
separatorIndex = buffer.search(/\r?\n\r?\n/);
|
||||
continue;
|
||||
}
|
||||
|
||||
const parsed = parseSseMessage(rawEvent);
|
||||
if (!parsed) {
|
||||
separatorIndex = buffer.search(/\r?\n\r?\n/);
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
if (parsed.event === "log") {
|
||||
const payload = JSON.parse(parsed.data) as {
|
||||
stream?: ExportLogStream;
|
||||
line?: string;
|
||||
ts?: number | null;
|
||||
};
|
||||
options.onEvent({
|
||||
event: "log",
|
||||
id: parsed.id,
|
||||
entry: {
|
||||
stream: payload.stream ?? "stdout",
|
||||
line: payload.line ?? "",
|
||||
ts: payload.ts ?? null,
|
||||
},
|
||||
});
|
||||
} else if (parsed.event === "heartbeat") {
|
||||
options.onEvent({ event: "heartbeat", id: parsed.id });
|
||||
} else if (parsed.event === "complete") {
|
||||
options.onEvent({ event: "complete", id: parsed.id });
|
||||
return;
|
||||
} else if (parsed.event === "error") {
|
||||
let errorMessage = "Export log stream error";
|
||||
try {
|
||||
const payload = JSON.parse(parsed.data) as { error?: string };
|
||||
if (payload.error) errorMessage = payload.error;
|
||||
} catch {
|
||||
// fall through with default message
|
||||
}
|
||||
options.onEvent({
|
||||
event: "error",
|
||||
id: parsed.id,
|
||||
error: errorMessage,
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
if (isAbortError(err)) return;
|
||||
// Ignore malformed events, keep reading.
|
||||
}
|
||||
|
||||
separatorIndex = buffer.search(/\r?\n\r?\n/);
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
if (isAbortError(err)) return;
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,9 +21,199 @@ import { Switch } from "@/components/ui/switch";
|
|||
import { AlertCircleIcon, ArrowRight01Icon, CheckmarkCircle02Icon, Key01Icon } from "@hugeicons/core-free-icons";
|
||||
import { HugeiconsIcon } from "@hugeicons/react";
|
||||
import { AnimatePresence, motion } from "motion/react";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { streamExportLogs, type ExportLogEntry } from "../api/export-api";
|
||||
import { collapseAnim } from "../anim";
|
||||
import { EXPORT_METHODS, type ExportMethod } from "../constants";
|
||||
|
||||
// Max log lines kept in the dialog's local state. Matches the backend
|
||||
// ring buffer's maxlen so the UI shows the full scrollback captured
|
||||
// server side.
|
||||
const MAX_LOG_LINES = 4000;
|
||||
|
||||
interface UseExportLogsResult {
|
||||
lines: ExportLogEntry[];
|
||||
connected: boolean;
|
||||
error: string | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Subscribe to the live export log SSE stream while `exporting` is
|
||||
* true, and accumulate lines in local state. Lines from a previous
|
||||
* action are cleared:
|
||||
*
|
||||
* - when a new export starts (`exporting` flips to true), and
|
||||
* - when the user switches export method, dialog opens fresh, or
|
||||
* the dialog closes — so re-opening into a different action's
|
||||
* screen doesn't show the prior screen's saved output.
|
||||
*/
|
||||
function useExportLogs(
|
||||
exporting: boolean,
|
||||
exportMethod: ExportMethod | null,
|
||||
open: boolean,
|
||||
): UseExportLogsResult {
|
||||
const [lines, setLines] = useState<ExportLogEntry[]>([]);
|
||||
const [connected, setConnected] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
// Reset log state whenever the user moves to a different screen --
|
||||
// either by switching export method or by reopening the dialog -- so
|
||||
// each (open × method) tuple shows only its own run history. The
|
||||
// streaming effect below additionally clears on new export start.
|
||||
useEffect(() => {
|
||||
setLines([]);
|
||||
setError(null);
|
||||
setConnected(false);
|
||||
}, [exportMethod, open]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!exporting) return;
|
||||
|
||||
setLines([]);
|
||||
setError(null);
|
||||
|
||||
const abortCtrl = new AbortController();
|
||||
let cancelled = false;
|
||||
// Track the highest seq we've observed on a `log` event so we can
|
||||
// resume the stream via `since=` / `Last-Event-ID` after a drop.
|
||||
// The backend's SSE `id:` field carries this as ExportLogEvent.id.
|
||||
let lastSeq: number | null = null;
|
||||
// Exponential backoff with jitter, capped. Reset on every
|
||||
// successful connection so flaky networks don't accumulate delay.
|
||||
let backoffMs = 500;
|
||||
const MAX_BACKOFF_MS = 5000;
|
||||
// Flipped by a terminal event (explicit `complete` from the
|
||||
// backend or a non-transient error we choose not to retry). Stops
|
||||
// the outer reconnect loop even if `exporting` is still true.
|
||||
let terminated = false;
|
||||
|
||||
const run = async () => {
|
||||
while (!cancelled && !terminated) {
|
||||
try {
|
||||
await streamExportLogs({
|
||||
signal: abortCtrl.signal,
|
||||
since: lastSeq,
|
||||
onOpen: () => {
|
||||
if (cancelled) return;
|
||||
setConnected(true);
|
||||
// Reset backoff on every successful connect so later
|
||||
// drops don't inherit accumulated delay from earlier ones.
|
||||
backoffMs = 500;
|
||||
},
|
||||
onEvent: (event) => {
|
||||
if (cancelled) return;
|
||||
if (event.event === "log" && event.entry) {
|
||||
if (typeof event.id === "number") {
|
||||
lastSeq = event.id;
|
||||
}
|
||||
const entry = event.entry;
|
||||
setLines((prev) => {
|
||||
const next = prev.length >= MAX_LOG_LINES
|
||||
? prev.slice(prev.length - MAX_LOG_LINES + 1)
|
||||
: prev.slice();
|
||||
next.push(entry);
|
||||
return next;
|
||||
});
|
||||
} else if (event.event === "complete") {
|
||||
// Backend signalled the run is fully drained -- stop
|
||||
// trying to reconnect even though `exporting` may not
|
||||
// have flipped false yet on this tick.
|
||||
terminated = true;
|
||||
} else if (event.event === "error" && event.error) {
|
||||
setError(event.error);
|
||||
}
|
||||
},
|
||||
});
|
||||
} catch (err: unknown) {
|
||||
if (cancelled) return;
|
||||
if (err instanceof DOMException && err.name === "AbortError") return;
|
||||
setError(err instanceof Error ? err.message : String(err));
|
||||
// Fall through to the backoff path below; a fetch-level
|
||||
// failure is retryable the same way a clean EOF is.
|
||||
}
|
||||
|
||||
setConnected(false);
|
||||
if (cancelled || terminated) return;
|
||||
|
||||
// Exponential backoff with jitter before reconnecting. The
|
||||
// backend's ring buffer plus Last-Event-ID resume means we
|
||||
// don't lose lines across the retry as long as the reconnect
|
||||
// happens within the buffer's lifetime (~4000 lines).
|
||||
const delay = backoffMs + Math.floor(Math.random() * 250);
|
||||
backoffMs = Math.min(backoffMs * 2, MAX_BACKOFF_MS);
|
||||
try {
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
if (abortCtrl.signal.aborted) {
|
||||
reject(new DOMException("Aborted", "AbortError"));
|
||||
return;
|
||||
}
|
||||
const timeoutId = window.setTimeout(resolve, delay);
|
||||
abortCtrl.signal.addEventListener(
|
||||
"abort",
|
||||
() => {
|
||||
window.clearTimeout(timeoutId);
|
||||
reject(new DOMException("Aborted", "AbortError"));
|
||||
},
|
||||
{ once: true },
|
||||
);
|
||||
});
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// run()'s own try/catch handles every failure path we care about;
|
||||
// swallow anything that somehow escapes so React's dev overlay
|
||||
// doesn't flag an unhandled rejection on dialog close.
|
||||
void run().catch(() => {});
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
abortCtrl.abort();
|
||||
setConnected(false);
|
||||
};
|
||||
}, [exporting]);
|
||||
|
||||
return { lines, connected, error };
|
||||
}
|
||||
|
||||
/**
|
||||
* Tick every second while `exporting` is true and report elapsed
|
||||
* seconds. Powers the "Working… 27s" badge in the log header so the
|
||||
* panel doesn't look frozen during long single-step phases (cache
|
||||
* file copy, GGUF conversion) when no new lines are arriving.
|
||||
*/
|
||||
function useElapsedSeconds(exporting: boolean): number {
|
||||
const [elapsed, setElapsed] = useState(0);
|
||||
useEffect(() => {
|
||||
if (!exporting) {
|
||||
setElapsed(0);
|
||||
return;
|
||||
}
|
||||
const startedAt = Date.now();
|
||||
setElapsed(0);
|
||||
const id = window.setInterval(() => {
|
||||
setElapsed(Math.floor((Date.now() - startedAt) / 1000));
|
||||
}, 1000);
|
||||
return () => window.clearInterval(id);
|
||||
}, [exporting]);
|
||||
return elapsed;
|
||||
}
|
||||
|
||||
function formatElapsed(seconds: number): string {
|
||||
if (seconds < 60) return `${seconds}s`;
|
||||
const m = Math.floor(seconds / 60);
|
||||
const s = seconds % 60;
|
||||
return `${m}m ${s.toString().padStart(2, "0")}s`;
|
||||
}
|
||||
|
||||
function formatLogLine(entry: ExportLogEntry): string {
|
||||
// Strip trailing carriage returns that tqdm-style progress leaves
|
||||
// in the stream so the scrollback doesn't render funky boxes.
|
||||
return entry.line.replace(/\r+$/g, "");
|
||||
}
|
||||
|
||||
type Destination = "local" | "hub";
|
||||
|
||||
interface ExportDialogProps {
|
||||
|
|
@ -49,6 +239,12 @@ interface ExportDialogProps {
|
|||
exporting: boolean;
|
||||
exportError: string | null;
|
||||
exportSuccess: boolean;
|
||||
/**
|
||||
* Resolved on-disk realpath of the most recent successful export.
|
||||
* Surfaced on the Export Complete screen so users can find their
|
||||
* model. Null when the export only pushed to the Hub.
|
||||
*/
|
||||
exportOutputPath: string | null;
|
||||
}
|
||||
|
||||
export function ExportDialog({
|
||||
|
|
@ -74,7 +270,38 @@ export function ExportDialog({
|
|||
exporting,
|
||||
exportError,
|
||||
exportSuccess,
|
||||
exportOutputPath,
|
||||
}: ExportDialogProps) {
|
||||
// Live log capture is only meaningful for export methods that run
|
||||
// a slow subprocess operation with interesting stdout: merged and
|
||||
// gguf. LoRA adapter export is a fast disk write and would just
|
||||
// show a blank panel, so we hide it there.
|
||||
const showLogPanel =
|
||||
exportMethod === "merged" || exportMethod === "gguf";
|
||||
|
||||
const { lines: logLines, connected: logConnected, error: logError } =
|
||||
useExportLogs(exporting && showLogPanel, exportMethod, open);
|
||||
const elapsedSeconds = useElapsedSeconds(exporting && showLogPanel);
|
||||
|
||||
const logScrollRef = useRef<HTMLDivElement | null>(null);
|
||||
// Auto-scroll to bottom whenever a new line arrives, unless the
|
||||
// user has scrolled up to read earlier output.
|
||||
const [followTail, setFollowTail] = useState(true);
|
||||
|
||||
useEffect(() => {
|
||||
if (!followTail) return;
|
||||
const el = logScrollRef.current;
|
||||
if (el) el.scrollTop = el.scrollHeight;
|
||||
}, [logLines, followTail]);
|
||||
|
||||
const handleLogScroll = () => {
|
||||
const el = logScrollRef.current;
|
||||
if (!el) return;
|
||||
const nearBottom =
|
||||
el.scrollHeight - el.scrollTop - el.clientHeight < 24;
|
||||
setFollowTail(nearBottom);
|
||||
};
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
|
|
@ -83,20 +310,36 @@ export function ExportDialog({
|
|||
onOpenChange(v);
|
||||
}}
|
||||
>
|
||||
<DialogContent className="sm:max-w-lg" onInteractOutside={(e) => { if (exporting) e.preventDefault(); }}>
|
||||
<DialogContent
|
||||
className={showLogPanel ? "sm:max-w-2xl" : "sm:max-w-lg"}
|
||||
onInteractOutside={(e) => { if (exporting) e.preventDefault(); }}
|
||||
>
|
||||
{exportSuccess ? (
|
||||
<>
|
||||
<div className="flex flex-col items-center gap-3 py-6">
|
||||
<div className="flex size-12 items-center justify-center rounded-full bg-emerald-500/10">
|
||||
<HugeiconsIcon icon={CheckmarkCircle02Icon} className="size-6 text-emerald-500" />
|
||||
</div>
|
||||
<div className="text-center">
|
||||
<div className="flex flex-col items-center gap-2 text-center">
|
||||
<h3 className="text-lg font-semibold">Export Complete</h3>
|
||||
<p className="mt-1 text-sm text-muted-foreground">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
{destination === "hub"
|
||||
? "Model successfully pushed to Hugging Face Hub."
|
||||
: "Model saved locally."}
|
||||
</p>
|
||||
{exportOutputPath ? (
|
||||
<div className="mt-1 flex w-full max-w-md flex-col items-stretch gap-1 rounded-lg border border-border/40 bg-muted/40 px-3 py-2 text-left">
|
||||
<span className="text-[11px] font-medium uppercase tracking-wide text-muted-foreground">
|
||||
Saved to
|
||||
</span>
|
||||
<code
|
||||
className="select-all break-all font-mono text-[12px] text-foreground"
|
||||
title={exportOutputPath}
|
||||
>
|
||||
{exportOutputPath}
|
||||
</code>
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
</div>
|
||||
<DialogFooter>
|
||||
|
|
@ -256,6 +499,78 @@ export function ExportDialog({
|
|||
</div> */}
|
||||
</div>
|
||||
|
||||
{/* Live export output panel */}
|
||||
<AnimatePresence>
|
||||
{showLogPanel && (exporting || logLines.length > 0) && (
|
||||
<motion.div {...collapseAnim} className="overflow-hidden">
|
||||
<div className="flex flex-col gap-1.5 pt-1">
|
||||
<div className="flex items-center justify-between">
|
||||
<label className="text-xs font-medium text-muted-foreground">
|
||||
Export output
|
||||
</label>
|
||||
<div className="flex items-center gap-2 text-[11px] text-muted-foreground/80">
|
||||
<span
|
||||
className={
|
||||
logConnected
|
||||
? "inline-block size-1.5 rounded-full bg-emerald-500"
|
||||
: "inline-block size-1.5 rounded-full bg-muted-foreground/40"
|
||||
}
|
||||
/>
|
||||
<span>
|
||||
{logConnected
|
||||
? "streaming"
|
||||
: exporting
|
||||
? "connecting..."
|
||||
: "idle"}
|
||||
</span>
|
||||
{exporting && elapsedSeconds > 0 ? (
|
||||
<span className="tabular-nums text-muted-foreground/70">
|
||||
· {formatElapsed(elapsedSeconds)}
|
||||
</span>
|
||||
) : null}
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
ref={logScrollRef}
|
||||
onScroll={handleLogScroll}
|
||||
className="h-56 w-full overflow-auto rounded-lg border border-border/40 bg-black/85 p-3 font-mono text-[11px] leading-[1.45] text-emerald-200/90"
|
||||
>
|
||||
{logLines.length === 0 ? (
|
||||
<div className="flex h-full items-center justify-center text-muted-foreground/70">
|
||||
<span className="flex items-center gap-2">
|
||||
<Spinner className="size-3" />
|
||||
Waiting for worker output...
|
||||
</span>
|
||||
</div>
|
||||
) : (
|
||||
<pre className="whitespace-pre-wrap break-words">
|
||||
{logLines.map((entry, idx) => (
|
||||
<div
|
||||
key={idx}
|
||||
className={
|
||||
entry.stream === "stderr"
|
||||
? "text-rose-300/90"
|
||||
: entry.stream === "status"
|
||||
? "text-sky-300/90"
|
||||
: ""
|
||||
}
|
||||
>
|
||||
{formatLogLine(entry)}
|
||||
</div>
|
||||
))}
|
||||
</pre>
|
||||
)}
|
||||
</div>
|
||||
{logError && (
|
||||
<p className="text-[11px] text-destructive/80">
|
||||
Log stream: {logError}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
|
||||
<DialogFooter>
|
||||
<Button
|
||||
variant="outline"
|
||||
|
|
|
|||
|
|
@ -121,6 +121,10 @@ export function ExportPage() {
|
|||
const [exporting, setExporting] = useState(false);
|
||||
const [exportError, setExportError] = useState<string | null>(null);
|
||||
const [exportSuccess, setExportSuccess] = useState(false);
|
||||
// Resolved on-disk path of the most recent successful export, surfaced
|
||||
// on the Export Complete screen so the user can find their model
|
||||
// without digging through the server log. Null for Hub-only pushes.
|
||||
const [exportOutputPath, setExportOutputPath] = useState<string | null>(null);
|
||||
|
||||
const hfComboboxAnchorRef = useRef<HTMLDivElement>(null);
|
||||
const localComboboxAnchorRef = useRef<HTMLDivElement>(null);
|
||||
|
|
@ -405,6 +409,7 @@ export function ExportPage() {
|
|||
setExporting(true);
|
||||
setExportError(null);
|
||||
setExportSuccess(false);
|
||||
setExportOutputPath(null);
|
||||
|
||||
// For GGUF, use a flat folder like "exports/gemma-3-4b-it-finetune-gguf"
|
||||
// For other formats, nest under training-run/checkpoint
|
||||
|
|
@ -433,18 +438,24 @@ export function ExportPage() {
|
|||
});
|
||||
}
|
||||
|
||||
// 2. Run export based on method
|
||||
// 2. Run export based on method. Capture the resolved output_path
|
||||
// (when the backend wrote a local copy) so the success screen can
|
||||
// show the user the realpath of their saved model. For multi-quant
|
||||
// GGUF runs, the directory is the same for every quant so we just
|
||||
// keep the last response.
|
||||
let lastOutputPath: string | null = null;
|
||||
if (exportMethod === "merged") {
|
||||
if (isAdapter) {
|
||||
await exportMerged({
|
||||
const resp = await exportMerged({
|
||||
save_directory: saveDir,
|
||||
push_to_hub: pushToHub,
|
||||
repo_id: repoId,
|
||||
hf_token: token,
|
||||
private: privateRepo,
|
||||
});
|
||||
lastOutputPath = resp.details?.output_path ?? null;
|
||||
} else {
|
||||
await exportBase({
|
||||
const resp = await exportBase({
|
||||
save_directory: saveDir,
|
||||
push_to_hub: pushToHub,
|
||||
repo_id: repoId,
|
||||
|
|
@ -452,27 +463,31 @@ export function ExportPage() {
|
|||
private: privateRepo,
|
||||
base_model_id: selectedModelData?.base_model,
|
||||
});
|
||||
lastOutputPath = resp.details?.output_path ?? null;
|
||||
}
|
||||
} else if (exportMethod === "gguf") {
|
||||
for (const quant of quantLevels) {
|
||||
await exportGGUF({
|
||||
const resp = await exportGGUF({
|
||||
save_directory: saveDir,
|
||||
quantization_method: quant,
|
||||
push_to_hub: pushToHub,
|
||||
repo_id: repoId,
|
||||
hf_token: token,
|
||||
});
|
||||
lastOutputPath = resp.details?.output_path ?? lastOutputPath;
|
||||
}
|
||||
} else if (exportMethod === "lora") {
|
||||
await exportLoRA({
|
||||
const resp = await exportLoRA({
|
||||
save_directory: saveDir,
|
||||
push_to_hub: pushToHub,
|
||||
repo_id: repoId,
|
||||
hf_token: token,
|
||||
private: privateRepo,
|
||||
});
|
||||
lastOutputPath = resp.details?.output_path ?? null;
|
||||
}
|
||||
|
||||
setExportOutputPath(lastOutputPath);
|
||||
setExportSuccess(true);
|
||||
} catch (err) {
|
||||
setExportError(
|
||||
|
|
@ -1080,6 +1095,7 @@ export function ExportPage() {
|
|||
exporting={exporting}
|
||||
exportError={exportError}
|
||||
exportSuccess={exportSuccess}
|
||||
exportOutputPath={exportOutputPath}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
|
|
|||
160
tests/test_cli_export_unpacking.py
Normal file
160
tests/test_cli_export_unpacking.py
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
# SPDX-License-Identifier: AGPL-3.0-only
|
||||
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
|
||||
|
||||
"""
|
||||
Regression tests for unsloth_cli.commands.export.
|
||||
|
||||
Context: the studio export dialog live-logs work changed
|
||||
ExportOrchestrator.export_{merged_model,base_model,gguf,lora_adapter}
|
||||
to return (success, message, output_path) instead of (success, message)
|
||||
so the frontend can show the on-disk realpath on the success screen.
|
||||
The CLI at unsloth_cli/commands/export.py still unpacks two values,
|
||||
so every `unsloth export --format ...` crashes with:
|
||||
|
||||
ValueError: too many values to unpack (expected 2)
|
||||
|
||||
These tests pin the CLI to the 3-tuple contract by invoking it against
|
||||
a fake ExportBackend and asserting exit_code == 0 for each --format.
|
||||
No real ML imports; the fake is installed via sys.modules injection so
|
||||
the CLI's deferred `from studio.backend.core.export import ExportBackend`
|
||||
binds to it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import typer
|
||||
from typer.testing import CliRunner
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake ExportBackend
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeExportBackend:
|
||||
"""Stand-in for studio.backend.core.export.ExportBackend.
|
||||
|
||||
All export_* methods return the new 3-tuple contract. load_checkpoint
|
||||
keeps its 2-tuple shape (unchanged by the live-logs work).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.loaded: str | None = None
|
||||
|
||||
def load_checkpoint(self, **kwargs):
|
||||
self.loaded = kwargs.get("checkpoint_path")
|
||||
return True, f"Loaded {self.loaded}"
|
||||
|
||||
def scan_checkpoints(self, **kwargs):
|
||||
return []
|
||||
|
||||
def export_merged_model(self, **kwargs):
|
||||
return True, "merged ok", str(Path(kwargs["save_directory"]).resolve())
|
||||
|
||||
def export_base_model(self, **kwargs):
|
||||
return True, "base ok", str(Path(kwargs["save_directory"]).resolve())
|
||||
|
||||
def export_gguf(self, **kwargs):
|
||||
return True, "gguf ok", str(Path(kwargs["save_directory"]).resolve())
|
||||
|
||||
def export_lora_adapter(self, **kwargs):
|
||||
return True, "lora ok", str(Path(kwargs["save_directory"]).resolve())
|
||||
|
||||
|
||||
def _install_fake_studio_backend(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Inject fake studio.backend.core.export into sys.modules.
|
||||
|
||||
The CLI imports ExportBackend lazily inside the command function, so
|
||||
patching sys.modules before invoking the command is sufficient to
|
||||
steer the `from studio.backend.core.export import ExportBackend`
|
||||
statement at the fake. Parent packages (studio, studio.backend,
|
||||
studio.backend.core) are stubbed too so Python's import machinery
|
||||
doesn't try to resolve the real (structlog-dependent) tree.
|
||||
"""
|
||||
for name in ("studio", "studio.backend", "studio.backend.core"):
|
||||
monkeypatch.setitem(sys.modules, name, types.ModuleType(name))
|
||||
|
||||
fake_mod = types.ModuleType("studio.backend.core.export")
|
||||
fake_mod.ExportBackend = _FakeExportBackend
|
||||
monkeypatch.setitem(sys.modules, "studio.backend.core.export", fake_mod)
|
||||
|
||||
# Drop any cached import of the CLI module so the deferred import
|
||||
# inside export() re-resolves against our fake module rather than a
|
||||
# previously cached real one.
|
||||
monkeypatch.delitem(sys.modules, "unsloth_cli.commands.export", raising = False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cli_app(monkeypatch: pytest.MonkeyPatch) -> typer.Typer:
|
||||
"""Typer app wrapping unsloth_cli.commands.export.export."""
|
||||
_install_fake_studio_backend(monkeypatch)
|
||||
from unsloth_cli.commands import export as export_cmd
|
||||
|
||||
app = typer.Typer()
|
||||
app.command("export")(export_cmd.export)
|
||||
|
||||
# Typer flattens a single-command app into that command, which would
|
||||
# make argv[0] ("export") look like an extra positional argument to
|
||||
# the test invocation. Register a harmless second command so Typer
|
||||
# keeps "export" as a real subcommand and the tests drive the
|
||||
# intended code path.
|
||||
@app.command("noop")
|
||||
def _noop() -> None: # pragma: no cover - only exists to pin routing
|
||||
pass
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner() -> CliRunner:
|
||||
return CliRunner()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# The actual regression tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"format_flag,quant_flag",
|
||||
[
|
||||
("merged-16bit", None),
|
||||
("merged-4bit", None),
|
||||
("gguf", "q4_k_m"),
|
||||
("lora", None),
|
||||
],
|
||||
)
|
||||
def test_cli_export_unpacks_three_tuple(
|
||||
cli_app: typer.Typer,
|
||||
runner: CliRunner,
|
||||
tmp_path: Path,
|
||||
format_flag: str,
|
||||
quant_flag: str | None,
|
||||
) -> None:
|
||||
"""Each --format path must unpack (success, message, output_path)
|
||||
without raising ValueError. Pre-fix, every parametrized case fails
|
||||
with 'too many values to unpack (expected 2)'.
|
||||
"""
|
||||
ckpt = tmp_path / "ckpt"
|
||||
ckpt.mkdir()
|
||||
out = tmp_path / "out"
|
||||
|
||||
cli_args = ["export", str(ckpt), str(out), "--format", format_flag]
|
||||
if quant_flag is not None:
|
||||
cli_args += ["--quantization", quant_flag]
|
||||
|
||||
result = runner.invoke(cli_app, cli_args)
|
||||
|
||||
assert result.exit_code == 0, (
|
||||
f"CLI exited with code {result.exit_code} for --format {format_flag}.\n"
|
||||
f"Output:\n{result.output}\n"
|
||||
f"Exception: {result.exception!r}"
|
||||
)
|
||||
# Sanity: the success message from the fake backend should reach stdout.
|
||||
expected_prefix = format_flag.split("-")[0]
|
||||
assert f"{expected_prefix} ok" in result.output
|
||||
|
|
@ -90,8 +90,9 @@ def export(
|
|||
typer.echo(message)
|
||||
|
||||
typer.echo(f"Exporting as {format}...")
|
||||
output_path: Optional[str] = None
|
||||
if format == "merged-16bit":
|
||||
success, message = backend.export_merged_model(
|
||||
success, message, output_path = backend.export_merged_model(
|
||||
save_directory = str(output_dir),
|
||||
format_type = "16-bit (FP16)",
|
||||
push_to_hub = push_to_hub,
|
||||
|
|
@ -100,7 +101,7 @@ def export(
|
|||
private = private,
|
||||
)
|
||||
elif format == "merged-4bit":
|
||||
success, message = backend.export_merged_model(
|
||||
success, message, output_path = backend.export_merged_model(
|
||||
save_directory = str(output_dir),
|
||||
format_type = "4-bit (FP4)",
|
||||
push_to_hub = push_to_hub,
|
||||
|
|
@ -109,7 +110,7 @@ def export(
|
|||
private = private,
|
||||
)
|
||||
elif format == "gguf":
|
||||
success, message = backend.export_gguf(
|
||||
success, message, output_path = backend.export_gguf(
|
||||
save_directory = str(output_dir),
|
||||
quantization_method = quantization.upper(),
|
||||
push_to_hub = push_to_hub,
|
||||
|
|
@ -117,7 +118,7 @@ def export(
|
|||
hf_token = hf_token,
|
||||
)
|
||||
elif format == "lora":
|
||||
success, message = backend.export_lora_adapter(
|
||||
success, message, output_path = backend.export_lora_adapter(
|
||||
save_directory = str(output_dir),
|
||||
push_to_hub = push_to_hub,
|
||||
repo_id = repo_id,
|
||||
|
|
@ -130,3 +131,5 @@ def export(
|
|||
raise typer.Exit(code = 1)
|
||||
|
||||
typer.echo(message)
|
||||
if output_path:
|
||||
typer.echo(f"Saved to: {output_path}")
|
||||
|
|
|
|||
Loading…
Reference in a new issue