Compare commits

...

292 commits
b8475 ... main

Author SHA1 Message Date
Roland Tannous
21e9a91a57
Studio: forward standard OpenAI tools / tool_choice on /v1/responses (Codex compat) (#5122)
* Studio: forward standard OpenAI tools / tool_choice on /v1/responses

Mirrors the /v1/chat/completions client-side tool pass-through from #5099
so clients (OpenAI Codex CLI, OpenAI Python SDK, ...) that target the
Responses API receive structured function_call output items instead of
plain text with tool-call tokens leaking into content.

- ResponsesRequest: type tools/tool_choice properly, add parallel_tool_calls;
  accept function_call and function_call_output input items for multi-turn
- Translate flat Responses tool / tool_choice shape to the nested Chat
  Completions shape before forwarding to llama-server
- _normalise_responses_input: map function_call_output -> role="tool",
  function_call -> assistant tool_calls (preserving call_id)
- Non-streaming: map returned tool_calls -> top-level function_call
  output items keyed by call_id
- Streaming: emit response.output_item.added (function_call),
  response.function_call_arguments.delta/.done, and response.output_item.done
  per tool call while keeping the text message at output_index 0
- Pytest coverage: tools/tool_choice translation, multi-turn input mapping,
  non-streaming tool_calls mapping, response round-trip

* Studio: merge system messages and close inner stream on /v1/responses

Fixes two issues surfacing when OpenAI Codex CLI drives /v1/responses
against a GGUF with a strict chat template (gpt-oss harmony, Qwen3, ...).

1. "System message must be at the beginning" upstream errors
   Codex sends `instructions` AND a `role:"developer"` message in `input`,
   producing two separate system-role messages. Strict templates raise
   when a second system message exists or when one appears after a user
   turn. _normalise_responses_input now hoists all instructions / system /
   developer content into a single merged system message at the top of
   the Chat Completions message list.

2. "async generator ignored GeneratorExit" / "Attempted to exit cancel
   scope in a different task"
   _responses_stream consumed the inner chat-completions body_iterator
   without an explicit aclose() in a finally block. On client disconnect
   (Codex frequently cancels mid-stream), Python 3.13 finalized the inner
   async generator on a different task, tripping anyio's cancel-scope
   check. Mirrored the same try/finally + aclose pattern used by the
   /v1/messages, /v1/chat/completions, and /v1/completions passthroughs.

Tests: hoisting of instructions + developer, developer mid-conversation,
multiple system messages in input, no-system passthrough.

* Studio: accept Codex multi-turn shapes and fix cross-task stream close on /v1/responses

Two issues observed driving /v1/responses from OpenAI Codex CLI against a
GGUF backend.

1. 422 on every turn after the first
   Codex replays prior assistant turns with
   `content:[{"type":"output_text","text":...,"annotations":[],"logprobs":[]}]`
   and carries forward `reasoning` items (o-series / gpt-5) between turns.
   Our `ResponsesContentPart` union only accepted input_text / input_image,
   and `ResponsesInputItem` only message / function_call / function_call_output,
   so Pydantic failed the whole list and FastAPI returned
   `"Input should be a valid string"` against the `str` branch of the
   outer union.

   - Add `ResponsesOutputTextPart` for assistant-replay content.
   - Add `ResponsesUnknownContentPart` and `ResponsesUnknownInputItem`
     as permissive catch-alls (drop during normalisation).
   - Wire an explicit `Discriminator` so dispatch is deterministic and
     the fallthrough reaches the catch-all instead of misreporting via
     the outer `Union[str, list[...]]`.
   - `_normalise_responses_input` now accepts output_text parts, flattens
     single-part assistant text to a plain string (keeps legacy chat
     templates happy), and silently drops reasoning / unknown items.

2. "async generator ignored GeneratorExit" / cross-task cancel scope
   `_responses_stream` awaited `openai_chat_completions` in the parent
   route-handler task, which opens the httpx client for the inner
   passthrough on *that* task. The outer `StreamingResponse` then iterates
   in a child task, so the asyncgen GC finalises the inner httpcore byte
   stream on the child task, tripping anyio's "Attempted to exit cancel
   scope in a different task". Move the `await` inside `event_generator`
   so the httpx lifecycle stays within the single streaming child task,
   and surface any HTTPException as a `response.failed` SSE frame.

Tests: assistant output_text replay, reasoning-item tolerance, unknown
content-part tolerance, end-to-end Codex-shape payload (developer + user +
reasoning + function_call + function_call_output + assistant output_text +
user), and single-part assistant flattening to plain string.

* Studio: call llama-server directly from streaming /v1/responses

The previous fix (running the inner await inside event_generator) was not
enough. Wrapping the existing `openai_chat_completions` pass-through still
stacks two async generators: when the outer generator is closed, the
innermost `HTTP11ConnectionByteStream.__aiter__` in httpcore doesn't
receive GeneratorExit before Python's asyncgen GC finalises it in a
sibling task, tripping "Attempted to exit cancel scope in a different
task" and "async generator ignored GeneratorExit" — the same Python 3.13
+ httpcore 1.0.x interaction already seen in PRs #4956, #4981, #5099.

Cure both pass-throughs had: a single same-task httpx lifecycle with
explicit `aiter_lines().aclose()` BEFORE `resp.aclose()` / `client.aclose()`
in the generator's finally block.

Apply it at the Responses layer by dropping the wrapper entirely for GGUF:
open httpx, consume `resp.aiter_lines()`, parse `chat.completion.chunk`,
emit Responses SSE events, close everything in finally — all in the
single StreamingResponse child task. Non-GGUF streaming is rejected with
a 400 (wrapping the transformers backend would re-introduce the
double-layer pattern and isn't a Codex-compatible path today anyway).

Also surfaces upstream httpx.RequestError / non-200 as a
`response.failed` SSE frame rather than a dropped stream now that the
request is dispatched after SSE headers have gone out.

* Studio: silence benign httpcore asyncgen GC warnings on Python 3.13

The streaming pass-throughs (/v1/chat/completions, /v1/messages,
/v1/responses, /v1/completions) all use the proven #4981 / #5099 pattern
— single-task httpx lifecycle with explicit aiter_lines().aclose() ahead
of resp.aclose() / client.aclose() in the generator's finally block.
That handles our own iterators correctly.

The residual noise ("async generator ignored GeneratorExit" /
"Attempted to exit cancel scope in a different task") comes from an
innermost HTTP11ConnectionByteStream.__aiter__ that httpcore creates
internally inside its pool. We hold no reference to it, so we cannot
aclose it ourselves. Python 3.13's asyncgen GC hook finalises it on the
finaliser task, its aclose path enters an anyio CancelScope shield, and
Python flags the cross-task exit. The response has already been
delivered with a 200 by then — it is purely log noise, not a functional
failure. Same interaction seen in modelcontextprotocol/python-sdk #831,
agno #3556, chainlit #2361, langchain-mcp-adapters #254.

Install a targeted sys.unraisablehook that swallows this specific tuple
— RuntimeError mentioning "cancel scope" or "GeneratorExit" plus an
object repr referencing HTTP11ConnectionByteStream — and defers to the
default hook for every other unraisable. Idempotent; guarded by a
sentinel attribute so repeated imports don't stack filters.
2026-04-21 13:17:20 +04:00
Lee Jackson
c20959dbf4
Studio: Improve chat composition, fix scroll behaviour, and refine sidebar UX (#5089)
* Chatbox, scroll, and menu fixes

- Fixed chatbox auto-expand height for multi-line text on the compare page
- Fixed chatbox UI to be consistent across compare and new chat
- Fixed scrolling being enabled on pages with no content, which also triggered the scroll-to-bottom button
- Fixed scroll-to-bottom button to only appear after scrolling up a reasonable amount instead of instantly
- Added shutdown studio button to the menu for easier access
- Fixed pop-up menu width to match the user button width

(cherry picked from commit cd4e390dfa84fe311fae79a781b96cc0ef5970a9)

* fix: correct compare scroll viewport and clean up chat composer UI polish

* Dark theme refactor and sidebar/chat UI refinements

- Complete refactoring of dark theme
- Replaced square rounded-corner user profile image with a circular bordered one
- Replaced user profile icon with 'U' initial and renamed label from 'Studio' to 'User'
- Chat bubbles now have a pointy top-right edge
- Sidebar menu tab line color selection is now consistent across all menus
- Tab-selection color animation now also applies to recent chats
- Removed 'Compare' menu autoselect when a compare chat conversation is selected
- Fixed UI consistency in Compare to match New Chat
- Removed sidebar animation and tab line, replaced with rounded selection for consistency
- Further adjustments to sidebar UI
- Further adjustments to compare chat UI

* Fixed sidebar collapse/expand for recent chats and recent runs not being clickable

* Chatbox, scroll, and menu fixes

- Fixed chatbox auto-expand height for multi-line text on the compare page
- Fixed chatbox UI to be consistent across compare and new chat
- Fixed scrolling being enabled on pages with no content, which also triggered the scroll-to-bottom button
- Fixed scroll-to-bottom button to only appear after scrolling up a reasonable amount instead of instantly
- Added shutdown studio button to the menu for easier access
- Fixed pop-up menu width to match the user button width

* Sidebar, fonts, and chat UI refinements

- Replaced logo PNG with real font text for 'unsloth' and 'BETA' label
- Added Hellix font and applied it across menus and UI elements
- Lighter scrollbar in the sidebar compared to other areas of the app
- Adjusted chat font and chat bubble styling
- Adjusted app menu design to stay consistent with the sidebar
- Adjusted text style for 'New Chat' and repositioned content/chatbox
- Adjusted model selector and top area UI
- Fixed footer text from 'LLM's' to 'LLMs'
- Fixed active selection border color incorrectly appearing on page refresh and during general navigation
- Logo now defaults to 'New Chat' when clicked

* Sidebar, model selector, and mobile UI fixes

- Further adjustments to sidebar UI and logo
- Changed right bar icon
- Model selector adjustments
- Collapsed sidebar now matches the content area background
- Adjusted Hellix font spacing across pages
- Fixed sidebar icon overlap on mobile screens

* Adjust sidebar icons

* Adjust sidebar icons

* Fixed compare chat UI and scrolling issues

* Fixed inference settings icon behavior and context info positioning

- Fixed top right inference settings icon to move into sidepanel during expand/collapse, matching left sidebar behavior
- Adjusted context information element positioning

* Fix: textarea overflow in system prompt editor

* Code block redesign, font, and chat bubble adjustments

- Redesigned code block colors and theme
- Changed code block font to Fira Code
- Fixed scrollbar disappearing when expanding/collapsing tool calls in chats
- Adjusted chat bubble background color

* Fix chat bubble background color in dark theme

* fix: restore textarea auto-sizing and scope prompt editor sizing

* fix: add explicit textarea field sizing for prompt editor overflow

* fix: generate chat nonce on click instead of render

* fix: respect training lock on logo navigation

* Refactor compare page dual chat scrolling behavior

* Revert "Refactor compare page dual chat scrolling behavior"

This reverts commit d056ec09f2.

---------

Co-authored-by: sneakr <hauzin@hotmail.com>
Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
2026-04-21 02:20:45 +04:00
Konstantin Azizov
0a5c61ffcc
fix: prefer mainstream clipboard copy over deprecated one (#5109)
Fixes #5097

Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
2026-04-20 23:18:18 +04:00
Lee Jackson
d3215ce113
Studio: Show LoRA live logs and update GGUF quant options (#5058)
* export: update GGUF quant list and ordering

* gguf: add Q2_K_L quantize flags for output and embeddings

* export: add live console logs for LoRA export flow

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: stream q2_k_l quantize logs and include subprocess error details

* fix: route Q2_K_L preset to q2_k ftype with q8_0 output+embeddings

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
2026-04-20 23:14:49 +04:00
Lee Jackson
9c8a079d97
Studio: Local profile customization in settings and sync sidebar identity (#5088)
* studio: add local profile customization in settings

* studio: add local profile settings and sync sidebar identity

* fix: adjust profile card margin

* fix: move helper modules to utils and use single-letter avatar fallback

* fix: keep profile icon visible on sidebar collapse

* fix: sidebar account trigger labeling and profile reset prefs
2026-04-20 22:28:02 +04:00
Roland Tannous
9954781d30
fix(studio/chat): cancel in-flight run when trashing a thread from sidebar (#5067)
Trashing a thread mid-stream used to delete the Dexie rows while the
model kept generating, because the sidebar has no access to the
@assistant-ui aui context. Expose per-thread cancelRun() through the
chat runtime store and call it from deleteChatItem so trash behaves
like Stop → Trash. Covers compare pairs by cancelling each paired
thread.

Co-authored-by: Lee Jackson <130007945+Imagineer99@users.noreply.github.com>
2026-04-20 21:06:59 +04:00
Michael Han
b24f3f61b8
Update README.md 2026-04-20 00:37:40 -07:00
Michael Han
f5eec8a6f2
Qwen3.6 and ReadMe revamp.md 2026-04-19 23:16:36 -07:00
Roland Tannous
ac2daf8b7a
Studio: forward standard OpenAI tools / tool_choice to llama-server (#5099)
* fix(studio): forward OpenAI tools/tool_choice to llama-server (#4999)

Studio's /v1/chat/completions silently stripped standard OpenAI `tools`
and `tool_choice` fields, so clients using standard function calling
(opencode, Claude Code, Cursor, Continue, ...) never got structured
tool_calls back. Adds a client-side pass-through path mirroring the
existing Anthropic /v1/messages flow: when `tools` is present without
Studio's `enable_tools` shorthand, the request is forwarded to
llama-server verbatim so the client sees native id, finish_reason
("tool_calls"), delta.tool_calls, and accurate usage tokens.

Also wires Anthropic tool_choice forwarding: /v1/messages previously
accepted tool_choice on the request model but silently dropped it with
a warning. Translate the four Anthropic shapes to OpenAI format and
forward them so agentic clients can actually enforce tool use.

- ChatCompletionRequest: add tools, tool_choice, stop; extra="allow"
- ChatMessage: accept role="tool", optional tool_call_id / tool_calls /
  name; content is now optional (assistant with only tool_calls)
- routes/inference.py: _openai_passthrough_stream /
  _openai_passthrough_non_streaming helpers, routing branch in
  openai_chat_completions, vision+tools via content-parts injection
- _build_passthrough_payload: tool_choice parameter (default "auto")
- anthropic_compat: anthropic_tool_choice_to_openai() translator
- tests/test_openai_tool_passthrough.py: Pydantic + translator unit tests
- tests/test_studio_api.py: 5 new E2E tests (non-stream, stream,
  multi-turn, OpenAI SDK, Anthropic tool_choice=any regression)

* fix(studio): surface httpx transport errors from OpenAI passthrough

When the managed llama-server subprocess crashes mid-request, the
async pass-through helpers in routes/inference.py used to return a
bare 500 (non-streaming) or an "An internal error occurred" SSE chunk
(streaming) because _friendly_error only recognized the sync path's
"Lost connection to llama-server" substring -- httpx transport
failures (ConnectError / ReadError / RemoteProtocolError /
ReadTimeout) stringify differently and fell through to the generic
case.

- _friendly_error: map any httpx.RequestError subclass to the same
  "Lost connection to the model server" message the sync chat path
  emits. Placed before the substring heuristics so the streaming path
  automatically picks it up via its existing except Exception catch.
- _openai_passthrough_non_streaming: wrap the httpx.AsyncClient.post
  in a try/except httpx.RequestError and re-raise as HTTPException
  502 with the friendly detail.
- tests/test_openai_tool_passthrough.py: new TestFriendlyErrorHttpx
  class pinning the mapping for ConnectError, ReadError,
  RemoteProtocolError, ReadTimeout, and confirming non-httpx paths
  (context-size heuristic, generic fallback) are unchanged.

* fix(studio): close aiter_bytes/aiter_lines explicitly in passthroughs

The httpcore asyncgen cleanup fix in 5cedd9a5 is incomplete on Python
3.13 + httpcore 1.0.x: it switched to manual client/response lifecycle
but still used anonymous `async for raw_line in resp.aiter_lines():`
patterns in all three streaming paths. Python's async for does NOT
auto-close the iterator on break/return, so the aiter_lines /
aiter_bytes async generator remains alive, reachable only from the
surrounding coroutine frame. Once `_stream()` returns the frame is
GC'd and the orphaned asyncgen is finalized on a LATER GC pass in a
DIFFERENT asyncio task, where httpcore's
HTTP11ConnectionByteStream.aclose() enters anyio.CancelScope.__exit__
with a mismatched task and prints "Exception ignored in: <async
generator>" / "async generator ignored GeneratorExit" / "Attempted
to exit cancel scope in a different task" to the server log.

User observed this on /v1/messages after successful (status 200)
requests, with the traceback pointing at HTTP11ConnectionByteStream
.__aiter__ / .aclose inside httpcore.

Fix: save resp.aiter_lines() / resp.aiter_bytes() as a variable and
explicitly `await iter.aclose()` in the finally block BEFORE
resp.aclose() / client.aclose(). This closes the asyncgen inside the
current task's event loop, so the internal httpcore byte stream is
cleaned up before Python's asyncgen GC hook has anything orphaned to
finalize. Each aclose is wrapped in try/except Exception so nested
anyio cleanup noise can't bubble out.

Applied to all three streaming passthrough paths:
- _anthropic_passthrough_stream (/v1/messages client-side tool path)
- _openai_passthrough_stream (/v1/chat/completions client-side tool
  path, new in this PR)
- openai_completions (/v1/completions bytes proxy from PR #4956)

* fix(studio): default ChatCompletionRequest.stream to false per OpenAI spec

OpenAI's /v1/chat/completions spec defaults `stream` to false, so
clients that omit the field (naive curl, minimal integrations) expect
a single JSON response back. Studio was defaulting to true, silently
switching those clients into SSE and breaking any parser that didn't
also handle streaming. ResponsesRequest and AnthropicMessagesRequest
already default to false correctly; only ChatCompletionRequest was
wrong.

Studio's own frontend always sets `stream` explicitly on every
chat-adapter / chat-api / runtime-provider call site, so the flip has
no UI impact. SDK users (OpenAI Python/JS SDK, opencode, Claude Code,
Cursor, Continue) also always pass `stream` explicitly, so they're
unaffected. The only clients feeling the change are raw-curl users
who were relying on the wrong default -- those get the correct OpenAI
behavior now.

Added a regression test pinning the default so it can't silently
flip back.

* fix(studio): reject images in OpenAI tool passthrough for text-only GGUFs

The new tool passthrough branch runs before _extract_content_parts,
skipping the existing not is_vision guard. Requests combining tools
with an image on a text-only tool-capable GGUF were forwarded to
llama-server, producing opaque upstream errors instead of the
pre-existing clear 400. Restore the guard inline at the dispatch
point, checking both legacy image_base64 and inline image_url parts.

* fix(studio): require tool_call_id on role=tool chat messages

Enforce the OpenAI spec rule that role="tool" messages must carry a
tool_call_id. Without it, upstream backends cannot associate a tool
result with the assistant's prior tool_calls entry and the request
fails in non-obvious ways through the passthrough path. Reject at the
request boundary with a 422 instead.

* fix(studio): harden OpenAI tool passthrough validation and error surfacing

Three related fixes called out by the PR review:

1. Preserve upstream status codes in the streaming passthrough. The
   httpx request is now dispatched before the StreamingResponse is
   constructed. Non-200 upstream responses and httpx RequestError
   transport failures raise HTTPException with the real status
   instead of being buried inside a 200 SSE error frame, so OpenAI
   SDK clients see APIError/BadRequestError/... as expected.

2. Require non-empty content on user/system/tool messages. Per the
   OpenAI spec, content may only be omitted on assistant messages
   that carry tool_calls; enforce that at the request boundary so
   malformed messages never reach the passthrough path.

3. Role-constrain tool-call metadata. tool_calls is only valid on
   role=assistant, tool_call_id and name only on role=tool. Without
   this, a user/system message with tool_calls would flip the
   passthrough branch on and be forwarded to llama-server, surfacing
   as an opaque upstream error.

* fix(studio): normalize image mode and passthrough JSON verbatim

Two Gemini-code-assist review findings on PR #5099:

1. Unconditionally convert decoded images to RGB before PNG encoding.
   The prior code only handled RGBA, letting CMYK/I/F images crash
   at img.save(format="PNG") and surface as opaque 400s. Applied to
   both the passthrough helper and the non-passthrough GGUF path
   that originally carried this pattern, keeping the two sites in
   sync.

2. Return the upstream JSON body as raw bytes via Response rather
   than parse-then-re-serialize with JSONResponse. Matches the
   passthrough helper's "verbatim" contract and drops a redundant
   round-trip.

---------

Co-authored-by: Lee Jackson <130007945+Imagineer99@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-04-18 12:53:23 +04:00
Manan Shah
7d0d2f256c
Add qwen3.6 script (#5084)
* unsloth gemma4 support files

* some fixes

* Fixing cache.empty() calls (#4813)

* Fixing cache.empty() calls

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Manan Shah <mananshah@Manans-MacBook-Pro.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Fix/gemma4 mlx (#4816)

* Fixing cache.empty() calls

* fixing for mlx versions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Manan Shah <mananshah@Manans-MacBook-Pro.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* removed bidirectional check for 31b (#4839)

Co-authored-by: Manan17 <shahmanan170602@gmail.coml>

* Add Gemma 4 26B MoE support (MLX) (#4844)

* removed bidirectional check for 31b

* Change gemma4_text for moe

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Manan Shah <mananshah@Manans-MacBook-Pro.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix(gemma4): cast RoPE offset to int before mx.arange() (#4901)

* fix(gemma4): cast RoPE offset to int before mx.arange()

* fix(gemma4): use zero-based arange + offset to avoid CPU-GPU sync

* qwen3.6 patches for multi-turn chat

* qwen3.6 script

* removing unnecessary scripts

* displaying errors for not installed packages

---------

Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
Co-authored-by: Manan Shah <mananshah@Manans-MacBook-Pro.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Manan17 <shahmanan170602@gmail.coml>
Co-authored-by: Théophile Lafargue <138336683+eauchs@users.noreply.github.com>
2026-04-17 01:21:30 -07:00
Daniel Han
d20b306755 Versioning 2026-04-16 12:06:10 -07:00
Daniel Han
0b57884120
Add Qwen3.6 inference defaults for Studio (#5065)
* Add Qwen3.6 inference defaults for Studio

Add qwen3.6 family entry to inference_defaults.json with the
recommended sampling parameters from Qwen's documentation:
temperature=0.7, top_p=0.8, top_k=20, min_p=0.0,
presence_penalty=1.5, repetition_penalty=1.0.

Without this, Qwen3.6 models fall through to the generic qwen3
pattern which uses different defaults (temperature=0.6,
top_p=0.95, no presence_penalty).

* Add Qwen3.6-35B-A3B-GGUF to default model lists

* Add Qwen3.5/3.6 presence_penalty to thinking toggle and small-model disable logic

- Thinking toggle (on-load + button click) now sets presencePenalty: 1.5 for
  Qwen3.5 and Qwen3.6 models (both thinking-ON and thinking-OFF states)
- Small-model thinking-disable check (<9B defaults to no-thinking) extended
  from Qwen3.5-only to also cover Qwen3.6, in all 3 locations:
  frontend on-load, frontend refresh, backend llama_cpp.py
2026-04-16 11:42:42 -07:00
Daniel Han
d56f980452
fix: multi-GPU inference crash for bnb 4-bit/8-bit models (#5068)
* fix: multi-GPU inference crash for bnb 4-bit/8-bit models

When load_in_4bit or load_in_8bit is used with device_map="sequential"
and max_memory constraints that place weights across multiple GPUs (or
entirely on a non-default GPU like cuda:1), the bitsandbytes loading
path in transformers never calls dispatch_model. No AlignDevicesHook is
installed, and the first forward/generate call crashes with:

  RuntimeError: Expected all tensors to be on the same device

This adds _attach_bnb_multidevice_hooks() which is called after
from_pretrained returns. It infers a device map from actual parameter
placements and calls dispatch_model(force_hooks=True) to install the
missing hooks. The function is a complete no-op for the common
single-GPU cuda:0 case.

Call sites: FastBaseModel.from_pretrained (vision.py) and
FastLlamaModel.from_pretrained (llama.py).

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: align with PR #5053 final review improvements

- Add hook call to the bnb quantized loading branch in llama.py (the
  primary load_in_4bit path), not just the non-fast-inference fallback
- Expand bnb detection: also check model.is_loaded_in_4bit,
  model.is_loaded_in_8bit, model.quantization_method
- Pass explicit main_device and skip_keys to dispatch_model
- Use logger.info instead of print for the success message
- Use kwargs.get("load_in_8bit", False) at llama.py call sites

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-16 11:35:02 -07:00
Lee Jackson
ee86530e55
chore: switch helper and no-cache fallback to Gemma (#5066) 2026-04-16 22:27:30 +04:00
Wasim Yousef Said
bc9ddb3af6
Fix onboarding followups (#5064)
* Fix onboarding followups

* Rename sidebar studio to train
2026-04-16 10:11:35 -07:00
Wasim Yousef Said
7ef65bd2e5
Chat first onboarding (#5063)
* auth: default to chat

* settings: relaunch onboarding

* onboarding: return to launch page

* studio: stop auto guided tour

* ui: soften global radius

* cleanup: rename onboarding exit prop

* fix onboarding redirect safety

* Show real Unsloth version in settings

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-16 09:58:10 -07:00
हिमांशु
f4422b0a62
change torchcodec version to 0.10.0 in extra-no-deps (#5043)
Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
2026-04-16 19:50:57 +04:00
Wasim Yousef Said
b01e9af124
feat(studio): replace navbar with collapsible sidebar (#4936)
* feat(studio): replace navbar navigation with collapsible sidebar

Add an app-wide sidebar with hover-expand and pin-to-dock behavior.
Navigation items (Studio, Recipes, Export, Chat) move from the center
pill navbar to the sidebar. Chat threads and recipes render as
collapsible sub-lists. Navbar simplified to logo + update + close.

- Extend SidebarProvider with pinned/hovered state model
- New AppSidebar with animated active indicator, sloth profile menu,
  theme toggle, guided tour, back/forward navigation
- Chat page refactored to URL-driven view state via search params
- Extract reusable hooks for chat thread and recipe sidebar data
- Guard startViewTransition for browser compatibility
- Wrap chat deletions in Dexie transaction for data integrity

* feat(studio): move logo to sidebar and make navbar overlay

- Sidebar is now full-height with logo in SidebarHeader
- Collapsed sidebar shows sticker.png, expanded shows full logo
- Navbar is absolute-positioned overlay (no layout space)
- Main content extends to top, aligning with navbar controls

* feat(studio): full-height sidebar with recents, edge-to-edge nav buttons

- Sidebar outside max-w-7xl, pinned to left edge
- Remove sidebar rounding, menu buttons rounded-md
- Nav buttons flush to sidebar edges with no left rounding
- Replace collapsible recipes/chat with flat nav items
- Add Recents section with chat history (1 item when not on chat, full on chat)
- New Chat as first nav item with PencilEdit02Icon
- Cursor pointer on all sidebar buttons
- Navbar temporarily hidden for screenshots

* fix(studio): fix chat scroll, action bar hover, collapsible recents

- Fix sticky composer by removing `relative` override on viewport footer
- Action bar buttons only show on hover (autohide=always)
- Remove floating border/shadow from action bar
- Add scroll space above composer for last message actions
- Back/forward buttons use router history (stay in-app)
- Recents section collapsible with chevron on chat route
- Set html/body/#root height for proper h-full chain

* fix(studio): address review feedback, clean up unused code

- Unhide navbar (was left hidden from screenshot)
- Remove unused imports: SidebarMenuSub*, BubbleChatIcon, ColumnInsertIcon
- Remove unused vars: recipeItems, activeRecipeId, canCompare, recipesOpen
- Include compare query id in active sidebar selection
- Use store type for contextUsage instead of inline type
- Simplify noop in sidebar.tsx
- Remove empty className prop

* feat(studio): add mobile sidebar, recent runs section, and misc UX fixes

* feat(studio): scaffold settings feature module with dialog store

* feat(studio): add tri-state theme store for settings

* feat(chat): add clear-all-chats and export-chat-history utils

* feat(studio): add settings dialog shell with tab rail

* feat(studio): add appearance tab with theme and sidebar pin

* feat(studio): add settings general tab with hf token, auto-title, reset prefs

* feat(studio): add settings chat tab with export and clear

* feat(studio): add api keys tab with list and revoke flow

* feat(studio): add create-key form and reveal dialog

* feat(studio): add usage examples panel to api keys tab

* feat(studio): add settings about tab with update and shutdown

* feat(studio): add settings dropdown item and cmd-comma shortcut

* feat(studio): remove legacy api-keys route and chat-sheet preference rows

* fix(studio): settings dialog a11y + polish pass

* feat(studio): inline api key reveal card replacing nested dialog

* fix(studio): hide revoked keys from settings list

* refactor(studio): strip navbar and hoist training unload guard

* feat(studio): explicit sidebar toggle, remove hover-open and pin icons

* fix(studio): use SidebarRight01Icon for collapsed sidebar open toggle

* fix(studio): address code review findings for settings dialog

* feat(studio): collapsible navigate group with standalone new-chat and compare

* fix(studio): chat-only standalone actions, use ColumnInsertIcon for compare

* fix(studio): sidebar new-chat/compare state reset and icon-mode collapsible

* feat(studio): add compact logo assets for sidebar header

* Fixed sidebar design

* fix(studio): sidebar delete icon hover contrast and sizing

* feat(studio): route-gate sidebar recents (chats off /studio, runs on /studio)

* feat(studio): add chat search store

* feat(studio): add chat search index hook with snapshot-on-open

* feat(studio): add chat search command dialog with global shortcut

* feat(studio): wire chat search into sidebar

* fix(studio): trim hf token on save, add show/hide toggle, commit on close

* revert(studio): restore original sidebar/border colors, brighten sidebar

* feat(studio): forward overlayClassName through CommandDialog

* fix(studio): wrap search dialog in Command context, redesign as flat 635px card

* fix(studio): reserve right padding on recent items so delete icon stops overlapping title

* fix(studio): skip hf token unmount-commit during reset-prefs reload

* chore(studio): drop unused icon import and unreachable runs navigate branch

* fix(studio): chat search index filters archived before limit, batches message query, picks up reasoning text

* fix(studio): keep CommandEmpty in tree so empty state renders correctly

* fix(studio): cap system prompt and chat template textareas so they scroll instead of growing

* fix(studio): attach chat-compare tour anchor to sidebar compare button

* fix(studio): persist system theme explicitly so next-themes does not clobber on reload

* fix(studio): auto-switch to history tab when selecting a recent run from sidebar

* UI overhaul: chatbox, scrollbar, sidebar, and compare view

UI Changes:
- Redesigned the Compare UI with general cleanup
- Redesigned the Chatbox UI
- Reduced the width of the user chat bubble for improved readability
- Narrowed the user chat box across the content page
- Adjusted thinking-box text color to be slightly darker
- Removed faded text effect from chat messages
- Removed faded text effect from the thinking box
- Added a small LLM chat safety note at the bottom of the chatbox
- Restyled the scrollbar

Layout & Behavior:
- Reworked the scrollbar to span the full height of the page (no top/bottom padding) and remain persistently visible when content is scrollable, rather than only on hover
- Reworked the Configuration sidebar to span full height — removed rounded corners and borders, with the scrollbar adjusted to match the full top-to-bottom layout
- Adjusted the top menu and bottom chatbox content areas to work correctly with the new full-page scroll behavior
- Made chat content match the chatbox width, with content sliding slightly behind the chatbox when scrolling
- Aligned chat text width with the chatbox for visual consistency, including how far the text extends behind the chatbox

Fixes:
- Fixed the chatbox not auto-expanding when typing multi-line input while bottom-positioned during an active chat (previously only worked before a chat had started)
- Fixed positioning and design of the user chat hover menu buttons to match the assistant chat box — now displayed below the chat bubble instead of on the left side

* Fix user message layout in thread component

* swap code icon

* fix compare layout

* fix compare pane flex

* Sidebar improvements and fixes

- Added scrolling support to the sidebar so menus and recent chats no longer get hidden
- Recent chats are now always visible in the sidebar, not hidden when in Studio, Recipes, or Export
- Recent chat is now deselected when selecting other navigations
- Fixed sidebar glitch where browser resize could make the sidebar and expand button disappear completely
- Fixed glitch where the open-sidebar hover tooltip appeared above the logo when clicking expand sidebar
- Reduced sidebar width on mobile to around 2/3 of the screen (was too wide)
- Made the close-sidebar hover tooltip consistent with the rest of the design
- Removed sidebar collapse/expand animation
- Small adjustment to chat width

* Fix route scrolling, polling, and theme sync issues

* Fix Studio page scrolling

---------

Co-authored-by: sneakr <hauzin@hotmail.com>
2026-04-16 08:46:16 -07:00
Daniel Han
05ec0f110b
Studio: Ollama support, recommended folders, Custom Folders UX polish (#5050)
* Studio: Ollama support, recommended folders, Custom Folders UX polish

Backend:
- Add _scan_ollama_dir that reads manifests/registry.ollama.ai/library/*
  and creates .gguf symlinks under <ollama_dir>/.studio_links/ pointing
  at the content-addressable blobs, so detect_gguf_model and llama-server
  -m work unchanged for Ollama models
- Filter entries under .studio_links from the generic models/hf/lmstudio
  scanners to avoid duplicate rows and leaked internal paths in the UI
- New GET /api/models/recommended-folders endpoint returning LM Studio
  and Ollama model directories that currently exist on the machine
  (OLLAMA_MODELS env var + standard paths, ~/.lmstudio/models, legacy
  LM Studio cache), used by the Custom Folders quick-add chips
- detect_gguf_model now uses os.path.abspath instead of Path.resolve so
  the readable symlink name is preserved as display_name (e.g.
  qwen2.5-0.5b-Q4_K_M.gguf instead of sha256-abc...)
- llama-server failure with a path under .studio_links or .cache/ollama
  surfaces a friendlier message ("Some Ollama models do not work with
  llama.cpp. Try a different model, or use this model directly through
  Ollama instead.") instead of the generic validation error

Frontend:
- ListLabel supports an optional leading icon and collapse toggle; used
  for Downloaded (download icon), Custom Folders (folder icon), and
  Recommended (star icon)
- Custom Folders header gets folder icon on the left, and +, search,
  and chevron buttons on the right; chevron uses ml-auto so it aligns
  with the Downloaded and Recommended chevrons
- New recommended folder chips render below the registered scan folders
  when there are unregistered well-known paths; one click adds them as
  a scan folder
- Custom folder rows that are direct .gguf files (Ollama symlinks) load
  immediately via onSelect instead of opening the GGUF variant expander
  (which is for repos containing multiple quants, not single files)
- When loading a direct .gguf file path, send max_seq_length = 0 so the
  backend uses the model's native context instead of the 4096 chat
  default (qwen2.5:0.5b now loads at 32768 instead of 4096)
- New listRecommendedFolders() helper on the chat API

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address review: log silent exceptions and support read-only Ollama dirs

Replace silent except blocks in _scan_ollama_dir and the
recommended-folders endpoint with narrower exception types plus debug
or warning logs, so failures are diagnosable without hiding signal.

Add _ollama_links_dir helper that falls back to a per-ollama-dir hashed
namespace under Studio's own cache (~/.unsloth/studio/cache/ollama_links)
when the Ollama models directory is read-only. Common for system installs
at /usr/share/ollama/.ollama/models and /var/lib/ollama/.ollama/models
where the Studio process has read but not write access. Previously the
scanner returned an empty list in that case and Ollama models would
silently not appear.

The fallback preserves the .gguf suffix on symlink names so
detect_gguf_model keeps recognising them. The prior "raw sha256 blob
path" fallback would have missed the suffix check and failed to load.

* Address review: detect mmproj next to symlink target for vision GGUFs

Codex P1 on model_config.py:1012: when detect_gguf_model returns the
symlink path (to preserve readable display names), detect_mmproj_file
searched the symlink's parent directory instead of the target's. For
vision GGUFs surfaced via Ollama's .studio_links/ -- where the weight
file is symlinked but any mmproj sidecar lives next to the real blob
-- mmproj was no longer detected, so the model was misclassified as
text-only and llama-server would start without --mmproj.

detect_mmproj_file now adds the resolved target's parent to the scan
order when path is a symlink. Direct (non-symlink) .gguf paths are
unchanged, so LM Studio and HF cache layouts keep working exactly as
before. Verified with a fake layout reproducing the bug plus a
regression check on a non-symlink LM Studio model.

* Address review: support all Ollama namespaces and vision projector layers

- Iterate over all directories under registry.ollama.ai/ instead of
  hardcoding the "library" namespace. Custom namespaces like
  "mradermacher/llama3" now get scanned and include the namespace
  prefix in display names, model IDs, and symlink names to avoid
  collisions.

- Create companion -mmproj.gguf symlinks for Ollama vision models
  that have an "application/vnd.ollama.image.projector" layer, so
  detect_mmproj_file can find the projector alongside the model.

- Extract symlink creation into _make_symlink helper to reduce
  duplication between model and projector paths.

* Address review: move imports to top level and add scan limit

- Move hashlib and json imports to the top of the file (PEP 8).
- Remove inline `import json as _json` and `import hashlib` from
  function bodies, use the top-level imports directly.
- Add `limit` parameter to `_scan_ollama_dir()` with early exit
  when the threshold is reached.
- Pass `_MAX_MODELS_PER_FOLDER` into the scanner so it stops
  traversing once enough models are found.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address review: Windows fallback, all registry hosts, collision safety

_make_link (formerly _make_symlink):
- Falls back to os.link() hardlink when symlink_to() fails (Windows
  without Developer Mode), then to shutil.copy2 as last resort
- Uses atomic os.replace via tmp file to avoid race window where the
  .gguf path is missing during rescan

Scanner now handles all Ollama registry layouts:
- Uses rglob over manifests/ instead of hardcoding registry.ollama.ai
- Discovers hf.co/org/repo:tag and any other host, not just library/
- Filenames include a stable sha1 hash of the manifest path to prevent
  collisions between models that normalize to the same stem

Per-model subdirectories under .studio_links/:
- Each model's links live in their own hash-keyed subdirectory
- detect_mmproj_file only sees the projector for that specific model,
  not siblings from other Ollama models

Friendly Ollama error detection:
- Now also matches ollama_links/ (the read-only fallback cache path)
  and model_identifier starting with "ollama/"

Recommended folders:
- Added os.access(R_OK | X_OK) check so unreadable system directories
  like /var/lib/ollama/.ollama/models are not advertised as chips

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address review: filter ollama_links from generic scanners

The generic scanners (models_dir, hf_cache, lmstudio) already filter
out .studio_links to avoid duplicate Ollama entries, but missed the
ollama_links fallback cache directory used for read-only Ollama
installs. Add it to the filter.

* Address review: idempotent link creation and path-component filter

_make_link:
- Skip recreation when a valid link/copy already exists (samefile or
  matching size check). Prevents blocking the model-list API with
  multi-GB copies on repeated scans.
- Use uuid4 instead of os.getpid() for tmp file names to avoid race
  conditions from concurrent scans.
- Log cleanup errors instead of silently swallowing them.

Path filter:
- Use os.sep-bounded checks instead of bare substring match to avoid
  false positives on paths like "my.studio_links.backup/model.gguf".

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address review: drop copy fallback, targeted glob, robust path filter

_make_link:
- Drop shutil.copy2 fallback -- copying multi-GB GGUFs inside a sync
  API request would block the backend. Log a warning and skip the
  model when both symlink and hardlink fail.

Scanner:
- Replace rglob("*") with targeted glob patterns (*/*/* and */*/*/*)
  to avoid traversing unrelated subdirectories in large custom folders.

Path filter:
- Use Path.parts membership check instead of os.sep substring matching
  for robustness across platforms.

Scan limit:
- Skip _scan_ollama_dir when _generic already fills the per-folder cap.

* Address review: sha256, top-level uuid import, Path.absolute()

- Switch hashlib.sha1 to hashlib.sha256 for path hashing consistency.
- Move uuid import to the top of the file instead of inside _make_link.
- Replace os.path.abspath with Path.absolute() in detect_gguf_model
  to match the pathlib style used throughout the codebase.

* Address review: fix stale comments (sha1, rglob, copy fallback)

Update three docstrings/comments that still referenced the old
implementation after recent changes:
- sha1 comment now says "not a security boundary" (no hash name)
- "rglob" -> "targeted glob patterns"
- "file copies as a last resort" -> removed (copy fallback was dropped)

* Address review: fix stale links, support all manifest depths, scope error

_make_link:
- Drop size-based idempotency shortcut that kept stale links after
  ollama pull updates a tag to a same-sized blob. Only samefile()
  is used now -- if the link doesn't point at the exact same inode,
  it gets replaced.

Scanner:
- Revert targeted glob back to rglob so deeper OCI-style repo names
  (5+ path segments) are not silently skipped.

Ollama error:
- Only show "Some Ollama models do not work with llama.cpp" when the
  server output contains GGUF compatibility hints (key not found,
  unknown architecture, failed to load). Unrelated failures like
  OOM or missing binaries now show the generic error instead of
  being misdiagnosed.

---------

Co-authored-by: Daniel Han <info@unsloth.ai>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: danielhanchen <michaelhan2050@gmail.com>
2026-04-16 08:24:08 -07:00
Daniel Han
ff23ce40b4
Fix review findings for chat-template repair (#5049) (#5056)
* Fix review findings for PR #49

1. Sandbox fallback Jinja env in _VariantTokenizerProxy.apply_chat_template
   (use SandboxedEnvironment, matching _derive_assistant_prefix_by_render)
2. Unwrap benign outer-If guards in _template_ends_with_toplevel_for so
   templates like {% if messages %}{% for ... %}{% endfor %}{% endif %}
   are still repairable (preserves Qwen3-Guard rejection via else-branch
   and add_generation_prompt-name checks)
3. Preserve raw name_or_path in _VariantTokenizerProxy._source_path so
   local-path detection works for dict/list variant tokenizers
4. Context-aware strict-mode messages: omit "will still load" and
   "Set UNSLOTH_STRICT_CHAT_TEMPLATE=1" when already raising

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-16 08:02:05 -07:00
Daniel Han
b42e3a120d
Remove legacy venv Scripts entry from User PATH on upgrade (#5060)
Older installers persisted the venv Scripts directory directly in the
User PATH registry. The shim approach from #4961 no longer writes that
entry, but on upgrade the old one survived and python.exe / pip.exe
from the unsloth venv continued winning resolution in every new shell.

Before creating the shim, read the current User PATH, filter out any
entry matching $VenvDir\Scripts (using the same symmetric raw+expanded
comparison as Add-ToUserPath), and write back if changed. No-op on
fresh installs where the legacy entry was never written.

Confirmed on a real Windows machine: `where.exe python` was returning
the venv interpreter first even after the shim PR merged.
2026-04-16 07:36:59 -07:00
Daniel Han
5b8643969e Revert "Remove legacy venv Scripts entry from User PATH on upgrade"
This reverts commit cae4a74297.
2026-04-16 14:20:43 +00:00
Daniel Han
cae4a74297 Remove legacy venv Scripts entry from User PATH on upgrade
Older installers persisted the venv Scripts directory directly in the
User PATH registry. The shim approach (added in this PR) no longer writes
that entry, but it also did not remove the old one. On upgrade, the
legacy entry survived and python.exe / pip.exe from the unsloth venv
continued winning resolution in every new shell, which is exactly the
hijack the shim was designed to prevent.

Before creating the shim, read the current User PATH, filter out any
entry matching $VenvDir\Scripts (using the same symmetric raw+expanded
comparison as Add-ToUserPath), and write back if changed. This runs
once per install and is a no-op on fresh installs where the legacy
entry was never written.
2026-04-16 14:19:04 +00:00
Datta Nimmaturi
6764cb9b90
Restrict flash attn to <=256 head dim. Consolidate attn impl checks (#5051)
* Restrict flash attn to <=256 head dim. Consolidate attn impl checks

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Consolidate the changes into single function

* safeguard for dict instead of object

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-16 09:00:17 -05:00
Daniel Han
c5be8b1cd2
Chat-template repair: warn-by-default, AST classification, dict support (#5049)
* Chat-template repair: warn-by-default, AST classification, dict support

Follow-up hardening on top of PR #4426 (which fixed the #4150
RuntimeError for ChatML LoRA reloads).

Behavior changes:

- Warn-by-default instead of RuntimeError. When fix_chat_template cannot
  repair a broken template, emit a warning and return the original.
  Set UNSLOTH_STRICT_CHAT_TEMPLATE=1 to restore the pre-warn hard fail.
  Fixes the UX where a missing `{% if add_generation_prompt %}` block on
  a saved LoRA (typical after LlamaFactory / Axolotl re-serialize) would
  block model loading entirely.

- Local path vs HF hub distinguished in the warning message. For local
  paths the message points at the likely downstream tool; for HF IDs it
  points at the upstream model maintainers. Previously both said "file a
  bug report to the maintainers of <path>" even when <path> was the
  user's own saves/ directory.

- Dict / list chat_template now handled. Hermes-3 ships with
  {default, tool_use} and the previous code crashed with
  AttributeError: 'dict' object has no attribute 'find' when entering
  _fix_chat_template with a dict. Each variant is now fixed
  independently; structure is preserved.

Internals:

- _find_end_position now matches all four Jinja whitespace-control
  variants ({% %}, {%- %}, {% -%}, {%- -%}) and returns the rightmost
  endfor/endif so multi-for templates aren't locked onto the first loop.
  Previously {%- endfor -%} (both-side dash, used by Qwen3-Guard) was
  silently bypassed.

- _has_add_generation_prompt_block uses Jinja AST via
  jinja2.nodes.If/Name walks instead of substring matching, so
  templates that hide the block behind comments or dash-style variants
  are classified correctly.

- _template_ends_with_toplevel_for gates the GH#4150 ChatML repair on
  the AST: only fires when the last structural top-level node is a For
  (standard ChatML shape), ignoring trailing pure-whitespace output
  nodes. Templates wrapped in an outer If (Qwen3-Guard) are now
  explicitly skipped at the _fix_chat_template level as well, not just
  at load_correct_tokenizer's name-based exemption.

- _validate_patched_template renders the patched template with and
  without add_generation_prompt and confirms the patched output
  responds to the flag by appending (not replacing) content. If
  validation fails, the patch is discarded and we fall through to the
  warn path.

Verified with an expanded regression suite in tests/:
- test_fix_chat_template_pr4426.py: 42/42 template-matrix cells
- test_load_correct_tokenizer_pr4426.py: 5/5 tokenizer loads
- test_chat_template_followups.py: 10/10 new follow-up tests
- test_mistral_pr4426.py: 5 Mistral variants byte-identical
- test_qwen_pr4426.py: 14 Qwen variants byte-identical
  (Qwen1.5, Qwen2, Qwen2.5-Instruct/Coder/Math/VL, Qwen3,
  Qwen3-Coder, QwQ, Qwen3-Guard-Gen)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Guard _validate_patched_template against read-only chat_template

If tokenizer.chat_template is a property or otherwise read-only, the
validation helper would crash with AttributeError when trying to
temporarily set the patched template. Catch the assignment failure and
return False (skip validation), and best-effort restore in the finally
block.

* Replace regex separator inference with render-diff; broaden repair to non-ChatML templates

The previous `_infer_assistant_separator` was a four-tier regex heuristic that
only worked on ChatML-shaped templates and forced a hard `<|im_start|>` /
`<|im_end|>` presence gate on Case 2 repair. This meant a Llama-3, Gemma, or
Phi-3 template stripped of its generation-prompt block by a downstream tool
(LlamaFactory, Axolotl, etc.) would still warn-and-return even though the
structural shape is identical to the ChatML case the PR already handles.

This replaces the regex with `_derive_assistant_prefix_by_render`: render the
template with two dialogs that differ only in assistant content, then
`os.path.commonprefix` on the tails captures the exact assistant-turn prefix
the template emits. The template itself is ground truth, so non-ChatML shapes
work as long as the assistant block is a literal the template emits once per
message.

Three guards keep the derivation safe:
  A. both assistant renders extend the base render (no reordering);
  B. the divergence point is exactly the content-insertion site (sentinel
     follows the common prefix);
  C. a user-role cross-check: if a render with a user sentinel also emits
     the same prefix, role has no effect on output and we reject. A render
     failure on [user, user] (e.g. Gemma's `raise_exception` alternation
     check) is evidence that role matters; we accept.

Sentinels differ at character 0 so `commonprefix` cannot absorb them, and
trailing whitespace/comments after the last `{% endfor %}` are stripped
before probing (they would appear in base but not after the appended
assistant turn and break Guard A).

`_fix_chat_template` and `_repair_string_template` now thread an
`is_sharegpt` kwarg; `_fix_chat_template` retries once with
`is_sharegpt=True` if the first probe returns None (dual-probe fallback
for dict/list callers).

The ChatML `<|im_start|>` / `<|im_end|>` hard gate in Case 2 is dropped.
`_infer_assistant_separator` is deleted.

Verified via:
  - tests/test_fix_chat_template_pr4426.py: 51/51 cells (new Llama-3,
    Gemma, Phi-3 broken-template rows all repair FIX-OK)
  - tests/test_load_correct_tokenizer_pr4426.py: 5/5
  - tests/test_chat_template_followups.py: 18/18 (T11-T18 cover
    non-ChatML repair + probe failure modes)
  - tests/test_mistral_pr4426.py: 5/5 byte-identical
  - tests/test_qwen_pr4426.py: 14/14 byte-identical (Qwen3-Guard AST
    gate still rejects)
  - tests/hermes3_lora_pr4426.py reload: patched template ends with
    `<|im_start|>assistant\n`, inference returns sensible output.
  - temp/sim/battery.py: 79/79 followup; vs baseline: 0 regressions,
    9 improvements.
  - Spot-check probe on real stripped tokenizers (Hermes-3, Phi-4,
    Llama-3.2-1B, Gemma-3-1B): all derive the expected prefix.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address reviewer findings: variant routing, positive-gate detection, comment-safe end scan

Resolves three reviewer findings on PR #5049 (`fix/chat-template-followups`):

Finding #1 [10/10]: dict/list variants now route through
`_fix_chat_template_for_tokenizer` via a new `_VariantTokenizerProxy`
adapter. Previously the dict/list branches called `_fix_chat_template`
directly, silently bypassing the warn/strict (`UNSLOTH_STRICT_CHAT_TEMPLATE`)
contract, the `no == yes` diagnostic, broken-existing-block detection,
and `_validate_patched_template` guard. The proxy swaps
`base.chat_template` to the variant string before each
`apply_chat_template` call so tokenizer globals (`bos_token`, custom
filters, `raise_exception`) remain available; if the base is read-only
it falls back to isolated Jinja rendering.

Finding #2 [1/10]: `_has_add_generation_prompt_block` now requires the
`If` body to contain at least one `Output` node (a new
`_if_body_emits_content` helper walks descendants). This distinguishes a
real generation-prompt block from a header guard like
`{% if not add_generation_prompt is defined %}{% set ... %}{% endif %}`
(body contains only `Assign`) which references the name but emits
nothing. Also dropped a now-redundant `"add_generation_prompt" not in
scrubbed` guard in `_fix_chat_template` Case 2 so header-guarded
templates still get repaired.

Finding #4 [1/10]: `_find_end_position` now replaces Jinja comments with
equal-length whitespace before scanning for `{% endfor %}` / `{% endif %}`
tokens. This prevents a trailing comment containing those tokens from
being picked as the real end tag. Positions in the padded string map 1:1
to positions in the original template.

Tests:
  - tests/test_chat_template_followups.py: 21/21 (T19 strict-mode
    dict variant, T20 header-guard repair, T21 comment-endfor trap
    added; T4/T5 stubs updated with a working apply_chat_template
    that routes through Jinja).
  - tests/test_fix_chat_template_pr4426.py: 51/51 cells unchanged.
  - tests/test_load_correct_tokenizer_pr4426.py: 5/5.
  - tests/test_mistral_pr4426.py: 5/5 byte-identical.
  - tests/test_qwen_pr4426.py: 14/14 byte-identical.
  - temp/sim/battery.py: 79/79 followup; 0 regressions vs baseline.
  - Phase 3 Hermes-3 broken-LoRA reload: inference still returns
    `'The answer to the equation 2+2 is 4.'`.
  - Spot-checks on Hermes-3 / Phi-4 / Llama-3.2-1B / Gemma-3-1B real
    stripped templates: probe still derives the expected prefix.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Tighten comments in chat-template helpers

Pure comment minimization across `_find_end_position`,
`_has_add_generation_prompt_block`, `_if_body_emits_content`,
`_derive_assistant_prefix_by_render`, `_fix_chat_template` Case 2,
and `_VariantTokenizerProxy`. No behavior change; same intent,
fewer lines. All 21 follow-up tests and the 51-cell Phase 1 matrix
still pass.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Sandbox probe, fix is_sharegpt validator mismatch, reject negated gates

Three real bugs from the 10-agent Opus review:

1. Probe now uses `jinja2.sandbox.SandboxedEnvironment` instead of bare
   `jinja2.Environment`. The probe renders at model-load time (before
   the user calls `apply_chat_template`), so it was a new eager
   code-execution surface that the base HF tokenizer loading does not
   have. SandboxedEnvironment blocks attribute-chain exploits at
   negligible cost.

2. `_repair_string_template` now tries validation with both
   `is_sharegpt=False` and `is_sharegpt=True`. Previously, when
   `_fix_chat_template` internally fell back to the other schema via
   its dual-probe, the outer validation still used the caller's
   original `is_sharegpt` -- rendering with the wrong message keys and
   spuriously dropping a valid repair.

3. `_has_add_generation_prompt_block` now skips `If` nodes whose test
   is a `Not` expression. A negated gate like
   `{% if not add_generation_prompt %}{{ x }}{% endif %}` fires when
   agp=False, so its emitting body is not a generation block -- but the
   old code counted any Name reference regardless of polarity.

Cleanup: removed unused `self._label`, added `\r` escape in
generation-block literal, switched variant labels to `!r` formatting,
removed redundant `import os as _os`.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix jinja2.sandbox import and sandbox proxy fallback

Two critical findings from the 20-reviewer pass:

1. [20/20] The proxy read-only fallback used bare `jinja2.Environment`,
   not sandboxed. All 20 reviewers independently reproduced marker-file
   creation via `cycler.__init__.__globals__['os'].system(...)` during
   `fix_chat_template()`. Fixed: fallback now uses
   `from jinja2.sandbox import SandboxedEnvironment`.

2. [14/20] The render-diff probe did `import jinja2` then referenced
   `jinja2.sandbox.SandboxedEnvironment`. `jinja2.sandbox` is a
   submodule that is NOT auto-imported by `import jinja2` on Jinja 3.1.6.
   This caused `AttributeError` (swallowed by `except Exception`),
   making the entire Case 2 repair path silently return None in a clean
   process. The 6 reviewers who saw it work had `jinja2.sandbox`
   pre-imported by an earlier module in their process. Fixed: both the
   probe and the proxy fallback now use
   `from jinja2.sandbox import SandboxedEnvironment`.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-16 05:52:33 -07:00
Daniel Han
6e87bade25 Trim verbose comments in PATH helpers
Reduce inline comments from ~160 lines to ~25 across both files.
Keep one-line summaries of the "why"; drop multi-paragraph rationale
blocks that repeated information already captured in commit messages
and PR discussion.
2026-04-16 12:01:01 +00:00
Etherll
ec32ce2e82
fix: use direct registry API for PATH writes instead of SetEnvironmentVariable (#4961)
* fix: replacing SetEnvironmentVariable with direct registry API

* apply reviews

* Use CreateSubKey for HKCU\Environment

* Store PATH backup under HKCU\Software\Unsloth

* Fix $backupKey registry handle leak in PATH backup block

Wrap $backupKey operations in try/finally so the handle is closed even
if GetValue or SetValue throws. The Add-ToUserPath helper already uses
this pattern for its registry key -- the backup block was the only
place missing it.

* Isolate WM_SETTINGCHANGE broadcast from PATH write error handling

Wrap the broadcast dummy-variable calls in their own try/catch so a
broadcast failure does not mask a successful registry PATH write.
Previously, if SetEnvironmentVariable threw after SetValue already
committed the new PATH, Add-ToUserPath would return $false and the
caller would skip Refresh-SessionPath.

* PATH helper polish: venv precedence, quoted entries, raw/expanded dedup

Three small follow-ups surfaced by a 10-reviewer pass against the rebased
PR head. None fix a regression vs main; each strictly improves the new
helpers.

Refresh-SessionPath / Refresh-Environment:
- Move $env:Path to the front of the merge so an activated venv keeps
  precedence over machine/user PATH after a refresh. Pre-PR dropped
  process-only entries entirely; post-PR kept them but at the back.
- Dedup on both raw and expanded forms so %USERPROFILE%\foo and the
  already-expanded C:\Users\me\foo do not both survive.

Add-ToUserPath:
- Trim whitespace and surrounding double-quotes from each compared entry
  so quoted PATH entries like "C:\Program Files\CMake\bin" deduplicate
  against an unquoted directory of the same path.

* Back up User PATH inside Add-ToUserPath, before first mutation

Previously only studio/setup.ps1 took a one-time PATH backup, at script
top (line ~547). install.ps1 (the irm | iex entry point) had no backup,
so users who installed via that path had no recovery surface if anything
clobbered their PATH. The PR description's "one-time backup before any
modifications" promise only held for the studio installer flow.

Move the backup into Add-ToUserPath itself: just before the first actual
SetValue mutation, write the pristine raw PATH to
HKCU\Software\Unsloth\PathBackup if no backup already exists. This:

- Covers both entry points (install.ps1 and studio/setup.ps1).
- Captures the TRUE pristine PATH even when install.ps1 runs first and
  studio/setup.ps1 runs afterwards (the script-top backup in setup.ps1
  would otherwise see an already-modified PATH).
- Is idempotent: once a backup exists, subsequent calls preserve it.
- Skips when nothing would mutate (dedup match) or PATH is empty.

The script-top backup in studio/setup.ps1 is kept for defense in depth.

* Refresh PATH: venv-aware merge order

Reconcile two competing concerns about Refresh-SessionPath /
Refresh-Environment surfaced by separate review rounds:

  - venv at the back -> activated venv loses precedence to system Python
  - process at the front -> stale shims (old node, old python, etc.)
    still on $env:Path can beat a freshly installed tool

New merge order:
  1. Activated venv Scripts dir, only if $env:VIRTUAL_ENV is set
  2. Machine PATH freshly read from registry
  3. User PATH freshly read from registry
  4. Current $env:Path as fallback

This way an explicitly-activated venv keeps priority while a tool the
script just installed wins over any stale entry that was already on
the inherited shell PATH. When no venv is active, fresh registry
entries take precedence as expected.

* Append to User PATH by default, close $envKey in finally

Add-ToUserPath gains a -Position Append|Prepend parameter defaulting to
Append so installing unsloth no longer prepends the bundled venv Scripts
directory ahead of the user's existing python / pip on new shells. The
four current call sites (install.ps1 launcher, studio/setup.ps1 CMake,
nvcc, Python user Scripts) all take the Append default because each one
that needs in-session precedence already does an inline $env:Path prepend
independently. This matches rustup / cargo / nvm / pyenv / uv behavior.

Also wrap the script-top $envKey.GetValue in a try/finally so the
registry handle is released even if the read throws. Matches the pattern
already used for $backupKey five lines below.

* Prepend cmake, nvcc, Python Scripts; keep venv Scripts appended

The previous commit switched Add-ToUserPath to append by default so that
installing unsloth would not silently hijack the user's system python /
pip. That was correct for the venv Scripts dir (which contains python.exe
and pip.exe alongside unsloth.exe), but wrong for the three studio/setup
call sites. Those persist cmake, the driver-compatible nvcc, and the
Python user Scripts dir for future shells, and in all three cases an
older tool already earlier in the user PATH would keep winning after the
install finished. The nvcc case is especially load-bearing: setup selects
a driver-compatible CUDA toolkit, then llama.cpp builds against whatever
wins PATH resolution, so a stale older nvcc produces broken builds.

Pass -Position 'Prepend' explicitly at the three setup.ps1 call sites
(cmake at line 754, nvcc bin at line 1025, Python user Scripts at line
1191). None of those directories holds python.exe, so prepending them
does not re-introduce the original hijack problem. Leave the install.ps1
venv Scripts call on the default Append with a comment explaining why.

* Symmetric dedup, Prepend reorders duplicates, unsloth shim dir

Address three separate findings surfaced by review:

1. Dedup asymmetry (Gemini high-priority): the existing dedup expanded
   registry entries via ExpandEnvironmentVariables but did NOT expand the
   new directory. Passing "%USERPROFILE%\foo" when "C:\Users\me\foo" was
   already in PATH produced a duplicate. Expand both sides so the check
   is symmetric.

2. -Position Prepend no-op on existing duplicates: the dedup loop
   returned $false as soon as it saw a match, regardless of position.
   That left a late-position duplicate in place instead of moving it to
   the front, so "prepend the newly selected cmake/nvcc" did not always
   beat an older copy earlier in PATH. Partition entries into kept and
   dropped lists, then reinsert a single copy at the requested position.
   Append still returns $false on any match so user-curated orderings
   are not reshuffled. Prepend also returns $false when the only copy
   is already at position 0 so we preserve the user's casing.

3. Stop adding the venv Scripts dir to User PATH entirely. That dir
   holds python.exe and pip.exe alongside unsloth.exe, so neither
   Prepend nor Append worked: prepend hijacked the user's system python
   and pip, append made the freshly-installed unsloth.exe lose to any
   older unsloth.exe earlier on PATH. Replace the Scripts-dir PATH add
   with a dedicated shim directory that contains only unsloth.cmd, and
   prepend that dir. The shim calls the venv's unsloth.exe by absolute
   path so future pip upgrades inside the venv propagate automatically.

* Shim via hardlink, Append user Scripts, drop venv sysconfig fallback

Three follow-ups to the c0ab1ab shim commit, targeting concerns raised in
the second 20-reviewer pass:

1. Shim uses unsloth.exe (hardlink, copy fallback) instead of unsloth.cmd.
   The batch-file approach had three distinct regressions:
   - cmd.exe expanded %...% sequences inside user arguments, so prompts
     like "What does 50% mean?" got mangled before reaching the CLI
   - Git Bash / MSYS2 / POSIX-style shells on Windows do not resolve
     bare-name lookups to .cmd files, so `unsloth` stopped working there
   - Set-Content -Encoding ASCII replaced non-ASCII profile characters
     with '?', so installs under C:\Users\Jörg\... wrote a broken shim
   A hardlink (fallback: copy) of unsloth.exe is a native Windows
   executable with no shell indirection. PATHEXT picks .exe before .cmd
   in cmd.exe and PowerShell, Git Bash honors .exe natively, subprocess
   callers hit it directly, and a hardlink stays in sync with the venv
   on pip upgrades because both names point at the same inode.

2. studio/setup.ps1 Python user Scripts dir is added with default Append
   instead of -Position Prepend. That directory holds every pip-installed
   user console script (pip, pytest, huggingface-cli, and so on), not
   just unsloth, so reordering it silently changed resolution order for
   unrelated tools. The new install.ps1 shim at PATH position 0 already
   guarantees `unsloth` resolves to the freshly installed copy, so the
   Python user Scripts entry only needs to be present, not at the front.

3. The sysconfig lookup in studio/setup.ps1 no longer falls back to
   sysconfig.get_path('scripts') when the nt_user scheme dir does not
   exist. When setup.ps1 is invoked from an activated venv (a flow the
   linked issue actually hits) that fallback returns the venv's Scripts
   directory, which would then be added to the persisted User PATH and
   re-introduce the python / pip hijack the shim dir is meant to avoid.
   Stick strictly to the nt_user scheme; skip the block if it does not
   exist on disk.

* Do not crash installer when unsloth.exe shim is locked

The shim update sequence at install.ps1:1095 did a bare Remove-Item /
New-Item HardLink / Copy-Item. Under the script's $ErrorActionPreference
a locked target (most commonly 'unsloth studio' still running while the
user re-invokes the installer) turns the Remove-Item failure into a
terminating error that aborts the install with no actionable message.

The existing shim is perfectly usable in that state, so there is no
reason to abort. Wrap the whole remove/link/copy sequence in a try/catch
that logs the probable cause (Studio still running), points at the fix
(close Studio and re-run), and lets the installer finish with the old
launcher still serving the command.

Also only emit the "added unsloth launcher to PATH" step line when the
launcher was actually (re)created AND the PATH entry was newly added --
previously the message fired even when the shim refresh silently failed,
which was confusing.

* Guard shim PATH entry on existence, use NullString for broadcast delete

Two follow-ups surfaced by the latest review pass:

1. Do not add the shim directory to User PATH when the launcher was not
   actually created. Antivirus blocking unsloth.exe, a disk-full volume,
   or restrictive filesystem permissions can make both the hardlink and
   the copy fallback fail on a fresh install. In that case the existing
   sequence would report "added unsloth launcher to PATH" warnings but
   still prepend the empty $ShimDir to User PATH -- the user sees an
   install that claims success but then cannot resolve `unsloth` in a
   new shell. Gate Add-ToUserPath on Test-Path $ShimExe so the PATH
   entry is only persisted when the launcher is really there.

2. Pass [NullString]::Value instead of $null to the broadcast-delete
   call in Add-ToUserPath. On PowerShell 7.5 and later (running on .NET
   9), a bare $null going into [Environment]::SetEnvironmentVariable
   can be coerced to an empty string rather than a true .NET null,
   which sets the dummy UnslothPathRefresh_XXXXXXXX variable to "" in
   HKCU\Environment instead of deleting it. The leaked variable is
   visible in System Properties and accumulates one entry per install
   run. [NullString]::Value is a PowerShell-specific sentinel that
   crosses the interop boundary as a real null and works on both PS 5.1
   and PS 7.x. See PowerShell/PowerShell#24637 for the underlying issue.

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: Lee Jackson <130007945+Imagineer99@users.noreply.github.com>
2026-04-16 04:49:51 -07:00
Imgyu Kim
14ab6fbfae
BUG: fix _fix_chat_template for ChatML templates missing add_generation_prompt (#4426)
Fixes #4150.

Pre-PR, `_fix_chat_template` only patched templates where a trailing `{{ ... }}` expression followed the last `{% endfor %}`. ChatML templates (Hermes, Magnum, Phi-4, etc.) that end cleanly at `{% endfor %}` with no generation-prompt block were left unchanged, so the outer `fix_chat_template` raised:

```
RuntimeError: Unsloth: The tokenizer `...` does not have a
{% if add_generation_prompt %} for generation purposes.
```

This commonly shows up when a downstream tool (LlamaFactory, Axolotl) re-serializes the tokenizer during LoRA save and strips the generation-prompt block.

This PR adds a second branch to `_fix_chat_template` that fires when:

- the content after the last `{% endfor %}` is empty modulo Jinja `{# ... #}` comments,
- the scrubbed template contains `<|im_start|>` and `<|im_end|>`,
- and the scrubbed template does not already mention `add_generation_prompt`.

The assistant-turn separator is inferred from the template itself (preferring an explicit `'<|im_start|>assistant<sep>'` literal, then the unique `message['role'] + '<sep>'` from role concatenations, then `<|im_sep|>` for Phi-4-mini mixed-separator templates, then `\n`), so Phi-4-style templates are not silently corrupted with the wrong separator.

Verified against the existing chat-template corpus:

- Hermes-3, Magnum-v2, Phi-4-mini, Phi-4 multi-sep, ChatML with trailing whitespace, ChatML with trailing Jinja comment, dot-access `message.role`, split-literal `'<|im_start|>assistant'`: all repaired with the correct assistant prefix.
- Already-fixed ChatML templates: idempotent NOP.
- Trap templates with `<|im_start|>` only inside a Jinja comment: correctly not rewritten.
- Llama-3, Gemma-3, Qwen2.5 (non-ChatML): byte-identical.
- Mistral family (5 models including Mistral-Nemo, Mistral-Small-24B, Mixtral): byte-identical, protected both by the structural guard (no ChatML tokens) and the existing name-based exemption in `load_correct_tokenizer`.
- Qwen family (14 models including Qwen2.5, Qwen3, Qwen3-Coder, QwQ, VL, Math, Qwen3-Guard): byte-identical.

End-to-end reproduction: Hermes-3 LoRA SFT, save with stripped chat_template, reload. Pre-PR code path raises the RuntimeError above. Post-PR reload loads cleanly, patches the template at load time, and `apply_chat_template(add_generation_prompt=True)` produces the correct `<|im_start|>assistant\n` prefix.
2026-04-16 00:21:29 -07:00
DoubleMathew
a4d4dfe4ac
fix Gemma4 flash attn disable (#5045)
* fix pass attn implementation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-15 17:50:48 -05:00
Daniel Han
3869fbe1cc
Bump installer minimum to 2026.4.5 (#5041) 2026-04-15 08:23:41 -07:00
Daniel Han
cdb3e752ec Update _utils.py 2026-04-15 08:06:43 -07:00
Daniel Han
ba387e2c8f Update pyproject.toml 2026-04-15 08:06:30 -07:00
Daniel Han
f0d03655e8
Studio: add folder browser modal for Custom Folders (#5035)
* Studio: add folder browser modal for Custom Folders

The Custom Folders row in the model picker currently only accepts a
typed path. On a remote-served Studio (Colab, shared workstation) that
means the user has to guess or paste the exact server-side absolute
path. A native browser folder picker can't solve this: HTML
`<input type="file" webkitdirectory>` hides the absolute path for
security, and the File System Access API (Chrome/Edge only) returns
handles rather than strings, neither of which the server can act on.

This PR adds a small in-app directory browser that lists paths on the
server and hands the chosen string back to the existing
`POST /api/models/scan-folders` flow.

## Backend

* New endpoint `GET /api/models/browse-folders`:
  * `path` query param (expands `~`, accepts relative or absolute; empty
    defaults to the user's home directory).
  * `show_hidden` boolean to include dotfiles/dotdirs.
  * Returns `{current, parent, entries[], suggestions[]}`. `parent` is
    null at the filesystem root.
  * Immediate subdirectories only (no recursion); files are never
    returned.
  * `entries[].has_models` is a cheap hint: the directory looks like it
    holds models if it is named `models--*` (HF hub cache layout) or
    one of the first 64 children is a .gguf/.safetensors/config.json/
    adapter_config.json or another `models--*` subfolder.
  * Sort order: model-bearing dirs, then plain, then hidden; case-
    insensitive alphabetical within each bucket.
  * Suggestions auto-populate from HOME, the HF cache root, and any
    already-registered scan folders, deduplicated.
  * Error surface: 404 for missing path, 400 for non-directory, 403 on
    permission errors. Auth-required like the other models routes.

* New Pydantic schemas `BrowseEntry` and `BrowseFoldersResponse` in
  `studio/backend/models/models.py`.

## Frontend

* New `FolderBrowser` component
  (`studio/frontend/src/components/assistant-ui/model-selector/folder-browser.tsx`)
  using the existing `Dialog` primitive. Features:
  * Clickable breadcrumb with a `..` row for parent navigation.
  * Quick-pick chips for the server-provided suggestions.
  * `Show hidden` checkbox.
  * In-flight fetch cancellation via AbortController so rapid
    navigation doesn't flash stale results.
  * Badges model-bearing directories inline.

* `chat-api.ts` gains `browseFolders(path?, showHidden?)` and matching
  types.

* `pickers.tsx` adds a folder-magnifier icon next to the existing `Add`
  button. Opening the browser seeds it with whatever the user has
  already typed; confirming fills the text input, leaving the existing
  validation and save flow unchanged.

## What it does NOT change

* The existing text-input flow still works; the browser is additive.
* No new permissions or escalation; the endpoint reads only directories
  the server process is already allowed to read.
* No model scanning or filesystem mutation happens from the browser
  itself -- it just returns basenames for render.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Studio: cap folder-browser entries and expose truncated flag

Pointing the folder browser at a huge directory (``/usr/lib``,
``/proc``, or a synthetic tree with thousands of subfolders) previously
walked the whole listing and stat-probed every child via
``_looks_like_model_dir``. That is both a DoS shape for the server
process and a large-payload surprise for the client.

Introduce a hard cap of 2000 subdirectory entries and a
``truncated: bool`` field on the response. The frontend renders a small
hint below the list when it fires, prompting the user to narrow the
path. Below-cap directories are unchanged.

Verified end-to-end against the live backend with a synthetic tree of
2050 directories: response lands at 2000 entries, ``truncated=true``,
listing finishes in sub-second time (versus tens of seconds if we were
stat-storming).

* Studio: suggest LM Studio / Ollama dirs + 2-level model probe

Three improvements to the folder-browser, driven by actually dropping
an LM Studio-style install (publisher/model/weights.gguf) into the
sandbox and walking the UX:

## 1. Quick-pick chips for other local-LLM tools

`well_known_model_dirs()` (new) returns paths commonly used by
adjacent tools. Only paths that exist are returned so the UI never
shows dead chips.

* LM Studio current + legacy roots + user-configured
  `downloadsFolder` from its `settings.json` (reuses the existing
  `lmstudio_model_dirs()` helper).
* Ollama: `$OLLAMA_MODELS` env override, then `~/.ollama/models`,
  `/usr/share/ollama/.ollama/models`, and `/var/lib/ollama/.ollama/models`
  (the systemd-service install path surfaced in the upstream "where is
  everything?" issue).
* Generic user-choice locations: `~/models`, `~/Models`.

Dedup is stable across all sources.

## 2. Two-level model-bearing probe

LM Studio and Ollama both use `root/publisher/model/weights.gguf`.
The previous `has_models` heuristic only probed one level, so the
publisher dir (whose immediate children are model dirs, not weight
files) was always marked as non-model-bearing. Pulled the direct-
signal logic into `_has_direct_model_signal` and added a grandchild
probe so the classic layout is now recognised.

Still O(PROBE^2) worst-case, still returns immediately for
`models--*` names (HF cache layout) and for any direct weight file.

## 3. model_files_here hint on response body

A leaf model dir (just GGUFs, no subdirs) previously rendered as
`(empty directory)` in the modal, confusing users into thinking the
folder wasn't scannable. Added a `model_files_here` count on the
response (capped at 200) and a small hint row in the modal: `N model
files in this folder. Click "Use this folder" to scan it.`

## Verification

Simulated an LM Studio install by downloading the real 84 MB
`unsloth/SmolLM2-135M-Instruct-Q2_K.gguf` into
`~/.lmstudio/models/unsloth/SmolLM2-135M-Instruct-GGUF/`. Confirmed
end-to-end:

* Home listing suggests `~/.lmstudio/models` as a chip.
* Browsing `~/.lmstudio/models` flags `unsloth` (publisher) as
  `has_models=true` via the 2-level probe.
* Browsing the publisher flags `SmolLM2-135M-Instruct-GGUF` (model
  dir) as `has_models=true`.
* Browsing the model dir returns empty entries but
  `model_files_here=1`, and the frontend renders a hint telling the
  user it is a valid target.

* Studio: one-click scan-folder add + prominent remove + plain search icon

Three small Custom Folders UX fixes after real-use walkthrough:

* **One-click add from the folder browser**. Confirming `Use this
  folder` now submits the path directly to
  `POST /api/models/scan-folders` instead of just populating the text
  input. `handleAddFolder` takes an optional explicit path so the
  submit lands in the same tick as `setFolderInput`, avoiding a
  state-flush race. The typed-path + `Add` button flow is unchanged.

* **Prominent remove X on scan folders**. The per-folder delete
  button was `text-muted-foreground/40` and hidden entirely on
  desktop until hovered (`md:opacity-0 md:group-hover:opacity-100`).
  Dropped the hover-only cloak, bumped color to `text-foreground/70`,
  added a red hover/focus background, and sized the icon up from
  `size-2.5` to `size-3`. Always visible on every viewport.

* **Plain search icon for the Browse button**. `FolderSearchIcon`
  replaced with `Search01Icon` so it reads as a simple "find a
  folder" action alongside the existing `Add01Icon`.

* Studio: align Custom Folders + and X buttons on the same right edge

The Custom Folders header used `px-2.5` with a `p-0.5` icon button,
while each folder row used `px-3` with a `p-1` button. That put the
X icon 4px further from the right edge than the +. Normalised both
rows to `px-2.5` with `p-1` so the two icons share a column.

* Studio: empty-state button opens the folder browser directly

The first-run empty state for Custom Folders was a text link reading
"+ Add a folder to scan for local models" whose click toggled the
text input. That's the wrong default: a user hitting the empty state
usually doesn't know what absolute path to type, which is exactly
what the folder browser is for.

* Reword to "Browse for a models folder" with a search-icon
  affordance so the label matches what the click does.
* Click opens the folder browser modal directly. The typed-path +
  Add button flow is still available via the + icon in the
  section header, so users who know their path keep that option.
* Slightly bump the muted foreground opacity (70 -> hover:foreground)
  so the button reads as a primary empty-state action rather than a
  throwaway hint.

* Studio: Custom Folders header gets a dedicated search + add button pair

The Custom Folders section header had a single toggle button that
flipped between + and X. That put the folder-browser entry point
behind the separate empty-state link. Cleaner layout: two buttons in
the header, search first, then add.

* Search icon (left) opens the folder browser modal directly.
* Plus icon (right) toggles the text-path input (unchanged).
* The first-run empty-state link is removed -- the two header icons
  cover both flows on every state.

Both buttons share the same padding / icon size so they line up with
each other and with the per-folder remove X.

* Studio: sandbox folder browser + bound caps + UX recoveries

PR review fixes for the Custom Folders folder browser. Closes the
high-severity CodeQL path-traversal alert and addresses the codex /
gemini P2 findings.

Backend (studio/backend/routes/models.py):

* New _build_browse_allowlist + _is_path_inside_allowlist sandbox.
  browse_folders now refuses any target that doesn't resolve under
  HOME, HF cache, Studio dirs, registered scan folders, or the
  well-known third-party model dirs. realpath() is used so symlink
  traversal cannot escape the sandbox. Also gates the parent crumb
  so the up-row hides instead of 403'ing.
* _BROWSE_ENTRY_CAP now bounds *visited* iterdir entries, not
  *appended* entries. Dirs full of files (or hidden subdirs when
  show_hidden is False) used to defeat the cap.
* _count_model_files gets the same visited-count fix.
* PermissionError no longer swallowed silently inside the
  enumeration / counter loops -- now logged at debug.

Frontend (folder-browser.tsx, pickers.tsx, chat-api.ts):

* splitBreadcrumb stops mangling literal backslashes inside POSIX
  filenames; only Windows-style absolute paths trigger separator
  normalization. The Windows drive crumb value is now C:/ (drive
  root) instead of C: (drive-relative CWD-on-C).
* browseFolders accepts and forwards an AbortSignal so cancelled
  navigations actually cancel the in-flight backend enumeration.
* On initial-path fetch error, FolderBrowser now falls back to HOME
  instead of leaving the modal as an empty dead end.
* When the auto-add path (one-click "Use this folder") fails, the
  failure now surfaces via toast in addition to the inline
  paragraph (which is hidden when the typed-input panel is closed).

* Studio: rebuild browse target from trusted root for CodeQL clean dataflow

CodeQL's py/path-injection rule kept flagging the post-validation
filesystem operations because the sandbox check lived inside a
helper function (_is_path_inside_allowlist) and CodeQL only does
intra-procedural taint tracking by default. The user-derived
``target`` was still flowing into ``target.exists`` /
``target.is_dir`` / ``target.iterdir``.

The fix: after resolving the user-supplied ``candidate_path``,
locate the matching trusted root from the allowlist and rebuild
``target`` by appending each individually-validated segment to
that trusted root. Each segment is rejected if it isn't a single
safe path component (no separators, no ``..``, no empty/dot).
The downstream filesystem ops now operate on a Path constructed
entirely from ``allowed_roots`` (trusted) plus those validated
segments, so CodeQL's dataflow no longer sees a tainted source.

Behavior is unchanged for all valid inputs -- only the
construction of ``target`` is restructured. Live + unit tests
all pass (58 selected, 7 deselected for Playwright env).

* Studio: walk browse paths from trusted roots for CodeQL

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ubuntu <ubuntu@h100-8-cheapest.us-east5-a.c.unsloth.internal>
2026-04-15 08:04:33 -07:00
Roland Tannous
800ddc95f8
Re-apply #4939: updated models template mappers (#4950)
* Reapply "updated models template mappers. added lfm2.5vl450m to transformers 5…" (#4945)

This reverts commit 33503ea248.

* Add missing gemma-4-31B-it bnb-4bit mapper entry and LFM2.5 upstream namespace for PR #4950

- Add unsloth/gemma-4-31B-it-unsloth-bnb-4bit to __INT_TO_FLOAT_MAPPER so
  the int-to-float resolution works for this model (already listed in
  TEMPLATE_TO_MODEL_MAPPER but had no mapper entry).
- Add LiquidAI/LFM2.5-1.2B-Instruct to lfm-2.5 TEMPLATE_TO_MODEL_MAPPER
  entry so the canonical upstream namespace is mapped consistently with lfm-2.

* Add missing gemma-4-31B-it bnb-4bit Ollama mapping and lfm-2.5 chat template alias

- Add unsloth/gemma-4-31B-it-unsloth-bnb-4bit to OLLAMA_TEMPLATE_TO_MODEL_MAPPER
  so Ollama export works for this model (E2B-it and E4B-it bnb-4bit variants were
  already present, 31B-it was inconsistently omitted)
- Register CHAT_TEMPLATES["lfm-2.5"] as alias of the lfm-2 template to prevent
  KeyError when Studio resolves LFM2.5 models through MODEL_TO_TEMPLATE_MAPPER

* Add missing LFM2 bnb-4bit INT_TO_FLOAT_MAPPER entry

unsloth/LFM2-1.2B-unsloth-bnb-4bit is referenced in model_mappings.py
but had no mapper.py entry, so model resolution would fail when users
load that variant with load_in_4bit=False or when the float name is
used with load_in_4bit=True.

* Fix review findings for PR #16

1. ollama_template_mappers.py: Restore dropped Gemma-4 base model IDs
   (E2B, E4B, 31B, 26B-A4B) and add missing google/ upstream IDs to
   the gemma4 Ollama mapper for consistency with other gemma entries.

2. mapper.py: Remove self-mapping non-bnb-4bit entries from
   __INT_TO_FLOAT_MAPPER that were polluting FLOAT_TO_INT_MAPPER with
   lowercase 16-bit names, causing load_in_4bit=True to return bad
   model names. Add direct MAP_TO_UNSLOTH_16bit entries to preserve
   the google->unsloth 16-bit redirects.

3. mapper.py: Add LFM2.5 MAP_TO_UNSLOTH_16bit redirect so
   LiquidAI/LFM2.5-1.2B-Instruct resolves to its unsloth mirror.

* Add review tests for PR #4950

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove top-level test files

These test_*.py files were added at the repo root rather than under tests/.
Removing them from this PR; the production mapper changes remain.

* Add gemma-4-26B-A4B-it mapping

Adds unsloth/gemma-4-26B-A4B-it to __INT_TO_FLOAT_MAPPER as a 2-tuple so
google/gemma-4-26B-A4B-it routes to unsloth/gemma-4-26B-A4B-it across
INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER, and MAP_TO_UNSLOTH_16bit.

The 26B-A4B (MoE) model has no bnb-4bit variant, so the key uses the
plain unsloth name rather than the -unsloth-bnb-4bit suffix.

Removes the now-redundant standalone _add_with_lower call for the -it
variant; the 16bit mapping is registered via the dict loop.

* Add unsloth-bnb-4bit mappings for gemma-4 base (non-it) models

Adds E2B, E4B, 31B base unsloth-bnb-4bit entries to __INT_TO_FLOAT_MAPPER.
The 26B-A4B (MoE) base has no bnb-4bit variant on HF, so it stays on the
standalone _add_with_lower line for the 16bit-only routing.

Removes the redundant _add_with_lower lines for E2B, E4B, 31B base since
the dict loop now registers the same google->unsloth route through the
2-tuple entries, plus full FLOAT_TO_INT and INT_TO_FLOAT coverage.

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-15 07:52:12 -07:00
Avaya Aggarwal
7c5464ad71
feat: Add cactus QAT scheme support (#4679)
* feat: Add cactus QAT scheme support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* test(qat): add tests for cactus QAT scheme and fix missing import

* Fix cactus QAT scheme: correct MappingType import, tighten PerGroup filter

- Drop the broken `from torchao.dtypes import MappingType` import. `MappingType`
  lives in `torchao.quantization` (and `torchao.quantization.quant_primitives`);
  it is not exported from `torchao.dtypes` in any supported torchao release
  (verified on 0.14, 0.16, 0.17). The previous code raised `ImportError` on
  every cactus call and was masked as a misleading 'torchao not found' error.
- Since `IntxWeightOnlyConfig` already defaults `mapping_type` to
  `MappingType.SYMMETRIC`, drop the explicit kwarg entirely and remove the
  import. Behavior is unchanged.
- Introduce a named `group_size = 32` constant (matches the int4 / fp8-int4
  pattern in the surrounding branches) and add a `% group_size == 0`
  divisibility guard to the filter. `PerGroup(32)` requires
  `in_features % 32 == 0` at `quantize_()` time, otherwise torchao raises
  `ValueError: in_features (N) % group_size (32) must be == 0`. The old
  `in_features >= 32` filter would admit non-aligned widths (e.g. 33, 48, 65,
  127) and crash `_prepare_model_for_qat` for those shapes.

* Warn when cactus QAT skips non-divisible Linear layers

Multiple reviewers flagged that the divisibility guard added in the
previous commit can silently leave Linear layers in full precision when
their in_features is not a multiple of 32. For currently supported
Unsloth models (Qwen, Llama, Gemma, Mistral, Phi) every Linear width is
already a multiple of 32/64/128 so this never triggers, but surfacing
the coverage gap is cheap and avoids users assuming 100% QAT coverage
when they bring a custom model with unusual shapes.

Emit a UserWarning listing up to the first 8 skipped layers whenever
the cactus filter excludes any Linear due to the modulo guard. This
keeps the lenient silent-skip behavior (consistent with int4 /
fp8-int4), but stops making it silent.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-04-15 07:40:03 -07:00
Avaya Aggarwal
f18e9dddf0
feat: Add support for OLMo-3 model (#4678)
* feat: Add support for OLMo-3 model in mapping and tests

* Update unsloth/models/mapper.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update tests/test_get_model_name.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Fix casing, add Think variants, and align version gate for OLMo-3 PR 4678

Mapper: switch slugs from OLMo-3 to canonical Olmo-3 mixed case, drop the
non-existent unsloth/Olmo-3-7B-Instruct-bnb-4bit dead alias, and add the
already-published Olmo-3-7B-Think and Olmo-3-32B-Think Unsloth mirrors.

Loader: change the olmo3 transformers version gate from Version("4.57.0")
to Version("4.57.0.dev0") so nightly/source builds that already contain
olmo3 are not blocked, matching the OLMo-2, Gemma 3 and Cohere patterns.

* Use canonical Olmo-3 casing and cover Think variants in OLMo-3 tests

Mirrors the mapper.py fixes on pr-4678-code: HuggingFace canonical slugs
for the OLMo-3 family use mixed-case Olmo-3 (not OLMo-3 like OLMo-2), and
Unsloth already hosts Olmo-3-7B-Think and Olmo-3-32B-Think mirrors, so
the resolution matrix now covers all three published Olmo-3 families.

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-04-15 07:39:11 -07:00
Daniel Han
c3cd890357
Studio: refresh Downloaded GGUF list and recurse into variant subdirs (#5032)
* Studio: refresh Downloaded GGUF list and recurse into variant subdirs

Two fixes for the model picker's "Downloaded" section.

Frontend (`pickers.tsx`):
* `HubModelPicker`'s mount effect short-circuited the cached-gguf and
  cached-models refetch whenever the module-level cache already had
  entries (`if (alreadyCached) return;`). After downloading a new repo
  in the same session, reopening the picker rendered the stale cache
  and the new repo never appeared in "Downloaded" until a full page
  reload. The early return is removed so the lists are always refreshed
  on mount; the module cache still drives the initial render so there
  is no spinner flash when we already had data.

Backend (`utils/models/model_config.py`):
* `list_local_gguf_variants` and `_find_local_gguf_by_variant` used a
  non-recursive `Path.glob("*.gguf")`. Some HF GGUF repos (e.g.
  `unsloth/gemma-4-26B-A4B-it-GGUF`) place the largest quants under a
  variant-named subdirectory such as `BF16/...gguf`, which the
  top-level glob missed. Both helpers now use `rglob` and the variant
  filename is stored as a path relative to the scan root so the
  locator can still find the file.

The flat-layout case (variants directly in the snapshot root) is
unchanged: verified against `unsloth/gemma-4-E2B-it-GGUF` which still
returns its UD-Q4_K_XL variant correctly.

* Studio: emit posix-style relative filenames for local GGUF subdirs

`list_local_gguf_variants` was doing `str(f.relative_to(p))`, which on
Windows produces backslash-separated paths like `BF16\foo.gguf`. The
remote `list_gguf_variants` (HF API path) always returns forward-slash
filenames such as `BF16/foo.gguf`, so the two would diverge on Windows.

Switch to `.as_posix()` so the local and remote variant filenames stay
identical across Linux, macOS, and Windows. Verified by simulating with
`PureWindowsPath` in the test suite.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Studio: detect mmproj at snapshot root for nested-variant layouts

When _find_local_gguf_by_variant returns a weight file inside a
quant-named subdir (e.g. snapshot/BF16/foo.gguf), detect_mmproj_file
was scanning only the immediate parent and missing the mmproj file
sitting at the snapshot root. The model was then loaded without
--mmproj, silently breaking vision support for repos that ship
nested variants.

detect_mmproj_file now takes an optional search_root and walks up
from the weight file to that root, in order, so the mmproj at the
snapshot root is picked up. Sibling quant subdirs are not scanned,
so an unrelated variant's mmproj does not leak in.

Also apply the suggested micro-optimization on relative_to in
list_local_gguf_variants -- only build the posix path when storing
the first file for a quant.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-15 07:34:42 -07:00
Daniel Han
156f3fc4b0
Gate trl disable_gradient_checkpointing patch warning on UNSLOTH_ENABLE_LOGGING (#5038)
The "Patched trl.models.utils.disable_gradient_checkpointing with a no-op"
warning fires once on every Unsloth import, including from notebooks where
the user did not opt into verbose logging. It is a routine integration
patch, not an anomaly the user needs to know about. Gate it on
UNSLOTH_ENABLE_LOGGING=1 like other diagnostic notices.
2026-04-15 07:33:48 -07:00
jonahsamost
777e1bd0ac
fix (#4887) 2026-04-15 07:21:03 -07:00
Daniel Han
1a4ca5eca8
Fix grad-accum accepts_loss_kwargs detection for vision wrappers (#5036)
* Fix grad-accum model_accepts_loss_kwargs detection for vision wrappers

Replace the source-string rewrite of Trainer.__init__ with an instance-level
accepts_loss_kwargs shadow applied on the loaded model. Covers:

  1. Unsloth-compiled forward -> True, so HF Trainer does not double-scale
     on top of unsloth_fixed_cross_entropy's num_items_in_batch division.
  2. Stock forward on a conditional-generation wrapper (Gemma3n, Gemma3
     pre-4.57, Qwen-VL family, etc.) where the outer class has no
     accepts_loss_kwargs but the inner .model declares False -> False.
     This is the case that reproduces issue #4982 under trust_remote_code
     or UNSLOTH_COMPILE_DISABLE, where the previous fix's outer-attr
     check walked past the inner model and fell through to signature
     inspection.
  3. Text LMs without any explicit accepts_loss_kwargs -> leave HF default.

The previous .replace()-based patch silently no-ops on transformers 4.48
through 4.52 (variable named model, not unwrapped_model) and is fragile
against any upstream reformat. The new helper walks the PEFT / HF wrapper
chain, finds the first class that declares accepts_loss_kwargs on its own
class dict (type(m).__dict__, not hasattr, to avoid PEFT __getattr__
forwarding), and setattr-shadows that value at every wrapper level so
HF Trainer's hasattr(unwrapped_model, ...) check picks it up at whichever
level accelerate.unwrap_model returns.

Also adds an unconditional post-init clamp of
accelerator.gradient_accumulation_steps = 1 to work around the
transformers 5.0 through 5.5 GradientAccumulationPlugin regression that
makes accelerator.backward divide loss by GA on top of training_step's
own /GA division. Fixed upstream in 5.6.0.dev0; no-op on 4.x and 5.6+.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Trim comments

* Address review: cover PEFT-after-load and custom compile location

Two review findings from 3/20 reviewers:

1. [3 of 20 reviewers] apply_accepts_loss_kwargs_fix was called from the
   loaders before get_peft_model wraps the base model, so on transformers
   4.48-4.52 (which does hasattr on the outer model) the instance shadow
   on the base model was lost after PEFT wrapping. Fix: also call it from
   the wrapped Trainer.__init__ so it runs on whatever model the user
   actually hands to Trainer, which is always the final wrapped form.

2. [1 of 20 reviewers] _forward_is_unsloth_compiled hard-coded the
   substrings "unsloth_compiled" / "unsloth_cache" in the co_filename
   check, which misclassifies compiled forwards when
   UNSLOTH_COMPILE_LOCATION is set to a custom directory. Fix: new
   _unsloth_compile_cache_leaves helper that reads the env var and
   matches the basename against path components, honoring both the
   default and any user override.

Verified locally:
- PEFT-after-load simulation: HF's hasattr(peft, "accepts_loss_kwargs")
  now returns True after our init wrapper runs, and value resolves to
  False on Gemma3n-style inner wrappers.
- Custom UNSLOTH_COMPILE_LOCATION simulation: compiled detection returns
  True for /tmp/my_custom_cache/compiled.py when the env var is set.
- End-to-end Gemma-3 270m + LoRA SFT unchanged: loss 4.9626, grad-norm
  matches prior run, all 4 wrapper levels now carry the shadowed attr.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-15 06:59:36 -07:00
Daniel Han
1ccfd2e0a5
fix(rocm): tighten gfx regex to ignore generic ISA lines (#5033)
* fix(rocm): tighten gfx regex to ignore generic ISA lines

ROCm 6.1+ rocminfo emits generic ISA names such as
"amdgcn-amd-amdhsa--gfx11-generic" and "amdgcn-amd-amdhsa--gfx9-4-generic"
alongside the real GPU name. The previous `gfx[1-9]` regex used in
`_has_rocm_gpu` matched both, so a host with only a generic ISA entry
would be reported as having a usable AMD GPU.

Tighten the pattern to `gfx[1-9][0-9a-z]{2,3}` so only real gfx ids
match. This covers every documented target from GFX6 (gfx600) through
GFX12 (gfx1201), including letter-suffixed ids like gfx90a (MI250 /
MI250X) and gfx90c. Documented generic ISA names always have 1 or 2
digits before the dash and no longer match.

Applied to both `studio/install_python_stack.py` and
`studio/install_llama_prebuilt.py` so the two detection paths agree.

Co-authored-by: Martin Hoyer <mhoyer@redhat.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Martin Hoyer <mhoyer@redhat.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-15 05:24:41 -07:00
Daniel Han
b7a8ff2833
Respect classification head skip list on pre-quantized 4-bit checkpoints (#5027) (#5034)
* Respect classification head skip list on pre-quantized 4-bit checkpoints (#5027)

FastLanguageModel.from_pretrained(..., num_labels=N) crashed with
"NotImplementedError: normal_kernel_cuda not implemented for 'Byte'" on
pre-quantized bnb 4-bit checkpoints (e.g. unsloth/Qwen3-4B-bnb-4bit)
when running on transformers 5.x.

Two pieces were needed to close this out:

1. unsloth_zoo PR: add "score", "classifier", "qa_outputs" to
   SKIP_QUANTIZATION_MODULES so replace_with_bnb_linear leaves task
   heads in the compute dtype.

2. This commit: for pre-quantized checkpoints, transformers reads
   llm_int8_skip_modules from the quantization_config baked into
   config.json and ignores the runtime BitsAndBytesConfig we pass via
   kwargs. Unsloth must merge its skip list into
   model_config.quantization_config.llm_int8_skip_modules before the
   from_pretrained call, or the checkpoint's frozen list
   (e.g. ["lm_head", "multi_modal_projector", "merger",
   "modality_projection"]) wins and the `score` head gets converted to
   Linear4bit with uint8 storage, then _init_weights calls normal_ on
   uint8 and crashes.

Also add a defensive post-load cast on the task head to guard against
any residual path that ends up with a non-floating head dtype.

Verified on transformers 4.57.6 and 5.5.0 with:
- unsloth/Qwen3-4B-bnb-4bit + num_labels=3
- unsloth/Qwen3-4B (non-bnb repo, load_in_4bit=True)
- unsloth/Llama-3.2-1B-Instruct + num_labels=3
- unsloth/ModernBERT-large classifier head (bert_classification notebook)
- Regression: causal LM path unchanged, backbone still 4-bit
- 3-step SFT on num_labels=3 confirms gradient flow and weight updates
  on score.weight

Fixes unslothai/unsloth#5027

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-15 05:16:33 -07:00
David Solanas Sanz
1fcb2502cf
fix: prevent offline freeze by fixing stats retry and forwarding local_files_only (#5016)
Fixes #2393.

- `_utils.py`: `has_internet()` now respects `HF_HUB_OFFLINE` with truthy variant parsing in addition to `TRANSFORMERS_OFFLINE`.
- `_utils.py`: replace uncontrolled `except Exception: stats_check()` retry (which had no time limit and could freeze on Kaggle offline mode) with a logged skip.
- `loader.py`: forward `local_files_only` from kwargs into all `AutoConfig.from_pretrained` and `PeftConfig.from_pretrained` probes in `FastLanguageModel.from_pretrained` and `FastModel.from_pretrained`, including the PEFT base-model reload paths.
2026-04-15 04:51:31 -07:00
Lee Jackson
f9ef639dde
Studio: support GGUF variant selection for non-suffixed repos (#5023)
* fix: support GGUF variant selection for non-suffixed repos

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: harden GGUF detection across cached models and picker flows

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* chore: use shared GGUF picker helper for search rows

* fix: avoid mixed cache duplication and preserve GGUF fallback detection

* fix: unify GGUF cache matching and merge picker hints

* fix: normalize local GGUF matching across picker and model config

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: robust cached-gguf classification + hint-aware click routing

- _repo_gguf_size_bytes: treat size_on_disk=None as 0 and dedupe fallback
  by commit_hash so partial/interrupted downloads don't TypeError out of
  sum() and wipe the entire cached list.
- list_cached_gguf / list_cached_models: narrow per-repo try/except so
  one malformed repo no longer poisons the whole response.
- handleModelClick: route through isKnownGgufRepo instead of the
  suffix-only isGgufRepo, so non-suffixed GGUF repos still open the
  variant expander from every call site.
- Replace the modelIsGgufById/resultIsGgufById Maps with Sets of known
  GGUF ids to stop conflating "no hint" with "known not-GGUF".
- Make HfModelResult.isGguf required (it is always set in makeMapModel).
- Add regression tests for the None size case, mixed-repo inclusion in
  cached-gguf, and per-repo error isolation.

* fix: exclude mmproj from GGUF classification and case-normalize hint lookups

- _repo_gguf_size_bytes now filters mmproj vision-adapter files so
  safetensors+mmproj.gguf repos stay on the cached-models path and
  non-GGUF rows no longer show zero pickable variants. A vision-capable
  GGUF repo (main weight + mmproj adapter) still classifies as GGUF and
  reports the main weight size.
- modelGgufIds / resultGgufIds now key on lowercased ids and
  isKnownGgufRepo lowercases its lookup, so store and HF-search ids
  that differ only by casing still match the same GGUF hint.
- New regression tests: mmproj-only repo excluded from cached-gguf,
  same repo included in cached-models, vision-capable repo still
  classified as GGUF with correct size.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Roland Tannous <rolandtannous@gravityq.ai>
Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
2026-04-15 15:32:01 +04:00
Roland Tannous
13928b5f0e
Add configurable PyTorch mirror via UNSLOTH_PYTORCH_MIRROR env var (#5024)
* Add configurable PyTorch mirror via UNSLOTH_PYTORCH_MIRROR env var

When set, UNSLOTH_PYTORCH_MIRROR overrides the default
https://download.pytorch.org/whl base URL in all four install scripts
(install.sh, install.ps1, studio/setup.ps1, studio/install_python_stack.py).
When unset or empty, the official URL is used. This lets users behind
corporate proxies or in regions with poor connectivity to pytorch.org
point at a local mirror without patching scripts.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add pytest for UNSLOTH_PYTORCH_MIRROR in install_python_stack.py

Tests that _PYTORCH_WHL_BASE picks up the env var when set, falls back
to the official URL when unset or empty, and preserves the value as-is
(including trailing slashes).

* Remove stale test assertions for missing install.sh messages

* Fix GPU mocking in test_get_torch_index_url.sh

Extract _has_usable_nvidia_gpu and _has_amd_rocm_gpu alongside
get_torch_index_url so the GPU-presence checks work in tests.
Add -L flag handling to mock nvidia-smi so it passes the GPU listing
check. All 26 tests now pass on CPU-only machines.

* Strip trailing slash from UNSLOTH_PYTORCH_MIRROR to avoid double-slash URLs

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-15 11:39:11 +04:00
Datta Nimmaturi
826c98f3c0
[moe][gemma4] Target MoE for gemma4 (#4913)
* Target MoE for gemma4

* refactor attention impl determine

* Revert "refactor attention impl determine"

This reverts commit 888fca08110a9a74278dc1ebc14d0da043bbd11d.

* Remove attention policy changes from gemma4 MoE fix
2026-04-14 16:53:07 -05:00
Daniel Han
5aa8c15246
Studio: hard-stop at n_ctx with a 'Context limit reached' toast (#5021)
* Studio: hard-stop at n_ctx with a dedicated 'Context limit reached' toast

llama-server's default behavior when the KV cache fills is to silently
drop the oldest non-``n_keep`` tokens and keep generating. The UI has
no way to tell the user that earlier turns were evicted -- they just
see degraded continuity and a confusing ``5,361 / 4,096`` on the
context usage bar.

Launch llama-server with ``--no-context-shift`` so it returns a clean
error once the request would exceed ``n_ctx``. In the chat adapter,
catch the error, identify it as a context-limit error via
``isContextLimitError()``, and surface a dedicated toast that names
the exact control to adjust: the ``Context Length`` field in the chat
Settings panel.

Also add a lightweight tooltip hint on ``ContextUsageBar`` when usage
crosses 85%, so users see the "raise Context Length in Settings"
suggestion before they hit the hard stop.

Tests:

  * ``test_llama_cpp_no_context_shift.py`` pins the ``--no-context-shift``
    flag in the static launch-command template, and pins it inside the
    unconditional ``cmd = [ ... ]`` block so a future refactor can't
    hide it behind a branch.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Shorten --no-context-shift comment to 1 line

* Match backend _friendly_error rewrite in isContextLimitError

Codex review on PR caught that ``backend/routes/inference.py::_friendly_error``
rewrites the raw llama-server text
  "request (X tokens) exceeds the available context size (Y tokens)"
into
  "Message too long: X tokens exceeds the Y-token context window. ..."
on the main streaming GGUF path. The heuristic only looked for
"context size" / "exceeds the available context" / "context shift",
none of which survive the rewrite, so the new "Context limit reached"
toast would never fire for the most common case. Add matches for
"message too long" and "context window" so both wordings hit.

Also addresses Gemini feedback on the launch-flag test:
  * Use ``inspect.getsource(LlamaCppBackend.load_model)`` instead of
    reading ``__file__`` directly; scopes the assertions to the
    function that actually launches llama-server.
  * Replace the hardcoded ``"            ]"`` indent search with a
    line-at-a-time scan for a line that is just ``]``, so the test
    survives reformatting.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-14 10:58:20 -07:00
Daniel Han
5861a7ce15
Studio: split model-load progress label across two rows (#5020)
* Studio: split model-load progress label across two rows

The chat flow and training overlay both compose a progress label like
"112.6 of 122.3 GB • 331.0 MB/s • 30s left" and render it next to the
percent badge in a single flex row. Once the rate + ETA part shows up,
the label outgrows the row width and wraps mid-phrase, orphaning the
percent ("19 left %") onto a second ragged line.

Fix in model-load-status.tsx: split the label on the first " • " into
a primary (size) chunk that stays on row 1 with the percent, and a
secondary (rate/ETA) chunk that renders on its own muted row below.
Labels without a bullet (e.g. "22.8 GB downloaded") collapse cleanly
to one row. The inline-status variant keeps only the primary and
surfaces the full label via the tooltip.

Also extracts the rate/ETA math out of useTransferStats into a pure
``transfer-stats.ts`` module (appendSample + computeTransferStats) so
it can be reasoned about and tested without React. The hook is now a
thin wrapper that feeds sample history through the pure functions.

Backend: adds two companion test files for load_progress():

  * test_llama_cpp_load_progress_matrix.py (21 tests) -- platform
    matrix (Linux /proc, macOS/Windows absence), VmRSS parsing
    variants (tab/space/missing/malformed), filesystem edges (HF-cache
    symlinks, broken symlinks, nonexistent paths, relative paths),
    shard aggregation (partial multi-shard, two series in same dir,
    mmproj-* exclusion, single-file), lifecycle races, concurrent
    sampling (10 threads x 50 iters against real /proc), fraction
    bounds.
  * test_llama_cpp_load_progress_live.py (5 tests) -- no-mock live
    integration: real subprocess allocating 100 MB to match VmRSS,
    real ready phase, real dead-pid degradation, real shard
    aggregation, repeated polling. Skipped on non-Linux.

Both complement the existing test_llama_cpp_load_progress.py.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Hoist splitProgressLabel out of JSX IIFE (review feedback)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-14 10:58:16 -07:00
Eda Z
5b8dbdc3c2
Fix bitsandbytes ROCm install by using pip instead of uv (#4966)
* Fix bitsandbytes ROCm install by using pip instead of uv

* Also use pip for PyPI fallback path in _install_bnb_rocm

The original fix correctly switched the pre-release wheel install from
uv to pip, but left the PyPI fallback path on uv. If uv breaks bnb
on ROCm, the fallback would hit the same issue. Move pip bootstrap
before the branch so both paths use pip consistently.

* Harden pip bootstrap: try ensurepip first, warn on failure

- Try ensurepip --upgrade before falling back to uv pip install pip.
  ensurepip works offline and does not need PyPI, making the bootstrap
  robust when the network or index is unavailable.
- If both ensurepip and uv fail, emit a visible warning instead of
  silently swallowing the error (which previously led to a cryptic
  "No module named pip" downstream).
- Use run_maybe_quiet so --verbose users see bootstrap output.
- Update comment to document the actual root cause: uv rejects the
  wheel because filename version and metadata version disagree.

* Add --isolated to pip install calls in _install_bnb_rocm

uv pip install ignores pip.conf and PIP_* env vars, but python -m pip
reads them. Without --isolated, users with PIP_INDEX_URL pointing to a
private mirror that does not carry bitsandbytes would see the PyPI
fallback fail where it previously worked under uv. --isolated restores
parity with the old uv behavior.

* Drop --isolated from PyPI fallback in _install_bnb_rocm

--isolated suppresses PIP_INDEX_URL, PIP_EXTRA_INDEX_URL, and pip.conf.
This is correct for the pre-release path (hardcoded GitHub URL, no index
consulted), but breaks the PyPI fallback for users in corporate or
air-gapped environments whose only route to bitsandbytes is a private
mirror configured via those mechanisms. Keep --isolated on the direct-URL
pre-release install; drop it from the index-dependent fallback.

* Drop --isolated from pre-release pip install, fix warning wording

--isolated suppresses pip.conf cert/proxy/CA settings in addition to
index config. For the direct GitHub URL, index config is irrelevant but
cert/proxy settings matter in corporate SSL-inspection environments.
Without this fix, users with pip.conf-based CA bundles get a TLS error
on the pre-release download and silently fall back to the broken PyPI
version -- the exact outcome the PR is trying to prevent.

Also fix the fallback warning: "unreachable" is too specific since the
pre-release install can fail for reasons other than network reachability.

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-04-14 10:23:40 -07:00
pre-commit-ci[bot]
a0b9d14081
[pre-commit.ci] pre-commit autoupdate (#5004)
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.15.9 → v0.15.10](https://github.com/astral-sh/ruff-pre-commit/compare/v0.15.9...v0.15.10)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-14 09:49:18 -07:00
Daniel Han
bb14ab144a
Studio: live model-load progress + rate/ETA on download and load (#5017)
* Studio: live model-load progress + rate/ETA on download and load

Two UX fixes for the opaque multi-minute wait between clicking Load
and being able to chat, visible most clearly on large MoE GGUFs like
MiniMax-M2.7 (131 GB of weights on a 97 GB GPU):

1. **Model-load phase is now observable.** The existing chat flow
   transitions the toast to "Starting model..." as soon as the
   download hits 100%, then shows a spinner with no other feedback
   until llama-server reports healthy. For a 130 GB model that spinner
   freezes for five-plus minutes while the kernel pages shards into
   the page cache. A new `GET /api/inference/load-progress` endpoint
   samples `/proc/<pid>/status VmRSS` on the llama-server subprocess
   against the sum of shard file sizes on disk, so the UI can render
   a real bar plus rate / ETA during that window.

2. **Rate and ETA on downloads and loads.** Both the chat toast and
   the training-start overlay used to show a static pair of numbers
   (for example "15.4 of 140.8 GB"). A rolling 15-second window over
   the existing byte-series now surfaces "85.3 MB/s, 24m 23s left"
   beside that pair. The estimator is shared between the download
   and load phases so the numbers don't reset when the phase flips.

Also fixes a pre-existing assignment bug uncovered while wiring this
up: `load_model` was storing the caller's `gguf_path` kwarg into
`self._gguf_path`, which is `None` on the HF-download code path. The
resolved on-disk path (`model_path`) is what llama-server actually
mmaps; downstream consumers need that. No existing reader used
`_gguf_path`, so this is a correctness fix for the new endpoint.

- Backend: `LlamaCppBackend.load_progress()`, `GET /api/inference/load-progress`, `LoadProgressResponse` Pydantic model.
- Frontend: `useTransferStats` hook, `formatRate` / `formatEta` helpers, `getLoadProgress` client, rewired chat toast and `DownloadRow` in the training overlay.
- Tests: `studio/backend/tests/test_llama_cpp_load_progress.py` covers empty states, mmap phase, ready phase, sharded total aggregation, missing gguf_path, and unreadable /proc (7 cases). `tsc -b` and `vite build` on the frontend both clean.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-14 09:46:22 -07:00
Roland Tannous
514bb3a20e
studio: pin peft to 0.18.1 to fix export subprocess issues (#5015)
* studio: pin peft to 0.18.1 to fix export subprocess issues

peft 0.19.0 causes export subprocess shutdown failures in Studio.
Reverting to 0.18.1 resolves the issue.

* studio: move peft pin to extras-no-deps to prevent torch upgrade

Installing peft via overrides.txt would resolve its deps and pull in
torch>=0.11.0, breaking other pinned packages. Moving the pin to
extras-no-deps.txt ensures --no-deps is used during install.
2026-04-14 20:16:30 +04:00
Datta Nimmaturi
4328d0b4f6
Fix num_items_in_batch GA for Gemma4 (#4998)
* Fix num_items_in_batch GA for Gemma4

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-14 09:01:10 -07:00
Daniel Han
7252410ccc
studio: stream export worker output into the export dialog (#4897)
* studio: stream export worker output into the export dialog

The Export Model dialog only showed a spinner on the "Exporting..."
button while the worker subprocess was doing the actual heavy lifting.
For Merged to 16bit and GGUF / Llama.cpp exports this meant several
minutes (or more, for large models) of opaque silence, with no way to
tell whether save_pretrained_merged, convert_hf_to_gguf.py, or
llama-quantize was making progress.

This adds a live terminal-style output panel inside the export dialog,
rendered just above the Cancel / Start Export buttons and scrollable
with auto-follow-tail. It shows stdout and stderr from both the worker
process itself and any child process it spawns (GGUF converter,
llama-quantize), coloured by stream.

Backend

- core/export/worker.py: new _setup_log_capture(resp_queue) installed
  before LogConfig.setup_logging. It saves the original stdout/stderr
  fds, creates pipes, os.dup2's the write ends onto fds 1 and 2 (so
  every child process inherits the redirected fds), and spins up two
  daemon reader threads. Each thread reads bytes from a pipe, echoes
  them back to the original fd (so the server console keeps working),
  splits on \n and \r, and forwards each line to the resp queue as
  {"type":"log","stream":"stdout|stderr","line":...,"ts":...}.
  PYTHONUNBUFFERED=1 is set so nested Python converters flush
  immediately.

- core/export/orchestrator.py:
  - Thread-safe ring buffer (collections.deque, maxlen 4000) with a
    monotonically increasing seq counter. clear_logs(),
    get_logs_since(cursor), get_current_log_seq(), is_export_active().
  - _wait_response handles rtype == "log" by appending to the buffer
    and continuing the wait loop. Status messages are also surfaced as
    a "status" stream so users see high level progress alongside raw
    subprocess output.
  - load_checkpoint, _run_export, and cleanup_memory now wrap their
    bodies with the existing self._lock (previously unused), clear the
    log buffer at the start of each op, and flip _export_active in a
    try/finally so the SSE endpoint can detect idle.

- routes/export.py:
  - Wrapped every sync orchestrator call (load_checkpoint,
    cleanup_memory, export_merged_model, export_base_model,
    export_gguf, export_lora_adapter) in asyncio.to_thread so the
    FastAPI event loop stays free during long exports. Without this
    the new SSE endpoint could not be served concurrently with the
    blocking export POST.
  - New GET /api/export/logs/stream SSE endpoint. Honors
    Last-Event-ID and a since query param for reconnect, emits log /
    heartbeat / complete / error events, uses the id field to carry
    the log seq so clients can resume cleanly. On first connect
    without an explicit cursor it starts from the current seq so old
    lines from a previous run are not replayed.

Frontend

- features/export/api/export-api.ts: streamExportLogs() helper that
  authFetches the SSE endpoint and parses id / event / data fields
  manually (same pattern as streamTrainingProgress in train-api.ts).

- features/export/components/export-dialog.tsx:
  - Local useExportLogs(exporting) hook that opens the SSE stream on
    exporting transitions to true, accumulates up to 4000 lines in
    component state, and aborts on cleanup.
  - New scrollable output panel rendered above DialogFooter, only
    shown for Merged to 16bit and GGUF / Llama.cpp (LoRA adapter is
    a fast disk write with nothing to show). Dark terminal styling
    (bg-black/85, emerald text, rose for stderr, sky for status),
    max-height 14rem, auto-scrolls to the bottom on new output but
    stops following if the user scrolls up. A small streaming / idle
    indicator is shown next to the panel title.
  - DialogContent widens from sm:max-w-lg to sm:max-w-2xl when the
    output panel is visible so the logs have room to breathe.

Verified

- Python smoke test (tests/smoke_export_log_capture.py): spawns a
  real mp.get_context("spawn") process, installs _setup_log_capture,
  confirms that parent stdout prints, parent stderr prints, AND a
  child subprocess invoked via subprocess.run (both its stdout and
  stderr) are all captured in the resp queue. Passes.
- Orchestrator log helpers tested in isolation: _append_log,
  get_logs_since (with and without a cursor), clear_logs not
  resetting seq so reconnecting clients still progress. Passes.
- routes.export imports cleanly in the studio venv and /logs/stream
  shows up in router.routes.
- bun run build: tsc -b plus vite build, no TypeScript errors.

No existing export behavior is changed. If the subprocess, the SSE
endpoint, or the frontend hook fails, the export itself still runs to
completion the same way it did before, with or without logs visible.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* export dialog: trim bootstrap noise, scope logs per screen, show realpath

Several follow-ups to the live export log work:

1. Worker bootstrap noise (transformers venv activation, Unsloth banner,
   "Top GGUF/hub models" lists, vision detection, 2k-step weight load
   bar) is dropped from the export-dialog stream. A threading.Event
   gate in worker.py defaults closed and only opens once _handle_export
   actually starts; until then the reader thread still echoes lines to
   the saved console fd for debugging but does not push them onto the
   resp_queue. The orchestrator already spawns a fresh subprocess for
   every checkpoint load, so the gate is naturally reset between runs.

2. tqdm in non-tty mode defaults to a 10s mininterval, which makes
   multi-step bars look frozen in the panel. Set TQDM_MININTERVAL=0.5
   in the worker env so any tqdm-driven progress emits more often.

3. The dialog's useExportLogs hook now also clears its line buffer
   when exportMethod or open changes, so re-opening the dialog into a
   different action's screen no longer shows the previous action's
   saved output. A useElapsedSeconds tick + "Working Xs" badge in the
   log header gives users a visible sign that long single-step phases
   (cache copies, GGUF conversion) are still running when no new lines
   are arriving.

4. ExportBackend.export_{merged,base,gguf,lora} now return
   (success, message, output_path); the worker forwards output_path on
   each export_*_done response, the orchestrator's _run_export passes
   it to routes/export.py, which surfaces it via
   ExportOperationResponse.details.output_path. The dialog's Export
   Complete screen renders the resolved on-disk realpath under "Saved
   to" so users can find their exported model directly.

* fix(cli): unpack 3-tuple return from export backend

ExportOrchestrator.export_{merged,base,gguf,lora} now return
(success, message, output_path) so the studio dialog can show
the on-disk realpath. The CLI still unpacked 2 values, so every
`unsloth export --format ...` crashed with ValueError before
reporting completion. Update the four call sites and surface
output_path via a "Saved to:" echo.

* fix(studio): anchor export log SSE cursor at run start

The export dialog SSE defaulted its cursor to get_current_log_seq()
at connect time, so any line emitted between the POST that kicks
off the export and the client opening the stream was buffered with
seqs 1..k and then skipped (seq <= cursor). Long-running exports
looked silent during their first seconds.

Snapshot _log_seq into _run_start_seq inside clear_logs() and
expose it via get_run_start_seq(). The SSE default cursor now uses
that snapshot, so every line emitted since the current run began
is reachable regardless of when the client connects. Old runs
still can't leak in because their seqs are <= the snapshot.

* fix(studio): reconnect export log SSE on stream drop

useExportLogs launched streamExportLogs once per exporting
transition and recorded any drop in .catch(). Long GGUF exports
behind a proxy with an idle kill-timeout would silently lose the
stream for the rest of the run even though the backend already
supports Last-Event-ID resume. The "retry: 3000" directive emitted
by the backend is only meaningful to native EventSource; this
hook uses a manual fetch + ReadableStream parse so it had no
effect.

Wrap streamExportLogs in a retry loop that tracks lastSeq from
ExportLogEvent.id and passes it as since on reconnect. Backoff is
exponential with jitter, capped at 5s, reset on successful open.
The loop stops on explicit backend `complete` event or on effect
cleanup.

* fix(studio): register a second command so Typer keeps `export` as a subcommand

The CLI export unpacking tests wrap `unsloth_cli.commands.export.export`
in a fresh Typer app with a single registered command. Typer flattens a
single-command app into that command, so the test's
`runner.invoke(cli_app, ["export", ckpt, out, ...])` treats the leading
`"export"` token as an unexpected extra positional argument -- every
parametrized case failed with:

    Got unexpected extra argument (.../out)

Register a harmless `noop` second command so Typer preserves subcommand
routing and the tests actually exercise the 3-tuple unpack path they
were written to guard.

Before: 4 failed
After:  4 passed

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: studio-install <studio@local.install>
Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
Co-authored-by: Lee Jackson <130007945+Imagineer99@users.noreply.github.com>
Co-authored-by: Roland Tannous <rolandtannous@gravityq.ai>
2026-04-14 08:55:43 -07:00
Daniel Han
eca592effe
studio: show HF model download progress in training start overlay (#4894)
* studio: show HF model download progress in training start overlay

During the training setup phase, the overlay only displayed a static
"Loading model..." line while model weights were being downloaded from
Hugging Face. On slow connections this looked like the app had frozen.

This adds a small self-contained progress block inside the existing
TrainingStartOverlay that polls the existing
GET /api/models/download-progress endpoint and renders a Progress bar
with bytes downloaded, total bytes, and percent complete.

Notes:

- Frontend only change. No backend, worker, SSE, or runtime store edits.
- Reuses the existing getDownloadProgress client wrapper and the
  existing /api/models/download-progress endpoint that already scans
  the HF blob cache for completed and .incomplete files.
- selectedModel is read directly from useTrainingConfigStore inside the
  overlay, so no prop drilling and live-training-view.tsx is unchanged.
- Polling runs at 1500 ms and is gated on the HF repo regex
  (^[A-Za-z0-9._-]+/[A-Za-z0-9._-]+$), the same regex the backend uses,
  so local paths and empty form state never hit the endpoint.
- Polling stops once progress reaches 1.0 so the bar can stay at 100
  until the overlay hides on the first training step.
- Network errors are silently swallowed, matching the chat side flow
  (the bar simply freezes at the last value).
- When downloadedBytes is 0 the block is hidden entirely, so cached
  models do not flash a progress bar.
- When the HF API cannot determine the total size, the block falls
  back to "X downloaded" with no percent and no bar.

Verified with bun run build (tsc -b plus vite build, no TypeScript
errors).

* training overlay: track dataset download + show on-disk realpath

Adds a dedicated "Downloading dataset..." section to the training-start
overlay alongside the existing model-weights one, so an HF dataset that
is downloading mid-startup is no longer mislabeled as model weights or
hidden entirely. The new GET /api/datasets/download-progress endpoint
mirrors /api/models/download-progress against the datasets-- prefix in
HF_HUB_CACHE.

Both endpoints now also return cache_path, the resolved on-disk
realpath of the snapshot directory (or the cache repo root if no
snapshot is materialized yet). The overlay surfaces this under each
download row so users can immediately see where the model and dataset
landed without digging through server logs.

The frontend's existing useModelDownloadProgress hook is generalized
to a single useHfDownloadProgress(repoId, fetcher) hook that the
model and dataset variants both delegate to, keeping polling, gating,
and completion semantics in one place.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Studio: Polish training start overlay download progress UI (#4957)

* studio: polish training start overlay download progress visuals

* Fix formatCachePath cross-platform support and redundant sizeLabel

- Extend formatCachePath regex to also shorten macOS /Users/<user> paths to ~
- Suppress sizeLabel when no byte info is available (cachePath-only state),
  since the "Preparing" badge already conveys the status

* Fix misleading status badge when download total is unknown

- Hide badge when totalBytes is 0 but downloadedBytes > 0, since we cannot
  determine if the download is still in progress or already complete (happens
  when HF size metadata lookup fails for gated/private repos)
- Keep "Preparing" badge for the zero-bytes cachePath-only state
- Add Windows native path shortening to formatCachePath (C:\Users\<name>)

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

---------

Co-authored-by: studio-install <studio@local.install>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Lee Jackson <130007945+Imagineer99@users.noreply.github.com>
2026-04-14 08:54:01 -07:00
Daniel Han
44082cf88e
Studio: anchor ctx-slider warning threshold at 4096 when weights exceed VRAM (#5014)
* Studio: anchor ctx-slider warning threshold at 4096 when weights exceed VRAM

The chat settings sheet's ctx slider reads `max_context_length` from
`/api/inference/status` and renders

    Exceeds estimated VRAM capacity (N tokens). The model may use
    system RAM.

when the user drags the slider above that value. For models whose
weights fit on some GPU subset, `_max_context_length` was already set
to the binary-search cap and the warning fired correctly.

For models whose weights exceed 90% of every GPU subset's free memory
(e.g. MiniMax-M2.7-GGUF at 131 GB on a 97 GB GPU), the ceiling-probe
loop never matched a subset, so `max_available_ctx` stayed at the
native context (e.g. 196608). The slider ran all the way to native
with no indication that any value above the 4096 spec default would
trigger `--fit on` and degrade performance.

Anchor `max_available_ctx` at `min(4096, native_context_length)` when
no subset fits, so the warning fires at the right threshold and the
user sees the correct safe-zone / warning-zone split:

    Before (MiniMax-M2.7 on 97 GB GPU):
      slider 0 .. 196608, warning threshold = 196608  (never fires)

    After:
      slider 0 .. 196608, warning threshold = 4096    (fires correctly)

No frontend changes required: `chat-settings-sheet.tsx` already
consumes `ggufMaxContextLength` (= status.max_context_length) as the
warning threshold and `ggufNativeContextLength` as the slider max.

Adds tests/test_llama_cpp_max_context_threshold.py covering
weights-exceed-VRAM (single / multi-GPU), a native-ctx below the 4096
fallback case (don't lie about supported ctx), fittable-model
regressions (small / multi-GPU / tiny on huge GPU), and the
`max_context_length` property's fallback semantics.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-14 08:53:49 -07:00
Daniel Han
b2f80f210e
Studio: make GGUF disk-space preflight cache-aware (#5012)
* Studio: make GGUF disk-space preflight cache-aware

The pre-download disk check in LlamaCppBackend.load_model compared the
repo's total GGUF size against free disk without crediting bytes
already present in the Hugging Face cache. Re-loading a large cached
model (e.g. MiniMax-M2.7-GGUF at 131 GB) then failed cold with
"Not enough disk space to download any variant" whenever free disk
was below the full weight footprint, even though nothing actually
needed to be downloaded.

Subtract bytes already on disk via try_to_load_from_cache before
comparing against free space. A partial blob (interrupted download) is
not credited, so a second attempt still allocates room to finish the
download. The log line now also surfaces how much is already cached.

Adds tests/test_llama_cpp_cache_aware_disk_check.py covering the
fully-cached, partial-cache-insufficient-disk, partial-cache-enough-disk,
cold-cache, incomplete-blob, and zero-size-path-info cases. Sparse
tempfiles keep the GB-scale scenarios cheap to simulate.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-14 08:53:37 -07:00
Daniel Han
767fa8cade
Studio: honor explicit GGUF ctx and default to 4096 when weights exceed VRAM (#5011)
* Studio: honor explicit GGUF ctx and default to 4096 when weights exceed VRAM

The load-time auto-fit in LlamaCppBackend.load_model had two issues for
models whose weights do not fit on any GPU subset (the common case for
large MoE GGUFs such as MiniMax-M2.7, Qwen3.5-397B-A17B, etc.):

1. Auto mode (max_seq_length=0) left effective_ctx at the model's native
   context when no subset passed the 90% fit check. The UI slider then
   landed on e.g. 196608 for MiniMax-M2.7, far above anything usable.
   Default the auto-pick to 4096 so the UI starts at a sane value; the
   slider ceiling stays at the native context so the user can still
   opt in to longer contexts and receive the "might be slower" warning.

2. Explicit ctx was silently shrunk when weights fit but the requested
   KV overflowed the 90% budget. The shrink loop emitted -c <capped>
   -ngl -1 without informing the caller, so a user who had opted into
   a longer context via the UI never actually got it. Drop the shrink
   loop on the explicit path and emit -c <user_ctx> --fit on instead,
   letting llama-server flex -ngl (CPU layer offload).

Adds tests/test_llama_cpp_context_fit.py covering both paths, the
file-size-only fallback when KV metadata is missing, non-regression on
fittable auto-pick, and platform-agnostic input shape.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-14 08:53:25 -07:00
TF-MTGE
a31c82a640
fix(studio): remove 300s cap on load_checkpoint (inherits 3600s default) (#4922)
* fix: increase wait response timeout to 900 sec instead of 300 sec. #4845

* Apply suggestion from @gemini-code-assist[bot]

good catch

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

---------

Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-04-14 08:53:14 -07:00
Datta Nimmaturi
da78c6be71
[Studio] Install flash attn at setup time for linux (#4979)
* [Studio] Install flash attn at setup time for linux

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup changes

Signed-off-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Test cases

* wheel_utils: narrow url_exists exceptions and log at debug level

---------

Signed-off-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
Co-authored-by: Roland Tannous <rolandtannous@gravityq.ai>
2026-04-14 16:40:17 +04:00
Datta Nimmaturi
dccc0ebada
[Studio] Show non exported models in chat UI (#4892)
* Show non exported models in chat UI

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Distinguish b/w LoRa and full fine tune saves. Cleanup

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
2026-04-14 15:03:58 +04:00
Bharath Kumar Adinarayan
a50f61009b
fix(studio): default chart view to full training history (#5007)
* fix(studio): default chart view to full training history instead of last 80 steps

Fixes #5003

* chore: windowsize as null code comment

---------

Co-authored-by: imagineer99 <samleejackson0@gmail.com>
Co-authored-by: Wasim Yousef Said <wasimysdev@gmail.com>
2026-04-14 03:29:27 -07:00
Lee Jackson
bfa17330bd
Studio: Polish API key copy button and harden async clipboard fallback (#5006)
* fix: polish clipboard style and fix async clipboard path

* Use copyToClipboardAsync in CopyButton for Safari fallback

CopyButton was calling navigator.clipboard.writeText directly,
bypassing the execCommand fallback added in this same PR. Switch
to copyToClipboardAsync which tries execCommand first (Safari
user-gesture requirement) then falls back to the async clipboard API.

* Fix copyToClipboard sync contract regression and improve async path

- Restore copyToClipboard() to return only the execCommand result,
  preserving the boolean contract that 7 existing callers depend on
  to gate their "Copied!" UI state. The fire-and-forget async fallback
  was returning true before the promise resolved, causing false success.

- Add document.body null guard to copyWithExecCommand for SSR safety.

- Reorder copyToClipboardAsync to try the async Clipboard API first,
  avoiding unnecessary DOM/focus overhead in Radix focus-trapped dialogs
  where execCommand always fails anyway.

* Restore queryCommandSupported guard and fix async catch path

- Restore the queryCommandSupported("copy") guard in copyToClipboard()
  to match the original contract exactly: when execCommand is entirely
  unsupported, fall through to fire-and-forget async clipboard write.

- Fix copyToClipboardAsync catch block: after navigator.clipboard.writeText
  rejects, the user-gesture frame is gone, so execCommand will also fail.
  Return false from catch instead of falling through. The execCommand
  fallback at the bottom only runs when the Clipboard API is absent
  (still in user-gesture frame).

* Restore execCommand fallback in copyToClipboardAsync catch path

The catch block was returning false after clipboard API rejection,
based on the incorrect premise that the user-gesture frame is lost
after an await. Per the HTML spec, transient user activation IS
preserved through promise microtask chains. The real reason
execCommand fails in the Radix dialog is the focus trap intercepting
textarea.focus(), not gesture loss.

For non-dialog callers, execCommand can still succeed after a
clipboard rejection. Inside a Radix modal, execCommand returns
false harmlessly (focus trap blocks it).

* Harden textarea fallback for mobile and continue to async path on failure

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: Roland Tannous <rolandtannous@gravityq.ai>
2026-04-14 14:22:14 +04:00
Wasim Yousef Said
97eafd999e
studio: fix api-keys access + refresh (#5005)
* studio: fix api-keys access + refresh

* studio: guard v1 in spa fallback
2026-04-13 23:48:51 +04:00
AdamPlatin123
d2fc582840
studio: skip training status/metrics polling when idle (#4988)
* fix(studio): skip training status/metrics polling when idle

Add an early return in the status and metrics setInterval callbacks when
the runtime store reports phase === "idle" and hasHydrated is true.
Previously these polls fired unconditionally every 3s/5s, generating
unnecessary network traffic and console errors when no training was
running.

* fix(studio): reduce idle polling to 30s instead of stopping entirely

Review feedback (PR #4988): completely stopping polling when idle risks
permanent UI desync if hydration fails, and misses out-of-band state
changes from other clients. Add a 30s background poll that only fires
when idle to recover gracefully.

* fix: harden idle status polling around hydration and runtime reset

---------

Co-authored-by: AdamPlatin123 <AdamPlatin123@users.noreply.github.com>
Co-authored-by: Lee Jackson <130007945+Imagineer99@users.noreply.github.com>
Co-authored-by: imagineer99 <samleejackson0@gmail.com>
2026-04-13 12:02:12 -07:00
Daniel Han
9a261aec5f
Studio: Expose openai and anthropic compatible external API end points (#4956)
* Studio: add API key authentication for programmatic access

External users want to hit the Studio API (chat completions with tool
calling, training, export, etc.) without going through the browser
login flow. This adds sk-unsloth- prefixed API keys that work as a
drop-in replacement for JWTs in the Authorization: Bearer header.

Backend:
- New api_keys table in SQLite (storage.py)
- create/list/revoke/validate functions with SHA-256 hashed storage
- API key detection in _get_current_subject before the JWT path
- POST/GET/DELETE /api/auth/api-keys endpoints on the auth router

Frontend:
- /api-keys page with create form, one-time key reveal, keys table
- API Keys link in desktop and mobile navbar
- Route registered with requireAuth guard

Zero changes to any existing route handler -- every endpoint that uses
Depends(get_current_subject) automatically works with API keys.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use actual origin in API key usage examples

The examples on /api-keys were hardcoded to localhost:8888 which is
wrong for remote users. Use window.location.origin so the examples
show the correct URL regardless of where the user is connecting from.

* Add `unsloth studio run` CLI command for one-liner model serving

Adds a `run` subcommand that starts Studio, loads a model, creates an
API key, and prints a ready-to-use curl command -- similar to
`ollama run` or `vllm serve`.

Usage: unsloth studio run -m unsloth/Qwen3-1.7B-GGUF --gguf-variant UD-Q4_K_XL

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add end-to-end tests for `unsloth studio run` and API key usage

Tests the 4 usage examples from the API Keys page:
1. curl basic (non-streaming) chat completions
2. curl streaming (SSE) chat completions
3. OpenAI Python SDK streaming completions
4. curl with tools (web_search + python)

Also tests --help output, invalid key rejection, and no-key rejection.
All 7 tests pass against Qwen3-1.7B-GGUF.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add /v1/completions, /v1/embeddings, /v1/responses endpoints and --parallel support

- llama_cpp.py: accept n_parallel param, pass to llama-server --parallel
- run.py: plumb llama_parallel_slots through to app.state
- inference.py: add /completions and /embeddings as transparent proxies to
  llama-server, add /responses as application-level endpoint that converts
  to ChatCompletionRequest; thread n_parallel through load_model
- studio.py: set llama_parallel_slots=4 for `unsloth studio run` path

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Make /v1/responses endpoint match OpenAI Responses API format

The existing /v1/responses shim returned Chat Completions format, which
broke OpenAI SDK clients using openai.responses.create(). This commit
replaces the endpoint with a proper implementation that:

- Returns `output` array with `output_text` content parts instead of
  `choices` with `message`
- Uses `input_tokens`/`output_tokens` instead of `prompt_tokens`/
  `completion_tokens` in usage
- Sets `object: "response"` and `id: "resp_..."`
- Emits named SSE events for streaming (response.created,
  response.output_text.delta, response.completed, etc.)
- Accepts all OpenAI Responses API fields (tools, store, metadata,
  previous_response_id) without erroring -- silently ignored
- Maps `developer` role to `system` and `input_text`/`input_image`
  content parts to the internal Chat format

Adds Pydantic schemas for request/response models and 23 unit tests
covering schema validation, input normalisation, and response format.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Studio: add Anthropic-compatible /v1/messages endpoint (#4981)

* Add Anthropic-compatible /v1/messages endpoint with tool support

Translate Anthropic Messages API format to/from internal OpenAI format
and reuse the existing server-side agentic tool loop. Supports streaming
SSE (message_start, content_block_delta, etc.) and non-streaming JSON.
Includes offline unit tests and e2e tests in test_studio_run.py.

* Add enable_tools, enabled_tools, session_id to /v1/messages endpoint

Support the same shorthand as /v1/chat/completions: enable_tools=true
with an optional enabled_tools list uses built-in server tools without
requiring full Anthropic tool definitions. session_id is passed through
for sandbox isolation. max_tokens is now optional.

* Strip leaked tool-call XML from Anthropic endpoint content

Apply _TOOL_XML_RE to content events in both streaming and
non-streaming tool paths, matching the OpenAI endpoint behavior.

* Emit custom tool_result SSE event in Anthropic stream

Adds a non-standard tool_result event between the tool_use block close
and the next text block, so clients can see server-side tool execution
results. Anthropic SDKs ignore unknown event types.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Split /v1/messages into server-side and client-side tool paths

enable_tools=true runs the existing server-side agentic loop with
built-in tools (web_search/python/terminal). A bare tools=[...] field
now triggers a client-side pass-through: client-provided tools are
forwarded to llama-server and any tool_use output is returned to the
caller with stop_reason=tool_use for client execution.

This fixes Claude Code (and any Anthropic SDK client) which sends
tools=[...] expecting client-side execution but was previously routed
through execute_tool() and failing with 'Unknown tool'.

Adds AnthropicPassthroughEmitter to convert llama-server OpenAI SSE
chunks into Anthropic SSE events, plus unit tests covering text
blocks, tool_use blocks, mixed, stop reasons, and usage.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix httpcore GeneratorExit in /v1/messages passthrough stream

Explicitly aclose aiter_lines() before the surrounding async with
blocks unwind, mirroring the prior fix in external_provider.py
(a41160d3) and cc757b78's RuntimeError suppression.

* Wire stop_sequences through /v1/messages; warn on tool_choice

Plumb payload.stop_sequences to all three code paths (server-side
tool loop, no-tool plain, client-side passthrough) so Anthropic SDK
clients setting stop_sequences get the behavior they expect. The
llama_cpp backend already accepted `stop` on both generate_chat_
completion and generate_chat_completion_with_tools; the Anthropic
handler simply wasn't passing it.

tool_choice remains declared on the request model for Anthropic SDK
compatibility (the SDK often sets it by default) but is not yet
honored. Log a structured warning on each request carrying a non-
null tool_choice so the silent drop is visible to operators.

* Wire min_p / repetition_penalty / presence_penalty through /v1/messages

Align the Anthropic endpoint's sampling surface with /v1/chat/completions.
Adds the three fields as x-unsloth extensions on AnthropicMessagesRequest
and threads them through all three code paths: server-side tool loop,
no-tool plain, and client-side passthrough.

The passthrough builder emits "repeat_penalty" (not "repetition_penalty")
because that is llama-server's field name; the backend methods already
apply the same rename internally.

* Fix block ordering and prev_text reset in non-streaming tool path

_anthropic_tool_non_streaming was building the response by appending
all tool_use blocks first, then a single concatenated text block at
the end — losing generation order and merging pre-tool and post-tool
text into one block. It also never reset prev_text between synthesis
turns, so the first N characters of each post-tool turn were dropped
(where N = length of the prior turn's final cumulative text).

Rewrite to build content_blocks incrementally in generation order,
matching the streaming emitter's behavior: deltas within a turn are
merged into the trailing text block, tool_use blocks interrupt the
text sequence, and prev_text is reset on tool_end so turn N+1 diffs
against an empty baseline.

Caught by gemini-code-assist[bot] review on #4981.

* Make test_studio_run.py e2e tests pytest-compatible

Add a hybrid session-scoped studio_server fixture in conftest.py that
feeds base_url / api_key into the existing e2e test functions. Three
invocation modes are now supported:

1. Script mode (unchanged) — python tests/test_studio_run.py
2. Pytest + external server — point at a running instance via
   UNSLOTH_E2E_BASE_URL / UNSLOTH_E2E_API_KEY env vars, no per-run
   GGUF load cost
3. Pytest + fixture-managed server — pytest drives _start_server /
   _kill_server itself via --unsloth-model / --unsloth-gguf-variant,
   CI-friendly

The existing _start_server / _kill_server helpers and main() stay
untouched so the script entry point keeps working exactly as before.
Test function signatures are unchanged — the (base_url, api_key)
parameters now resolve via the new fixtures when running under
pytest.

* Rename test_studio_run.py -> test_studio_api.py

The file is entirely about HTTP API endpoint testing (OpenAI-compatible
/v1/chat/completions, Anthropic-compatible /v1/messages, API key auth,
plus a CLI --help sanity check on the command that runs the API). None
of its tests cover training, export, chat-UI, or internal-Python-API
concerns.

The old name misleadingly suggested "tests for the unsloth studio run
CLI subcommand" — the new name reflects the actual scope.

Updates:
- git mv the file (rename tracked, history preserved)
- Rewrite opening docstring to state the API surface focus and call
  out what is explicitly out of scope
- Update all 4 Usage-block path references to the new filename
- LOG_FILE renamed to test_studio_api.log
- conftest.py fixture import rewritten from test_studio_run to
  test_studio_api, plus 7 docstring/comment references updated

No functional changes to test logic, signatures, or main().

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Fix httpcore asyncgen cleanup in /v1/messages and /v1/completions

The earlier fix in 985e92a9 was incomplete: it closed aiter_lines()
explicitly but still used `async with httpx.AsyncClient()` /
`async with client.stream()` inside the generator. When the generator
is orphaned (e.g. client disconnects mid-stream and Starlette drops
the StreamingResponse iterator without explicitly calling aclose()),
Python's asyncgen finalizer runs the cleanup in a DIFFERENT task than
the one that originally entered the httpx context managers. The
`async with` exits then trigger httpcore's HTTP11ConnectionByteStream
.aclose(), which enters anyio.CancelScope.__exit__ with a mismatched
task and raises RuntimeError("Attempted to exit cancel scope in a
different task"). That error escapes any user-owned try/except
because it happens during GC finalization.

Replace `async with` with manual client/response lifecycle in both
/v1/messages passthrough and /v1/completions proxy. Close the
response and client in a finally block wrapped in
`try: ... except Exception: pass`. This suppresses RuntimeError (and
other Exception subclasses) from the anyio cleanup noise while
letting GeneratorExit (a BaseException, not Exception) propagate
cleanly so the generator terminates as Python expects.

Traceback observed in user report:
  File ".../httpcore/_async/connection_pool.py", line 404, in __aiter__
      yield part
  RuntimeError: async generator ignored GeneratorExit
...
  File ".../anyio/_backends/_asyncio.py", line 455, in __exit__
      raise RuntimeError(
  RuntimeError: Attempted to exit cancel scope in a different task

* Expand unsloth studio run banner with SDK base URL and more curl examples

Add an explicit "OpenAI / Anthropic SDK base URL" line inside the info
box so SDK users don't accidentally copy the bare server URL (without
/v1) into their OpenAI/Anthropic SDK constructors and hit 404s.

Replace the single /v1/chat/completions curl example with three
labeled blocks: chat/completions, Anthropic /messages, and OpenAI
Responses. The Anthropic example includes max_tokens (Anthropic SDKs
require it even though Studio accepts None).

All examples derived from a computed sdk_base_url so the /v1 prefix
stays in sync if the public path ever changes.

* Hash API keys with HMAC-SHA256 + persistent server secret

Stores the HMAC secret in a new app_secrets singleton table. Fixes
CodeQL py/weak-sensitive-data-hashing alert on storage.py:74-76,
394-395. Refresh tokens stay on plain SHA-256 (unchanged _hash_token)
so existing user sessions survive upgrade — API keys are new on this
branch so there is no migration.

* Use PBKDF2 for API key hashing per CodeQL recommendation

HMAC-SHA256 was still flagged by py/weak-sensitive-data-hashing.
Switch to hashlib.pbkdf2_hmac, which is in CodeQL's recommended
allowlist (Argon2/scrypt/bcrypt/PBKDF2). Persistent server-side
salt stays in app_secrets for defense-in-depth. 100k iterations to
match auth/hashing.py's password hasher.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
Co-authored-by: Roland Tannous <rolandtannous@gravityq.ai>
2026-04-13 21:08:11 +04:00
Roland Tannous
3bb72a557f
Pin kernels==0.12.1 to avoid huggingface_hub dataclass conflict (#5000) 2026-04-13 20:42:02 +04:00
Lee Jackson
21a7895959
Studio: Prompt manager, message deletion, and chat UI improvements (#4938)
* feat(chat): code block styling, delete with Dexie sync, settings sheet polish

* style: config save/delete padding fix

* fix(studio): centralize dark code-block surface and optimize message sync writes

* style: config padding/alignment polish

* fix(studio): upsert custom presets without implicit rename-delete

* fix settings sheet save state polish

* fix settings sheet button widths

* fix chat settings presets

* fix chat delete sync

* fix chat trust remote code flow

---------

Co-authored-by: shine1i <wasimysdev@gmail.com>
2026-04-13 16:42:33 +02:00
AdamPlatin123
3b092bcd46
fix(studio): prevent route transition DOM duplication via AnimatePresence (#4987)
Add mode="wait" and exit={{ opacity: 0 }} to the root AnimatePresence
wrapper so outgoing routes fully unmount before incoming routes render.
Without this, rapid navigation between Studio/Export/Recipes/Chat caused
pages to stack (2x–3x duplication).

Co-authored-by: AdamPlatin123 <AdamPlatin123@users.noreply.github.com>
Co-authored-by: Wasim Yousef Said <wasimysdev@gmail.com>
2026-04-13 01:38:00 -07:00
Manan Shah
80c12ff1a6
Move gemma4 script (#4994)
* updating gemma4 script

* moving gemma4 script to scripts folder
2026-04-12 23:41:15 -07:00
Manan Shah
db3b3a4d9b
updating gemma4 script (#4992)
* updating gemma4 script

* show errors
2026-04-12 23:11:32 -07:00
Daniel Han
93a24f6698
Add ROCm test suite for PR #4720 (#4824)
95 Python tests and 23 shell tests covering ROCm detection,
torch index URL selection, hardware flags, prebuilt asset selection,
and install pathway logic. All tests use mocks -- no AMD hardware required.

Companion to #4720 (AMD ROCm/HIP support).
2026-04-11 04:44:13 -07:00
Daniel Han
53af4a1b3e
Fix Gemma-4 GRPO catastrophic KL divergence with TRL 1.0.0+ (#4934)
* Fix Gemma-4 GRPO catastrophic KL divergence with TRL 1.0.0+

Two compounding bugs caused Gemma-4 GRPO training to diverge with KL ~10^12
at step 1 against TRL 1.0.0+. Both fixes are runtime patches in the existing
TRL/model patch flow and are no-ops for models and TRL versions that are not
affected.

Fix 1 (rl.py): replace trl.models.utils.disable_gradient_checkpointing with
a no-op context manager. TRL 1.0.0+ wraps generation in
`with torch.no_grad(), disable_gradient_checkpointing(self.model, ...):`
purely to suppress a cosmetic PyTorch warning ("None of the inputs have
requires_grad=True"). Inside torch.no_grad() the gradient checkpointing
state has no functional effect on the forward pass. On context exit, TRL
calls model.gradient_checkpointing_enable() which dispatches to HF's
generic implementation and overwrites Unsloth's custom
`use_gradient_checkpointing="unsloth"` wrapper, corrupting Gemma-4 forward
numerics. Replacing the toggle with a no-op preserves Unsloth's custom GC
wrapper across generation passes. The patch walks sys.modules dynamically
to also rebind the symbol on every trl.* module that already imported it
(grpo_trainer, dpo_trainer, rloo_trainer, dppo_trainer, gfpo_trainer,
grpo_with_replay_buffer_trainer, and any future trainer module).

Fix 2 (vision.py): inject `final_logit_softcapping` from `config.text_config`
into the top-level `model.config` for multimodal models. Unsloth's GRPO
trainer reads `getattr(model.config, "final_logit_softcapping", 0)` but
for Gemma-4 the attribute lives only on the nested `Gemma4TextConfig`,
so the lookup silently defaults to 0 instead of 30.

Backwards compatibility:
- trl 0.22.2: no `disable_gradient_checkpointing` symbol exists, the patch
  early-returns via `hasattr` guard.
- trl 0.27.1: same broken pattern as 1.0.0, the noop replacement is correct.
- trl 1.0.0+: end-to-end verified on `unsloth/gemma-4-E2B-it` GRPO with TRL
  1.0.0 and transformers 5.5.0. Step 1 loss=2.46e-08, kl=2.92e-05 (machine
  zero) vs broken baseline loss=1.37e+06, kl=1.76e+09.
- Llama / non-VLM text models: Fix 2 is a no-op (no `text_config`); Fix 1
  is functionally identical (Unsloth's GC wrapper is preserved).
- Qwen3-VL and other VLMs without final_logit_softcapping: Fix 2 is a no-op
  (text_config.final_logit_softcapping is None).

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply loop 1 review fixes for PR #4934

- Move Fix 2 from vision.py to rl_replacements.py:858 and :1110 at the
  actual consumer sites. This avoids mutating model.config (which could
  leak into save_pretrained output) and covers text-only Gemma-4 paths
  that do not flow through FastBaseModel.from_pretrained.
- Revert the vision.py injection block entirely.
- Narrow the bare except blocks in patch_trl_disable_gradient_checkpointing
  from `except Exception:` to `(AttributeError, ImportError)` and
  `(AttributeError, TypeError)` to avoid masking unrelated bugs.
- Add logger.warning_once when the noop patch is installed, matching
  patch_trl_openenv and patch_trl_vllm_generation convention.
- Remove the dead per-module `_unsloth_noop_patched` sentinel check inside
  the sys.modules walk. The function-level early return already covers
  this case.
- Move `import sys` and `from contextlib import contextmanager` to the
  module-level imports instead of inside the function body.
- Rewrite the ordering comment in PatchFastRL to accurately describe
  why patch_trl_disable_gradient_checkpointing must run before
  patch_trl_rl_trainers.
- Fix keyword default spacing to match surrounding rl.py style.

End-to-end verified: Gemma-4-E2B GRPO on TRL 1.0.0 + transformers 5.5.0
step 1 loss=2.464e-08 kl=2.921e-05, all 5 steps succeed.

* Apply loop 2 review fix for PR #4934

Extract the final_logit_softcapping fallback logic into a shared helper
`_unsloth_get_final_logit_softcapping(config)` defined in rl_replacements.py
and injected into the compiled cache via RL_PRE_ITEMS["grpo_trainer"]. Both
call sites (`grpo_trainer__generate_and_score_completions` and
`grpo_trainer_compute_loss`) now use the helper instead of inlining the
same text_config fallback block twice.

Verified: compiled cache file lists the helper at module scope and both
consumer sites call it. Gemma-4-E2B GRPO step 1 loss=2.464e-08 kl=2.921e-05
(unchanged), all 5 steps pass.

* Apply loop 3 review fix for PR #4934

Extend _unsloth_get_final_logit_softcapping to also fall back to
config.get_text_config() for composite configs such as T5GemmaConfig
where the text sub-config is not exposed via the text_config attribute
but only via the get_text_config() method. Guard against (TypeError,
ValueError) raised by ambiguous composite configs, and skip the
self-referential case where get_text_config() returns self.

This addresses the 6/7 reviewer consensus from the third review loop.

Verified:
- Helper returns 30.0 for Gemma-4, T5Gemma, and Gemma 1/2 configs.
- Helper returns 0 for Llama, Qwen, Mistral, Cohere, Granite, and
  ambiguous configs raising ValueError.
- Gemma-4-E2B GRPO step 1 loss=2.464e-08 kl=2.921e-05 (unchanged).
- Llama-3.2-1B GRPO all 5 steps loss=0 kl=0 (no regression).

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-10 07:58:15 -07:00
Daniel Han
65b4028560
Pin bitsandbytes to continuous-release_main on ROCm (4-bit decode fix) (#4954)
* Pin bitsandbytes to continuous-release_main on ROCm for 4-bit decode fix

bitsandbytes 0.49.2 on PyPI ships with a broken 4-bit GEMV kernel on
every ROCm target:

  - CDNA (gfx90a / gfx942 / gfx950 = MI210 / MI300X / MI350) via a
    broken blocksize=32/64 warp64 GEMV kernel whose tests were
    explicitly skipped with ROCM_WARP_SIZE_64 guards because the
    code was known broken.
  - RDNA3 / RDNA3.5 (gfx1100-1103 / gfx1150-1152) via a compile-time
    BNB_WARP_SIZE macro in the host-side dispatch that resolves to
    64 when the multi-arch wheel is compiled with CDNA as the
    primary target, so num_blocks is wrong on RDNA and half the GEMV
    output is never written.

At decode shape (1, 1, hidden) both bugs produce NaN. Training is
unaffected because training shapes are (batch, seq_len > 1, hidden)
and never touch the GEMV path. The crash during autoregressive
inference surfaces as _assert_async_cuda_kernel in torch.multinomial
which on HIP becomes a hard HSA_STATUS_ERROR_EXCEPTION instead of
a clean Python error.

Both bugs are fixed by bitsandbytes commit 713a3b8 ("[ROCm] Enable
blocksize 32 4-bit quantization and GEMV kernels on AMD CDNA",
PR #1887, merged 2026-03-09) which replaces BNB_WARP_SIZE with a
runtime hipDeviceGetAttribute query and ships a working CDNA warp64
kernel. That commit has not shipped to PyPI yet, but
continuous-release_main wheels are published on every push to bnb
main via GitHub Releases.

Point the ROCm install path at the continuous-release_main x86_64 and
aarch64 wheels and fall back to PyPI >=0.49.1 when the pre-release is
unreachable (offline installs, firewalled hosts, or architectures not
covered by the pre-release wheels). Drop the pin once bnb cuts a
0.50+ tag on PyPI.

Verified on MI300X (gfx942, ROCm 7.2, torch 2.10.0+rocm7.1): direct
bnb GEMV shape test now returns 0.0078 max abs error at seq_len=1
(no NaN) vs NaN on 0.49.2, and full Unsloth + for_inference + 4-bit
sampling generation works end-to-end.

NVIDIA / CPU / Mac / Windows paths are unaffected -- the helper is
gated on the ROCm torch index and platform.machine() respectively.

* Drop Studio ROCm 16-bit fallback now that bnb 0.50+ fixes 4-bit decode

The 16-bit fallback in studio/backend/core/inference/inference.py was
added as a workaround for a bug that this PR already fixes at the
install layer: bitsandbytes <= 0.49.2 has a broken 4-bit GEMV kernel
on every ROCm target, which NaNs at decode shape (seq_len=1) and
crashes autoregressive inference. bnb PR #1887 (commit 713a3b8, in
0.50.0.dev0+, pinned by install.sh / install_python_stack.py in this
PR) restores correct 4-bit decode on MI300X and verified working
end-to-end with full Unsloth + for_inference + sampling.

Revert the dual code path so ROCm and NVIDIA both go through the
normal FastLanguageModel.from_pretrained + for_inference flow:

  - Remove the conditional `from unsloth import` that skipped the
    import on ROCm. The monkey-patches it was trying to avoid were
    never the cause of the crash; bnb 4-bit GEMV was.
  - Remove the `if _hw_module.IS_ROCM:` branch in load_model that
    loaded with plain transformers + PEFT + bfloat16, and the
    `_resolve_fp16_base` helper it relied on.
  - Remove the `get_chat_template is not None` fallback in
    _load_chat_template_info -- get_chat_template is now always
    imported.
  - Refactor the audio/vision ROCm guard to check _hw_module.IS_ROCM
    directly instead of the removed _IS_ROCM_ENV global. Audio and
    vision on ROCm still need separate validation (FastVisionModel
    and the CSM audio codecs were never tested on HIP) so the guard
    stays for now.

Add _bnb_rocm_4bit_ok() as a runtime safety net for users who
install from this PR before the install.sh bnb pin kicks in, or
whose installer fell back to the PyPI pin because the continuous-
release wheel was unreachable. When the installed bnb is < 0.50 on
ROCm, force load_in_4bit=False and strip any -unsloth-bnb-4bit /
-bnb-4bit suffix from the model path so a pre-quantized repo
resolves to its FP16 sibling instead of pulling bnb back in via
the repo's quantization_config. LoRA adapters whose base is a
pre-quantized repo on old bnb will still fail inside Unsloth's
loader -- the only real fix there is `unsloth studio update`.

Verified on MI300X (gfx942, ROCm 7.2, torch 2.10.0+rocm7.1):

  - HAPPY path (bnb 0.50.0.dev0, load_in_4bit=True, pre-quantized
    repo): loads in 4-bit via the fixed GEMV, generation returns
    "Paris." for greedy and sampling.
  - SAFETY-NET path (simulated old bnb, suffix-stripped to the
    FP16 sibling, load_in_4bit=False): loads in bf16, generation
    returns "Paris." for greedy and sampling.

Net diff is ~45 lines smaller than the pre-revert state because
the entire plain-transformers 16-bit branch is gone.

* Cache _bnb_rocm_4bit_ok() with functools.cache

load_model() can be called many times in a single session but the bnb
version and hardware state cannot change at runtime, so memoise the
check. First call is ~1.9 ms (dominated by the lazy `import bitsandbytes`
inside the try block), subsequent calls drop to sub-microsecond dict
lookups. Zero behavioral change.

* Shorten verbose bnb/ROCm comments

Comment-only cleanup across install.sh, studio/install_python_stack.py,
and studio/backend/core/inference/inference.py. No behavioral change.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove _bnb_rocm_4bit_ok safety net from inference.py

Studio's ROCm support is brand new (PR #4720, merged today) and every
fresh install pulls the bnb continuous-release_main wheel via
install.sh / install_python_stack.py in this same PR. There are no
existing ROCm Studio installs carrying bnb < 0.50, so the defensive
version-check fallback is guarding against a scenario that cannot
actually occur. Delete the helper, the functools import, and the
safety-net block -- inference.py now calls FastLanguageModel.from_pretrained
directly with no ROCm branching.

* Drop audio/vision ROCm guard in inference.py — verified unblocked by bnb fix

Vision inference was blocked by the same bnb 4-bit GEMV bug that affected
text inference (vision models use bnb 4-bit for the LM backbone). With
bnb 0.50+ pinned in install.sh / install_python_stack.py, vision works
end-to-end on MI300X: Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit
loaded in 4-bit via FastVisionModel + for_inference returns a correct
answer to a multimodal prompt.

Audio (CSM) was never actually blocked by HIP — on this hardware CSM
loads and runs its backbone forward pass fine with bnb 0.50, then fails
during generate() with a transformers-level kwarg validation mismatch
in generation_csm.py (`backbone_last_hidden_state` rejected). That's a
pre-existing transformers/CSM integration bug that reproduces identically
on NVIDIA, so the ROCm-gated guard was never actually protecting users
from anything HIP-specific.

Remove the combined audio/vision guard and the now-unused _hw_module
import. Also restore the one-word "Can be" in an inline comment that
drifted during the earlier comment-shortening pass, so the inference.py
delta vs pre-#4720 is exactly the max_seq_length<=0 crash fix and
nothing else.

* Shorten max_seq_length=0 guard comment to one line

---------

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-10 06:25:39 -07:00
Daniel Han
cad8c6ad05
Add AMD ROCm/HIP support across installer and hardware detection (#4720)
* Add ROCm detection to install.sh and expand shell tests

Add AMD ROCm GPU detection to get_torch_index_url() in install.sh.
When nvidia-smi is not found, probe for ROCm via amd-smi, /opt/rocm
version file, hipconfig, dpkg-query, and rpm.

Includes validation guard for malformed _rocm_tag, Debian epoch prefix
stripping, ROCm 7.2+ cap to rocm7.1 index, bitsandbytes AMD install,
and status messaging. Shell tests expanded to 23 cases.

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* Add ROCm torch reinstall support to install_python_stack.py

Add _detect_rocm_version() and _ensure_rocm_torch() to detect when a
Linux host has ROCm but the venv received CPU-only torch, and reinstall
with the correct ROCm wheels. Covers ROCm 6.0 through 7.1 with a
30-second timeout on the torch GPU probe subprocess.

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* Add ROCm support to llama.cpp prebuilt installer

Add has_rocm field to HostInfo, extend detect_host() to probe for ROCm
via hipcc/amd-smi/rocm-smi/ROCM_PATH, and route ROCm hosts to upstream
prebuilts (Linux ROCm 7.2 prebuilt with source fallback, Windows HIP
prebuilt with CPU fallback). Add linux-rocm and windows-hip install
kinds to runtime_patterns_for_choice().

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* Add IS_ROCM hardware flag and fix AMD error message

Add IS_ROCM flag to hardware.py detect_hardware() (set when
torch.version.hip is present, DeviceType stays CUDA). Export IS_ROCM
from __init__.py. Add "rocm" key to get_package_versions().

Replace "We do not support AMD" error in tokenizer_utils.py with a
helpful message pointing to ROCm installation docs.

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* Add comprehensive ROCm support test suite (68 tests)

Add tests/studio/install/test_rocm_support.py covering all ROCm code
paths across install_llama_prebuilt.py, install_python_stack.py,
hardware.py, tokenizer_utils.py, and install.sh. All tests use mocks
and run without AMD hardware.

Covers: asset selection (11), runtime patterns (5), HostInfo (4),
ROCm version detection (9), torch reinstall (9), index mapping (8),
hardware flag (8), tokenizer message (2), install.sh structure (10),
and live regression (1).

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Harden ROCm support: probe error handling, version cap, validation

Address review findings from 8 independent reviewers:

- Wrap _ensure_rocm_torch() torch probe in try/except for
  TimeoutExpired and OSError so a hung or broken torch import does not
  crash the installer (8/8 reviewers flagged this)
- Add torch>=2.4,<2.11.0 version cap to the ROCm reinstall path to
  prevent installing unsupported torch 2.11.0 from the rocm7.1 index
- Use with-statement for file reads in _detect_rocm_version() to avoid
  resource leaks
- Handle ROCM_PATH="" correctly (use `or "/opt/rocm"` instead of
  default parameter to avoid relative path resolution)
- Strengthen shell validation guard from rocm[0-9] to rocm[1-9] to
  reject rocm0.x tags that would produce nonexistent PyTorch index URLs
- Switch shell version cap from blocklist to allowlist (rocm6.*|rocm7.0*
  |rocm7.1* pass through, everything else caps to rocm7.1) so future
  ROCm 10+ does not fall through to a nonexistent index
- Add sorted() to _ROCM_TORCH_INDEX lookup for defensive ordering
- Fix test_probe_timeout_handled: replace zero-assertion test with
  proper assertions verifying reinstall proceeds after timeout

* Clean up rocm_paths list construction in detect_host()

Filter None from the ROCM_PATH env var lookup at list construction time
instead of relying on the inline `if p` guard in the any() call.

* Require actual AMD GPU presence before selecting ROCm paths

All 8 reviewers across 2 cycles independently flagged that ROCm
detection used toolkit/filesystem hints (hipcc, /opt/rocm, rocm-core)
as a proxy for GPU presence, which would misroute CPU-only or NVIDIA
hosts that happen to have ROCm tools installed.

Now all 3 detection points (install.sh, install_python_stack.py,
install_llama_prebuilt.py) probe for an actual AMD GPU before
entering the ROCm path:

- install.sh: check rocminfo for gfx* GPU names, or amd-smi list
  for device rows, before version detection
- install_python_stack.py: new _has_rocm_gpu() function probes
  rocminfo and amd-smi list before _ensure_rocm_torch() proceeds
- install_llama_prebuilt.py: detect_host() probes rocminfo/amd-smi
  list instead of just checking tool existence or directory paths

Also:
- Shell test mock amd-smi now handles "list" subcommand
- Python tests updated to mock _has_rocm_gpu where needed
- Added test_no_gpu_with_rocm_tools_skips to verify the new guard
- Test index lookups now use sorted() to match production code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Harden hipconfig version parsing and torch probe compatibility

- Add parts[1].isdigit() check in hipconfig version parsing to handle
  versions like "6.3-HIP" where the minor component has non-numeric
  suffix (strip "-" prefix before int() conversion)
- Use getattr() in torch probe subprocess to safely handle old or
  custom torch builds that may lack torch.version.hip/cuda attributes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Strengthen AMD GPU detection and add NVIDIA precedence guard

- Change amd-smi list detection from any-non-empty-output to requiring
  "gpu" marker in output, matching the shell-side NR>1 check. Prevents
  false positives from header-only amd-smi list output.
- Add nvidia-smi check at the top of _ensure_rocm_torch() so mixed
  AMD+NVIDIA hosts preserve NVIDIA precedence (matching install.sh and
  install_llama_prebuilt.py behavior).
- Apply the same amd-smi marker fix to install_llama_prebuilt.py
  detect_host() for consistency.

* Add Windows-specific ROCm/HIP detection in detect_host()

The previous detect_host() ROCm check used rocminfo and amd-smi list
which are Linux-only tools. On Windows, has_rocm would always be False,
making the Windows HIP prebuilt path at line 1794 unreachable.

Now detect_host() uses platform-specific detection:
- Linux: rocminfo (check for gfx GPU names) or amd-smi list
- Windows: hipinfo.exe, amd-smi, or amdhip64.dll on PATH

This allows Windows AMD users to get the HIP prebuilt binary instead
of silently falling through to the CPU prebuilt.

* Add AMD ROCm gaps: Mamba/SSM source builds, GPU monitoring, Windows messaging, RDNA expansion

- worker.py: Add HIP detection to causal-conv1d/mamba-ssm probe, check
  for hipcc before ROCm source builds, improve status messages and error
  reporting, add timeout and uv support for the source build fallback
- amd.py: New AMD GPU monitoring module via amd-smi metric --json,
  mirroring nvidia.py structure (utilization, temperature, power, VRAM)
- hardware.py: Branch to amd.py when IS_ROCM is True for GPU utilization,
  visible GPU queries, and physical GPU count
- install_python_stack.py: Detect AMD GPUs on Windows and warn that
  ROCm-enabled PyTorch must be installed manually
- kernels/utils.py: Expand is_rdna() to cover RDNA2 (gfx1030-1032),
  RDNA3 (gfx1102-1103), RDNA3.5 (gfx1150-1152) alongside existing entries
- tests: Add 32 new tests covering all changes (95/95 pass)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Harden ROCm detection, fix VRAM heuristic, and expand RDNA2 coverage

- Windows ROCm detection: validate actual GPU presence via hipinfo/amd-smi
  output markers instead of just checking tool existence on PATH
- _ensure_rocm_torch: validate nvidia-smi actually reports a GPU before
  giving NVIDIA precedence (fixes AMD-only hosts with stale NVIDIA tools)
- amd.py _parse_numeric: handle dict-shaped metric objects from newer
  amd-smi versions ({"value": 10, "unit": "W"}) and strip MiB/GiB units
- amd.py VRAM heuristic: raise threshold from 100k to 10M to correctly
  handle MI300X (192 GB = 196608 MB) and other high-VRAM GPUs
- amd.py visible GPU: use AMD-reported GPU IDs instead of enumerate index
  so non-dense sets like CUDA_VISIBLE_DEVICES=1,3 report correctly
- install.sh: add ROCm <6.0 minimum version guard (no PyTorch wheels
  exist for older versions); fix rocm7.1* glob to not match rocm7.10+
- is_rdna: add gfx1033-1036 for RDNA2 mobile GPUs (RX 6600M etc.)
- worker.py: increase ROCm source build timeout from 600s to 1800s;
  fix success log message for ROCm source builds
- Tests: update mocks for _has_usable_nvidia_gpu, add RDNA2 target asserts

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add HIP_VISIBLE_DEVICES support, unit-aware VRAM parsing, Windows GPU validation

- hardware.py: check HIP_VISIBLE_DEVICES and ROCR_VISIBLE_DEVICES on ROCm
  before falling back to CUDA_VISIBLE_DEVICES, so multi-GPU AMD setups with
  HIP-specific env vars report the correct visible device set
- amd.py: add _parse_memory_mb() that reads "unit" from dict-shaped amd-smi
  JSON (e.g. {"value": 192, "unit": "GiB"}) and converts to MB correctly;
  fixes MI300X VRAM misreported as 0.19 GB instead of 192 GB
- install_python_stack.py: Windows AMD warning now validates actual GPU
  presence via hipinfo/amd-smi output markers before printing
- install_llama_prebuilt.py: restore amdhip64.dll fallback for Windows HIP
  detection after tool-based checks, so Windows HIP installs without CLI
  tools on PATH are still detected
- hardware.py: fix IS_ROCM comment to accurately describe its role

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix HIP_VISIBLE_DEVICES empty-string handling in GPU visibility spec

Use explicit None checks instead of Python `or` operator when reading
HIP_VISIBLE_DEVICES / ROCR_VISIBLE_DEVICES, so that an empty string
("") is correctly honored as "no visible GPUs" rather than silently
falling through to CUDA_VISIBLE_DEVICES on mixed ROCm+CUDA systems.

* Fix IS_ROCM test assertion for multi-line formatting

* Cap torchvision/torchaudio versions, remove amdhip64.dll fallback, fix visible GPU count

- Cap torchvision<0.26.0 and torchaudio<2.11.0 alongside torch<2.11.0 in
  both install.sh and install_python_stack.py to prevent resolver from
  selecting incompatible companion packages from ROCm wheel index
- Remove amdhip64.dll fallback in Windows ROCm detection (DLL presence
  without hipinfo/amd-smi is not proof of GPU existence)
- Fix get_visible_gpu_count() to use _get_parent_visible_gpu_spec() which
  respects HIP_VISIBLE_DEVICES/ROCR_VISIBLE_DEVICES on ROCm hosts

* Attribute is_rdna() RDNA2/3/3.5/4 expansion to PR #4428

The is_rdna() expansion to cover RDNA2 (gfx1030-1036), RDNA3
(gfx1100-1103), RDNA3.5 (gfx1150-1152), and RDNA4 (gfx1200-1201)
architectures is based on the original work from PR #4428.

Co-authored-by: GoldenGrapeGentleman <yueyuan@amd.com>
Co-authored-by: billishyahao <bill.he@amd.com>

* Support AMD Radeon for studio (#4770)

Co-authored-by: Iswarya Alex <iswarya.alex@amd.com>

* Remove ROCm test files from main PR

Move test_rocm_support.py and shell test additions to a separate PR
to keep the main ROCm support PR focused on implementation changes.

* Fix installer and hardware detection issues for PR #4720

- Fix empty _tri_arg passed to uv pip install in Radeon path (causes
  "Empty field is not allowed for PEP508" error)
- Fix Radeon fallback: use ROCm index instead of CPU-only when
  repo.radeon.com is unreachable (TORCH_INDEX_URL already has ROCm)
- Use $TORCH_CONSTRAINT in fallback paths instead of hardcoded strings
- Fix _pick_radeon_wheel: relax suffix to match manylinux_2_28_x86_64
  wheels (AMD Radeon repo does not use bare linux_x86_64 platform tag)
- Fix IS_ROCM export: use __getattr__ so callers always see the live
  value after detect_hardware() runs
- Fix apply_gpu_ids: set HIP_VISIBLE_DEVICES and ROCR_VISIBLE_DEVICES
  on ROCm so _get_parent_visible_gpu_spec picks up narrowed GPU set
- Fix _parse_memory_mb: distinguish GB (1000 MB) from GiB (1024 MiB)
- Add amd-smi version as a fallback in _detect_rocm_version
- Fix trailing whitespace and missing newline at EOF in install.sh

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix GPU detection false positives and add missing health groups

- Fix _has_rocm_gpu() false positive: require "GPU: <number>" data rows
  from amd-smi list, not just header containing "gpu"
- Apply same fix in detect_host() in install_llama_prebuilt.py
- Add runtime_payload_health_groups for linux-rocm and windows-hip so
  partial/corrupt ROCm/HIP prebuilt installs are properly detected
- Add bitsandbytes install to Radeon fallback paths (was only in the
  success path, skipped when repo.radeon.com was unreachable)
- Keep DEVICE/CHAT_ONLY as direct imports in __init__.py (matching main)
  and only use __getattr__ for IS_ROCM

* Fix _ensure_rocm_torch and Windows AMD warning false positives

- _ensure_rocm_torch: only skip when HIP is already present, not for
  CUDA builds (which are unusable on AMD-only hosts). Fixes the case
  where a venv has a stale CUDA wheel and the repair step is skipped.
- Windows AMD warning: use GPU data row check (same as Linux fix) to
  avoid false positives from amd-smi list header-only output.

* Fix amd-smi GPU detection for GPU[N] output format

Older amd-smi versions output "GPU[0] : Card series: ..." instead of
"GPU: 0". The regex now matches both "GPU: <digit>" and "GPU[<digit>"
formats to detect actual GPU data rows.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Harden AMD GPU detection against false positives

- install.sh: replace weak amd-smi list check (awk 'NR>1 && NF') with
  strict pattern matching GPU data rows (/^GPU[[:space:]]*[:\[]/)
- All files: reject rocminfo gfx000 (CPU HSA agent) by requiring
  gfx[1-9] instead of gfx[0-9] in the rocminfo GPU probe
- Fixes false positives on hosts with ROCm tools but no AMD GPU

* Remove duplicate comment from pre-commit merge

* Refactor: deduplicate AMD detection, consolidate bitsandbytes, clean up imports

- Extract _has_amd_rocm_gpu() shell function to avoid duplicating the
  rocminfo/amd-smi GPU detection logic in get_torch_index_url and
  the Radeon auto-detect block
- Consolidate bitsandbytes install into a single case block after torch
  install (was duplicated 4 times across Radeon success/fallback paths)
- Move math and re imports to top of amd.py (were inline in functions)
- Add _smi_query() helper in hardware.py to centralize IS_ROCM backend
  selection for get_gpu_utilization and get_visible_gpu_utilization

Addresses Gemini code review suggestions.

* Fix VRAM parsing for string values and GB/GiB consistency

- Extract unit from string-valued VRAM fields (e.g. "192 GiB") so
  _parse_memory_mb correctly applies the unit multiplier instead of
  treating the value as bare MB
- Treat GB and GiB identically (both as binary x1024) since GPU tools
  including amd-smi use binary units even when labeling them "GB"
- Fixes incorrect VRAM reporting on MI300-class cards (was showing
  ~0.19 GB instead of 192 GB for string-valued outputs)

* Add --no-cache to uv for ROCm HIP source builds

Avoid stale cache artifacts from partial HIP source builds when
uv is used for causal-conv1d/mamba-ssm compilation on ROCm.
The pip path already uses --no-cache-dir; this adds the uv equivalent
(--no-cache) only when is_hip is True.

* Fix critical: initialize _amd_gpu_radeon before case block

_amd_gpu_radeon was only set inside the */rocm*) case arm, so on
NVIDIA/CPU/macOS paths where TORCH_INDEX_URL does not contain "rocm",
the variable was unbound. With set -u (nounset) enabled, this crashes
the installer for every non-AMD user.

Move initialization to before the case block so it is always defined.

* Fix Windows AMD: route has_rocm hosts to HIP prebuilt path

resolve_release_asset_choice was selecting windows-cpu for all Windows
x86_64 hosts including those with has_rocm=True. Windows AMD users
should fall through to resolve_upstream_asset_choice which tries the
HIP prebuilt first. Add "not host.has_rocm" guard to the published
windows-cpu selection.

* Harden ROCm detection, Radeon wheel fallback, and HIP visibility

Addresses review findings from parallel reviewers on PR #4720:

- install.sh: add _has_usable_nvidia_gpu() helper requiring nvidia-smi -L
  to actually list a GPU before treating the host as NVIDIA. Fixes the
  stale-nvidia-smi-on-PATH regression where AMD-only hosts fell into the
  CUDA branch.
- install.sh: fix hipconfig awk blocks to propagate a non-zero exit code
  when the output is not a recognisable version string, so the ||-chain
  continues to dpkg-query / rpm instead of terminating early.
- install.sh: fail-closed on Radeon wheel fallback. When torch,
  torchvision or torchaudio is missing from the Radeon repo for the
  active Python tag, fall back to the standard ROCm index instead of
  silently mixing Radeon wheels with PyPI defaults. Quote all wheel
  arguments individually so wheel filenames cannot be word-split or
  glob-expanded.
- install_llama_prebuilt.py: detect_host() now requires nvidia-smi -L to
  list a GPU before setting has_physical_nvidia. Routes AMD ROCm hosts
  with a broken leftover nvidia-smi to the ROCm path instead of
  misclassifying them as NVIDIA.
- install_llama_prebuilt.py: scan upstream assets for any rocm-<version>
  prebuilt instead of hard-coding rocm-7.2, so ROCm 6.x / 7.0 / 7.1 / 7.3+
  users pick up a matching upstream prebuilt when one exists.
- install_llama_prebuilt.py: validate_server() adds --n-gpu-layers 1 for
  linux-rocm and windows-hip hosts, so new HIP prebuilts are preflighted
  on the GPU path instead of passing validation on CPU only.
- install_llama_prebuilt.py: restore the published windows-cpu fallback
  for AMD Windows hosts without a HIP prebuilt so hash-approved bundles
  are still preferred over the raw upstream CPU asset.
- install_python_stack.py: drop the /opt/rocm / hipcc gate in
  _ensure_rocm_torch() and rely on _has_rocm_gpu(). Runtime-only ROCm
  installs (package-managed minimal installs, Radeon software) that ship
  amd-smi / rocminfo without hipcc can now repair a CPU-only venv via
  "unsloth studio update". Adds an explicit IS_WINDOWS / IS_MACOS guard.
- studio/backend/utils/hardware/amd.py: honour HIP_VISIBLE_DEVICES /
  ROCR_VISIBLE_DEVICES / CUDA_VISIBLE_DEVICES in
  get_primary_gpu_utilization(). A process restricted to GPU 2 now
  reports metrics for GPU 2 instead of physical GPU 0. Tighten the plain
  bytes unit detection to an explicit allowlist.
- studio/backend/utils/hardware/hardware.py: route
  get_backend_visible_gpu_info()'s backend_cuda_visible_devices field
  through a helper that reads HIP_VISIBLE_DEVICES on ROCm. Drop the
  unconditional "(rocm=False)" suffix in apply_gpu_ids() logs.

* Fix round 2 regressions: ROCm validate_server and Windows HIP routing

Follow-up to 810b833b addressing review findings on the first round of
hardening commits:

- install_llama_prebuilt.py validate_server: gate --n-gpu-layers on the
  resolved install_kind instead of host.has_rocm. AMD Windows hosts
  without a HIP prebuilt fall back to windows-cpu and must not be
  validated with GPU layers; thread install_kind through from the
  caller.
- install_llama_prebuilt.py resolve_release_asset_choice: reinstate the
  "not has_rocm" guard on the published windows-cpu bundle so AMD
  Windows hosts reach resolve_upstream_asset_choice() where the new
  HIP prebuilt path lives. Prefer a published windows-hip bundle first
  when one exists, fall through to upstream HIP + upstream CPU
  otherwise.
- install_llama_prebuilt.py detect_host: also set has_physical_nvidia
  when the secondary --query-gpu block confirms a working NVIDIA GPU,
  so older nvidia-smi versions without -L support do not silently skip
  the Linux diagnostics that key off has_physical_nvidia.
- install_llama_prebuilt.py: drop redundant "import re as _re" /
  "import re as _re_rocm" local aliases in favour of the existing
  top-level "import re".
- install_python_stack.py _ensure_rocm_torch: run the AMD
  bitsandbytes install unconditionally after the HIP-torch probe so
  "unsloth studio update" on venvs that already have ROCm torch still
  gains the AMD bitsandbytes build.
- install.sh: add a non-x86_64 early-exit to get_torch_index_url() so
  aarch64 / arm64 Linux hosts do not hit the ROCm wheel index
  (PyTorch only publishes ROCm wheels for linux_x86_64).
- install.sh: add bitsandbytes install to the migrated-environment
  branch so upgrades pick it up for ROCm hosts instead of only the
  fresh-install path.
- install.sh: in the Radeon wheel path, pass version constraints +
  --no-index --find-links to uv instead of explicit wheel URLs so a
  version-compatible torch / torchvision / torchaudio triple is
  resolved, rather than picking the highest-version wheel for each
  package independently.
- studio/backend/utils/hardware/amd.py _first_visible_amd_gpu_id: fall
  through to lower-priority visibility env vars when the first entry
  is malformed (leading comma, all-whitespace first token) instead of
  silently returning GPU 0.

* Fix round 3 findings: x86_64 guard, ROCm version clip, Radeon deps

Address issues surfaced by the round 3 reviewers on top of 8636fa63:

- install_python_stack.py _ensure_rocm_torch: add the same `x86_64`
  guard that install.sh already has. Linux aarch64 / arm64 ROCm hosts
  must skip the repair path entirely; PyTorch only publishes ROCm
  wheels for linux_x86_64, and without this guard
  `unsloth studio update` aborts with a missing-wheel error on non
  x86_64 hosts.
- install_llama_prebuilt.py resolve_upstream_asset_choice: add a
  best-effort _detect_host_rocm_version() helper (reading
  /opt/rocm/.info/version, amd-smi version, hipconfig --version) and
  filter rocm_candidates to entries whose major.minor is <= host
  version. Falls back to the newest candidate only when no compatible
  one exists, so a ROCm 6.4 host downloads rocm-6.4 instead of being
  handed the numerically newest rocm-7.2 bundle (which fails preflight
  and forces a source build).
- install.sh: remove the round 2 --no-index switch from the Radeon
  wheel branch. --no-index forced uv to ignore PyPI entirely, which
  broke transitive dependency resolution (filelock, sympy, networkx,
  jinja2, fsspec, setuptools, typing-extensions, ...) on a fresh venv.
  Restore the round 1 explicit wheel URL invocation but add a
  torch / torchvision / torchaudio version-pair sanity check so a
  mismatched trio (e.g. torch 2.9.1 + torchvision 0.23.0 + torchaudio
  2.9.0) falls back to the standard ROCm index instead of installing a
  broken combination.
- install_python_stack.py _ensure_rocm_torch: restructure the
  "tag is None" path so it no longer short-circuits the bitsandbytes
  install. On a ROCm runtime older than anything in
  _ROCM_TORCH_INDEX, print the "no wheel" warning but still run the
  AMD bitsandbytes install.
- studio/backend/core/training/worker.py: restore the pre-PR
  "no timeout" behaviour for non-HIP causal-conv1d / mamba-ssm source
  builds. The round 2 "timeout = 1800 if is_hip else 300" cap aborts
  slow non-HIP builds (Linux aarch64, unsupported torch/CUDA combos)
  after 5 minutes; omit timeout for the non-HIP branch so the cap
  only applies to ROCm source builds.

* Fix round 4 findings: apply_gpu_ids env inheritance, Radeon X.Y, bitsandbytes gate

Address remaining issues surfaced by the round 4 reviewers:

- studio/backend/utils/hardware/hardware.py apply_gpu_ids: mirror the
  selection into HIP_VISIBLE_DEVICES / ROCR_VISIBLE_DEVICES whenever
  the caller already had a ROCm visibility env var set, not only when
  IS_ROCM has already been set by detect_hardware(). Training and
  inference workers call apply_gpu_ids() before detect_hardware()
  runs, so the old guard would leave a forked ROCm worker with a
  stale HIP_VISIBLE_DEVICES mask that no longer matched the
  narrowed CUDA_VISIBLE_DEVICES selection.
- install.sh get_radeon_wheel_url: accept X.Y ROCm versions in
  addition to X.Y.Z. The `/opt/rocm/.info/version` file and some
  hipconfig versions report only two components, and the Radeon
  repository publishes both rocm-rel-X.Y.Z/ and rocm-rel-X.Y/
  directories, so treating X.Y as invalid caused Radeon hosts to fall
  back to the generic ROCm index even when a matching AMD wheel set
  existed.
- install_python_stack.py _ensure_rocm_torch: only install the AMD
  bitsandbytes build when the venv actually has a ROCm-compatible
  torch (either already present or just installed by this function).
  Previously the bitsandbytes install ran unconditionally, which
  could leave an AMD bitsandbytes layered on top of a CPU/CUDA torch
  on hosts where the ROCm runtime is older than any entry in
  _ROCM_TORCH_INDEX. Also add --force-reinstall so an existing
  CPU/CUDA bitsandbytes is replaced by the AMD build during upgrades.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix gemini findings: amd-smi metric envelope validation and dict-wrapped GPU id

Two medium-severity defensive fixes from the gemini-code-assist review on
the AMD monitoring backend:

1. _extract_gpu_metrics may return a dict where every value is None when
   amd-smi succeeds (zero exit) but the JSON envelope contains no usable
   fields (error response, unsupported card). The new _has_real_metrics
   helper lets get_primary_gpu_utilization surface available:False and
   lets get_visible_gpu_utilization skip ghost device rows so the UI
   does not render placeholder cards with empty numbers.

2. Newer amd-smi versions wrap scalar fields as {"value": 0, "unit":
   "none"}, including the per-GPU id. The previous int(raw_id) call
   silently fell back to the enumeration index in that case, losing the
   real GPU id. Routing raw_id through the existing _parse_numeric
   helper handles bare ints, floats, strings, and the dict shape
   uniformly, with a debug log on parse failure.

* Fix gemini round 2 findings: explicit length guard on ROCm version file parser

Both _detect_rocm_version (install_python_stack.py) and
_detect_host_rocm_version (install_llama_prebuilt.py) read /opt/rocm/.info/version
or $ROCM_PATH/lib/rocm_version, split on "." and unconditionally accessed
parts[1]. The surrounding broad `except Exception: pass` already swallowed
the resulting IndexError, so a one-component file like "6\n" did fall
through to the next detection source -- but the control flow relied on
exception handling instead of an explicit check.

Add `if len(parts) >= 2:` guards in both helpers so the loop falls through
on its own without raising. Behaviour is unchanged for the common multi-
component case; the previously-silent IndexError path becomes an explicit
no-op.

* Fix gemini round 3: include has_rocm in validate_server fallback path

When validate_server is called without an explicit install_kind (older
call sites that have not been updated), the fallback was only enabling
--n-gpu-layers for NVIDIA and macOS arm64 hosts. AMD ROCm Linux hosts
fell through to the CPU validation path even though the prebuilt being
exercised was a HIP binary.

Add host.has_rocm to the fallback expression so the GPU offload flag is
applied consistently with the install_kind=='linux-rocm' / 'windows-hip'
branches above.

* Fix gemini round 4: remove risky bytes-vs-MB heuristic in _parse_memory_mb

The previous heuristic divided any bare number above 10_000_000 by
1024*1024 on the assumption that large unit-less values were bytes.
This misclassified small VRAM allocations: 5 MB of used VRAM reported
as 5_242_880 bytes without a unit would be taken at face value and
render as 5_242_880 MB (~5 TB) in the monitoring UI.

Modern amd-smi always provides explicit units (MiB/GiB dict form),
and legacy amd-smi returns bare numbers in MB -- the heuristic never
had a real workload to handle. Drop it and default to MB for bare
numeric input, keeping the existing unit-aware branches for dict /
string inputs unchanged.

The unrelated gemini suggestion to "default minor to 0" in the
amd-smi version awk parser was intentionally NOT applied: rocm7.0
and rocm7.1 ship different wheel sets, so silently substituting 0
for a missing minor could install the wrong wheels. The existing
reject-and-fall-through behaviour is safer.

* Fix gemini round 5: POSIX compliance and leading-comma visibility parsing

Three medium findings from gemini-code-assist addressed in this commit:

1. _pick_radeon_wheel used grep -o and sort -V, both GNU extensions
   that are not in POSIX and break on BSD/BusyBox coreutils. install.sh
   has a #!/bin/sh shebang so the whole pipeline was rewritten as a
   single awk script that extracts all href="..." hits on each line,
   filters to wheels matching the package prefix and python tag, and
   picks the newest version via zero-padded lexical comparison. No
   external sort or grep is needed.

2. _first_visible_amd_gpu_id in the AMD monitoring backend treated a
   leading comma (e.g. HIP_VISIBLE_DEVICES=",1") as "fall through to
   the next env var", which is surprising given the clear intent to
   narrow to device 1. Filter empty tokens after the split and return
   the first real one. An all-commas value ("," / ",,,") still falls
   through because no real tokens exist; the empty-string and "-1"
   explicit-zero cases are unchanged.

The unrelated amd-smi version awk parser suggestion was not applied
(see round 4 commit message for rationale: defaulting a missing minor
to 0 could silently install the wrong ROCm wheel set).

* Fix 20-reviewer.py findings: base drift, Radeon %2B, dpkg/rpm fallback, bnb, backend label

Consolidated fix batch from a 20-parallel reviewer.py run on the current
head. Each fix is drawn from a high-consensus finding and addresses a
real bug or feature gap, not a stylistic preference.

1. install.sh: bump `unsloth>=2026.4.2` -> `unsloth>=2026.4.4` at five
   call sites so this branch no longer regresses main's version floor
   (main bumped to 2026.4.4 in #4876). Without this, merging 4720 would
   silently downgrade the minimum version pin for fresh installs.

2. install.sh: URL-decode Radeon wheel names before extracting the
   torch / torchvision / torchaudio version strings. Real wheel URLs
   from repo.radeon.com are percent-encoded ("torch-2.10.0%2Brocm7.2.0...")
   so the previous `[+-]` terminator in the sed regex never matched,
   `_torch_ver` stayed empty, `_radeon_versions_match` stayed false,
   and every Radeon consumer install silently fell back to the generic
   ROCm index. Now decode %2B -> + first, then extract, then validate.

3. install.sh: the two AMD bitsandbytes install lines were running
   `uv pip install "bitsandbytes>=0.49.1"` without `--force-reinstall`,
   so upgrades where the venv already has a CPU/CUDA bitsandbytes
   satisfying the constraint would keep the stale non-AMD wheel. Add
   `--force-reinstall --no-cache-dir` to both call sites, matching the
   pattern already used in install_python_stack.py::_ensure_rocm_torch.

4. install_python_stack.py and install_llama_prebuilt.py: add
   `dpkg-query -W rocm-core` and `rpm -q rocm-core` fallbacks to the
   Python-side ROCm version detectors so they match the chain in
   install.sh::get_torch_index_url. Package-managed ROCm installs
   (Debian/Ubuntu/RHEL/Fedora distro packages) can expose GPUs via
   rocminfo/amd-smi but still lack /opt/rocm/.info/version, hipconfig,
   or amd-smi `version` output -- without these fallbacks, `unsloth
   studio update` on such hosts returned None and skipped the ROCm
   torch repair. Also strip the dpkg epoch prefix ("1:6.3.0-1") before
   parsing so epoch-annotated packages parse correctly.

5. hardware.py: add a `_backend_label(device)` helper that returns
   "rocm" when IS_ROCM is set and the device is DeviceType.CUDA, and
   use it for every `"backend": ...` emission in JSON responses served
   to the Studio frontend. Internally we still represent ROCm hosts as
   DeviceType.CUDA (ROCm torch reuses the whole torch.cuda.* API
   surface), but the user-facing API now correctly reports "rocm" on
   AMD boxes instead of labeling them as "cuda".

All 250 simulation scenarios pass (was 233 before this batch: added 17
new regression tests covering the version pin, %2B decoding, bnb
force-reinstall flags, dpkg/rpm fallback presence, and the
_backend_label helper's four-way truth table).

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix gemini round 6 + URL audit: amd.py defensive checks, rocm6.5+ clip to 6.4

Two rounds of fixes in one commit, plus a full URL audit of every PyPI /
download.pytorch.org / repo.radeon.com reference the PR introduces.

amd.py (4 medium gemini findings on commit b3627bc2):

1. _extract_gpu_metrics used `and vram_total_mb` as part of the vram_util
   gate. The follow-up `vram_total_mb > 0` already handles the division
   guard, but the truthiness check was redundant and slightly surprising
   for a 0.0 valid value. Replace with explicit `is not None and > 0`
   for both vram_util and power_util.

2. get_physical_gpu_count called `data.get("gpu", ...)` without guarding
   for non-dict envelopes. A scalar / string JSON response from amd-smi
   would raise AttributeError. Add an isinstance(data, dict) check and
   return None for unexpected shapes.

3. get_visible_gpu_utilization had the same .get() exposure on the outer
   envelope. Rewrite the gpu_list extraction as an explicit
   list/dict/else cascade so a malformed scalar envelope produces
   gpu_list=[data] and continues without raising.

4. The same function's per-entry loop also called gpu_data.get() on
   whatever was inside gpu_list. If a scalar ever leaks into the list
   (directly or via the previous fix's fallback), _extract_gpu_metrics
   would raise on the first .get() inside the helper. Skip non-dict
   entries in the loop before extracting metrics.

install.sh (URL audit finding, previously flagged by 20-reviewer as #13):

5. get_torch_index_url used `rocm6.*` in the rocm tag case statement,
   which matched rocm6.5 and rocm6.6 and emitted
   download.pytorch.org/whl/rocm6.5 -- which returns HTTP 403 because
   PyTorch only publishes rocm 5.7, 6.0-6.4, 7.0-7.2. Enumerate the
   supported 6.x minors explicitly and add a rocm6.* fallback branch
   that clips to rocm6.4 (the last supported 6.x wheel set).

URL audit results (all URLs PR 4720 references):
- 14/14 download.pytorch.org/whl/{cpu,cu118,cu124,cu126,cu128,cu130,
  rocm6.0..6.4,rocm7.0..7.2} return HTTP 200.
- 9/9 repo.radeon.com/rocm/manylinux/rocm-rel-{5.7,6.0,6.1,6.2,6.3,
  6.4,7.0,7.1,7.2}/ return HTTP 200.
- X.Y.Z patch directories exist for 7.0.2, 7.1.1, 7.2.1 but NOT for
  6.3.0, 6.4.0, 6.2.1 -- install.sh already handles this via the X.Y.Z
  -> X.Y fallback sed in the Radeon wheel install block.
- Docs links (rocm.docs.amd.com, docs.unsloth.ai AMD guide) and the
  llama.cpp GitHub releases API endpoint all return 200.

Test suite: 255 -> 258. New regression coverage:
- U17: get_physical_gpu_count tolerates scalar amd-smi envelope
- U18: get_visible_gpu_utilization tolerates scalar envelope
- U19a-c: vram_util / power_util return None on zero total, but
  vram_total_gb still echoes 0.0 (not None)
- A_rocm{6.5,6.6,6.9}_clips_to_rocm64: install.sh clips unsupported
  6.x minors to rocm6.4 instead of producing a 403 index URL

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix reviewer.py round 2: tokenizer AMD multi-GPU, --no-torch bnb, main.py backend label

Three high-confidence findings from a second 20-parallel reviewer.py run
on commit 7effb3ae. Triaged 15 total findings and applied the three that
were confirmed as real bugs; the rest were either false positives (e.g.
"migrated AMD venv not repaired" -- _ensure_rocm_torch runs downstream
via setup.sh regardless), design decisions (e.g. visibility mask env
vars not consulted in installer detection), or edge cases the existing
fallback logic already handles.

1. unsloth/tokenizer_utils.py [6/20]: the multi-GPU guard's shell probe
   runs `nvidia-smi --query-gpu=memory.used`, catches the failure, then
   only raises if `torch.cuda.is_available()` is False. On ROCm torch,
   torch.cuda.is_available() returns True (ROCm reuses the torch.cuda.*
   API), so the guard becomes dead code on AMD hosts and multi-GPU AMD
   setups slip through even though unsloth does not support them yet.
   Add a torch.cuda.device_count() > 1 fallback inside the except so
   AMD multi-visible-device setups are flagged consistently with the
   original CUDA memory check.

2. install.sh [1/20]: the fresh-install bitsandbytes block for AMD ROCm
   ran unconditionally when TORCH_INDEX_URL matched `*/rocm*`, even when
   SKIP_TORCH=true (from --no-torch or Intel Mac auto-detect). A user
   running `install.sh --no-torch` on an AMD host would still pull in
   bitsandbytes despite explicitly asking for GGUF-only mode. Wrap the
   case block in an outer `[ "$SKIP_TORCH" = false ]` guard.

3. studio/backend/main.py [3/20]: the /api/system endpoint returned
   `"device_backend": get_device().value`, which is "cuda" on ROCm
   hosts (because ROCm torch piggybacks on torch.cuda). Other endpoints
   (hardware.py) already use the _backend_label helper which swaps
   "cuda" -> "rocm" when IS_ROCM. Route /api/system through the same
   helper so the Studio UI reports the backend consistently across all
   endpoints.

4. studio/backend/tests/test_utils.py: update test_backend_matches_device
   to call _backend_label(get_device()) instead of raw get_device().value
   so the test matches the new contract and still passes on CUDA hosts.

Tests: 258 -> 261. New regression coverage:
- X08 main.py /api/system uses _backend_label
- X09 tokenizer multi-GPU guard has device_count() fallback
- X10 fresh-install bnb case block gated on SKIP_TORCH=false

* fix: prevent bitsandbytes from overwriting ROCm torch with CUDA wheels

During install, bitsandbytes was installed without --no-deps, causing
uv to resolve torch from PyPI (CUDA build) and silently overwrite the
ROCm wheels that were just installed in the previous step.

This happened in three places:
- install.sh: bitsandbytes install in both migrated and fresh paths
- install_python_stack.py: bitsandbytes install inside _ensure_rocm_torch()

Additionally, multiple install steps in install_python_stack.py (extras,
overrides, studio deps) can pull in CUDA torch via transitive
dependencies. A final _ensure_rocm_torch() call at the end of the
install sequence ensures ROCm torch is always in place at runtime.

All changes are gated behind ROCm-specific conditions and do not affect
NVIDIA, CPU-only, macOS, or Windows install paths.

Tested on AMD Instinct MI300X VF with ROCm 7.2.0 -- confirms
torch==2.10.0+rocm7.1 with HIP 7.1.25424 after install.

* fix: ROCm inference fallback -- skip Unsloth patching and bnb 4-bit on HIP

On AMD ROCm (HIP), two issues prevent the normal Unsloth inference path:

1. Unsloth's global monkey-patching of transformers model classes
   (LlamaRotaryEmbedding, attention modules) triggers
   _assert_async_cuda_kernel crashes on HIP during generation.
   Training uses different code paths and works fine.

2. bitsandbytes 4-bit matmul kernels also trigger HIP assertion
   failures on MI300X (CDNA3 / gfx942), even without Unsloth patching.

This commit adds a ROCm-specific inference fallback that:
- Skips importing Unsloth at module level (prevents global patching)
- Loads models in 16-bit with plain transformers + PEFT instead
- Resolves pre-quantized model names (e.g. "xxx-bnb-4bit" -> "xxx")
  since pre-quantized HF repos still trigger bnb codepaths
- Guards get_chat_template calls (unavailable without Unsloth import)
- Fixes max_seq_length=0 being passed to from_pretrained (GGUF
  semantics don't apply to transformers path)

The NVIDIA path is completely unchanged -- Unsloth import and
for_inference() optimization remain active. GGUF inference (via
llama-server/HIP) is unaffected since it never imports Python model
classes. AMD GPUs typically have large VRAM (e.g. 192GB on MI300X)
so 16-bit loading is practical for inference.

Tested on AMD Instinct MI300X VF (ROCm 7.2, HIP 7.1.25424):
- Simple generation: PASS
- Compare mode (base vs finetuned): PASS
- GGUF inference + tool calling: PASS (unaffected by this change)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: guard audio/vision inference on ROCm, remove unused import

- Add clear RuntimeError for audio/vision model inference on ROCm
  (these paths use Unsloth's FastModel/FastVisionModel which would
  crash on HIP; GGUF inference is the supported path on AMD)
- Remove unused `import os as _os` from the ROCm changes

* fix: amd-smi parsing for newer output format (gpu_data wrapper, mem_usage, temperature)

amd-smi on recent ROCm versions (7.x) wraps metric output in a
{"gpu_data": [...]} envelope instead of returning a raw list. This
caused get_primary_gpu_utilization() and get_visible_gpu_utilization()
to fail silently (returning available=False) because the GPU data
dict was never unwrapped.

Additionally:
- VRAM data moved from "vram" to "mem_usage" with "total_vram" /
  "used_vram" keys. Added fallback key lookup.
- Temperature "edge" sensor returns "N/A" on MI300X VF; the previous
  dict.get() chain returned the "N/A" string instead of falling
  through to "hotspot". Changed to a loop that checks each key until
  a parseable value is found.

Tested on AMD Instinct MI300X VF (ROCm 7.2, amd-smi 24.x):
- GPU utilization: 0% (idle), up to 100% during training
- Temperature: 40-44C (from hotspot sensor)
- VRAM: 0.28/191.69 GB (idle)
- Power: 158-211W draw

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Bug fix detecting radeon (#4940)

* Bug fix detecting radeon

* Expanding GPU target for gfx1100*

* Generalize gfx family-prefix filter to cover gfx10/gfx12 as well

rocminfo on ROCm 6.1+ emits LLVM generic-family ISA lines alongside the
specific GPU (e.g. gfx11-generic next to gfx1100). The outer grep captures
the bare family prefix from the generic line, and passing that to
-DGPU_TARGETS breaks the HIP build because clang only accepts specific
gfxNNN ids.

The previous filter only special-cased gfx11. Generalize it so any bare
2-digit family prefix (gfx10, gfx11, gfx12, ...) is dropped whenever a
specific sibling target is present in the same list. No real AMD GPU has
a 2-digit gfx id, so the filter can only ever drop family prefixes and
never a real target.

Covers the existing gfx11 cases unchanged, and extends the same fix to
gfx10-1-generic / gfx10-3-generic (RDNA1/2) and gfx12-generic (RDNA4),
which would otherwise hit the same build failure on newer rocminfo.

---------

Co-authored-by: Iswarya Alex <iswarya.alex@amd.com>
Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>

---------

Co-authored-by: Eda Z <eda.zhou@amd.com>
Co-authored-by: GoldenGrapeGentleman <yueyuan@amd.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: billishyahao <bill.he@amd.com>
Co-authored-by: Iswarya Alex <47045679+iswaryaalex@users.noreply.github.com>
Co-authored-by: Iswarya Alex <iswarya.alex@amd.com>
Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
2026-04-10 01:56:12 -07:00
Roland Tannous
33503ea248
Revert "updated models template mappers. added lfm2.5vl450m to transformers 5…" (#4945)
This reverts commit bcf4fd6bd3.
2026-04-09 23:14:57 -07:00
Roland Tannous
bcf4fd6bd3
updated models template mappers. added lfm2.5vl450m to transformers 5… (#4939)
* updated models template mappers. added lfm2.5vl450m to transformers 5.3.0 whitelist

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-09 23:36:42 +04:00
Ricardo-M-L
d5525e8bbb
fix: check find() return value before adding offset in try_fix_tokenizer (#4923)
* fix: check find() return value before adding offset in try_fix_tokenizer

The `str.find()` result was checked for -1 only after adding
`len(find_text)`, turning the guard into dead code. When the substring
is absent, `start` becomes `len(find_text) - 1` (a positive number),
so the `if start == -1: continue` never triggers and the subsequent
slice extracts garbage from the tokenizer string.

Split the find and offset into two steps so the -1 check works correctly.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Add defensive guards for token_id None and end find() returning -1

- Skip loop iteration early when token_id is None to avoid constructing
  a find_text that can never match valid JSON
- Guard end = tokenizer_string.find('",', start) against -1 to prevent
  silent garbage extraction from malformed tokenizer strings

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-09 06:15:46 -07:00
Lee Jackson
dc16e0c65b
Studio: keep chat input visible and fix compare pane clipping (#4924)
* fix(chat): sticky composer bar in thread

* fix(chat): fix compare pane clipping

* fix(chat): tighten scroll-to-bottom placement and compare footer spacing

* Fix TypeScript build break and clean up ViewportFooter classes

- Remove unused `compact` prop from ThreadScrollToBottom call site
  (component is FC with no props, passing it caused TS2322)
- Extract shared classes (sticky, bottom-0, z-20, bg-transparent) from
  ternary branches into the unconditional className string
- Restore `relative` on normal-mode footer so the inner absolute
  bg-background strip has a positioning context
- Remove redundant md:pb-3 / md:pb-4 (same value as base pb-3 / pb-4)
- Remove no-op `sticky bottom-0` from SharedComposer wrapper in both
  LoraCompareContent and GeneralCompareContent (flex layout with
  shrink-0 already pins it at the bottom; parent has no scrollable
  overflow for sticky to bind to)
- Fix truncated comment on pointer-events rationale

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-04-09 06:00:56 -07:00
kiankyars
ad5972492d
Fix raw text paragraph break normalization (#4884)
* Fix raw text paragraph break normalization

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Normalize horizontal whitespace before stripping non-ASCII and collapse leftover doubles

Run the [^\S\n]+ horizontal-whitespace collapse before the non-ASCII strip
so that Unicode whitespace (\u00A0, \u202F, \u2009, \u3000, \v, \f, etc.)
becomes a single ASCII space instead of being deleted outright. The prior
ordering silently merged adjacent words on HTML/PDF/OCR-sourced text:
"hello\u00a0world" used to produce "helloworld" after this PR; it now
produces "hello world".

Also drop \t from the allow-list since the horizontal-whitespace collapse
already normalizes tabs to a single space, and add a targeted [ ]{2,} pass
right after the non-ASCII strip so that a non-whitespace non-ASCII character
sitting between two spaces ("word1 (c) word2") does not leave an interior
double space. Without this extra pass, clean_text was not idempotent on
such inputs: the first call produced "word1  word2" and only the second
call collapsed it to "word1 word2". Fuzz testing over 10000 random inputs
now satisfies the idempotence invariant in every case.

* Add regression tests for Unicode/control whitespace and non-ASCII edge cases

Cover:
- Unicode horizontal whitespace separators (NBSP, narrow NBSP, thin space,
  en/em space, ideographic space, vertical tab, form feed) normalizing to
  a single ASCII space instead of being deleted.
- Mixed paragraph + Unicode whitespace realistic input ("Section\u00a01\r\n\r\nBody\ftext\u202Fhere").
- Tab collapsing and space trimming around newlines.
- Non-whitespace non-ASCII characters (copyright, accented letters, emoji)
  sitting between spaces: must not leave an interior double space, and
  clean_text must be idempotent on these inputs.
- Non-ASCII characters adjacent to a newline: stripping must not leave
  stray leading or trailing spaces on the neighbouring line, and must not
  swallow an adjacent paragraph break.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-04-09 04:45:43 -07:00
cheehook
7aa442289b
Fix Mistral DPO/preference training crash on non-xformers platforms (e.g. Intel XPU) (#4889)
* Fix Mistral training crash when xformers is unavailable

* Fix/adjust Mistral DPO training crash fix for PR #4889

- Clarify comment in MistralForCausalLM_fast_forward: the DPO embed-masking
  block runs BEFORE attention_mask is nulled out, and it is the consumer that
  requires a 2D mask.
- Add defensive attention_mask.ndim == 2 guard to the LlamaModel_fast_forward
  DPO embed-masking block so it self-protects if a 4D mask ever reaches it.

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-04-09 04:38:44 -07:00
Daniel Han
da2ef6dce6
Only run ldconfig CUDA-linking recovery when we have permission (#4930)
* Only run ldconfig CUDA-linking recovery when we have permission

When `import unsloth` runs on a non-root environment (shared HPC,
locked-down container, CI runner, etc.) the CUDA-linking recovery path
shells out to `os.system("ldconfig /usr/lib64-nvidia")`, which fails
loudly with "Permission denied". It's especially noisy for users who
don't even have bitsandbytes installed - they're doing 16bit or full
finetuning and the line immediately above told them "16bit and full
finetuning works!". The reason the recovery runs at all in that case
is that `bnb.functional.lib.cdequantize_blockwise_fp32` raises
AttributeError on `bnb is None`, the bare `except:` swallows it, and
the code drops into the recovery unconditionally.

Fix: gate the recovery body on `os.geteuid() == 0`. When we don't
have permission to run ldconfig, silently skip the recovery. When we
do, the recovery runs UNCHANGED - same `os.system()` calls, same
reload + retry, same warnings. `libcuda_dirs()` is used by both triton
and bitsandbytes, so we still want to run the recovery whenever we
have permission, regardless of whether bnb is installed.

For non-root users who DO have bitsandbytes installed and broken,
emit a single remediation warning telling them how to fix it manually
(`sudo ldconfig /usr/lib64-nvidia`). This preserves the diagnostic
guidance from the original code without the Permission denied noise.

Scope:
- Only the `DEVICE_TYPE == "cuda"` branch is touched.
- The `hip` (AMD ROCm) and `xpu` (Intel) branches are unchanged.
- On a real CUDA box running as root, behavior is byte-identical to
  main: same os.system() calls, same reload, same retry, same warnings.
  AST-verified by /tmp/verify_minimal/verify.py.
- `hasattr(os, "geteuid")` guards against Windows where `os.geteuid`
  doesn't exist.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Daniel Han <info@unsloth.ai>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-09 00:07:25 -07:00
dependabot[bot]
5fa8683b27
build(deps): bump the bun-frontend group across 1 directory with 16 updates (#4586)
* build(deps): bump the bun-frontend group across 1 directory with 16 updates

Bumps the bun-frontend group with 16 updates in the /studio/frontend directory:

| Package | From | To |
| --- | --- | --- |
| [@dagrejs/dagre](https://github.com/dagrejs/dagre) | `2.0.4` | `3.0.0` |
| [@dagrejs/graphlib](https://github.com/dagrejs/graphlib) | `3.0.4` | `4.0.1` |
| @hugeicons/core-free-icons | `3.3.0` | `4.0.0` |
| [@streamdown/cjk](https://github.com/vercel/streamdown/tree/HEAD/packages/streamdown-cjk) | `1.0.2` | `1.0.3` |
| [@streamdown/code](https://github.com/vercel/streamdown/tree/HEAD/packages/streamdown-code) | `1.0.2` | `1.1.1` |
| [lucide-react](https://github.com/lucide-icons/lucide/tree/HEAD/packages/lucide-react) | `0.577.0` | `1.6.0` |
| [recharts](https://github.com/recharts/recharts) | `3.7.0` | `3.8.0` |
| [shadcn](https://github.com/shadcn-ui/ui/tree/HEAD/packages/shadcn) | `3.8.5` | `4.1.0` |
| [streamdown](https://github.com/vercel/streamdown/tree/HEAD/packages/streamdown) | `2.3.0` | `2.5.0` |
| [@biomejs/biome](https://github.com/biomejs/biome/tree/HEAD/packages/@biomejs/biome) | `1.9.4` | `2.4.8` |
| [@eslint/js](https://github.com/eslint/eslint/tree/HEAD/packages/js) | `9.39.4` | `10.0.1` |
| [@types/node](https://github.com/DefinitelyTyped/DefinitelyTyped/tree/HEAD/types/node) | `24.12.0` | `25.5.0` |
| [eslint](https://github.com/eslint/eslint) | `9.39.4` | `10.1.0` |
| [eslint-plugin-react-refresh](https://github.com/ArnaudBarre/eslint-plugin-react-refresh) | `0.4.26` | `0.5.2` |
| [globals](https://github.com/sindresorhus/globals) | `16.5.0` | `17.4.0` |
| [typescript](https://github.com/microsoft/TypeScript) | `5.9.3` | `6.0.2` |



Updates `@dagrejs/dagre` from 2.0.4 to 3.0.0
- [Release notes](https://github.com/dagrejs/dagre/releases)
- [Changelog](https://github.com/dagrejs/dagre/blob/master/changelog.md)
- [Commits](https://github.com/dagrejs/dagre/compare/v2.0.4...v3.0.0)

Updates `@dagrejs/graphlib` from 3.0.4 to 4.0.1
- [Release notes](https://github.com/dagrejs/graphlib/releases)
- [Changelog](https://github.com/dagrejs/graphlib/blob/master/changelog.md)
- [Commits](https://github.com/dagrejs/graphlib/compare/v3.0.4...v4.0.1)

Updates `@hugeicons/core-free-icons` from 3.3.0 to 4.0.0

Updates `@streamdown/cjk` from 1.0.2 to 1.0.3
- [Release notes](https://github.com/vercel/streamdown/releases)
- [Changelog](https://github.com/vercel/streamdown/blob/main/packages/streamdown-cjk/CHANGELOG.md)
- [Commits](https://github.com/vercel/streamdown/commits/@streamdown/cjk@1.0.3/packages/streamdown-cjk)

Updates `@streamdown/code` from 1.0.2 to 1.1.1
- [Release notes](https://github.com/vercel/streamdown/releases)
- [Changelog](https://github.com/vercel/streamdown/blob/main/packages/streamdown-code/CHANGELOG.md)
- [Commits](https://github.com/vercel/streamdown/commits/@streamdown/code@1.1.1/packages/streamdown-code)

Updates `lucide-react` from 0.577.0 to 1.6.0
- [Release notes](https://github.com/lucide-icons/lucide/releases)
- [Commits](https://github.com/lucide-icons/lucide/commits/1.6.0/packages/lucide-react)

Updates `recharts` from 3.7.0 to 3.8.0
- [Release notes](https://github.com/recharts/recharts/releases)
- [Changelog](https://github.com/recharts/recharts/blob/main/CHANGELOG.md)
- [Commits](https://github.com/recharts/recharts/compare/v3.7.0...v3.8.0)

Updates `shadcn` from 3.8.5 to 4.1.0
- [Release notes](https://github.com/shadcn-ui/ui/releases)
- [Changelog](https://github.com/shadcn-ui/ui/blob/main/packages/shadcn/CHANGELOG.md)
- [Commits](https://github.com/shadcn-ui/ui/commits/shadcn@4.1.0/packages/shadcn)

Updates `streamdown` from 2.3.0 to 2.5.0
- [Release notes](https://github.com/vercel/streamdown/releases)
- [Changelog](https://github.com/vercel/streamdown/blob/main/packages/streamdown/CHANGELOG.md)
- [Commits](https://github.com/vercel/streamdown/commits/streamdown@2.5.0/packages/streamdown)

Updates `@biomejs/biome` from 1.9.4 to 2.4.8
- [Release notes](https://github.com/biomejs/biome/releases)
- [Changelog](https://github.com/biomejs/biome/blob/main/packages/@biomejs/biome/CHANGELOG.md)
- [Commits](https://github.com/biomejs/biome/commits/@biomejs/biome@2.4.8/packages/@biomejs/biome)

Updates `@eslint/js` from 9.39.4 to 10.0.1
- [Release notes](https://github.com/eslint/eslint/releases)
- [Commits](https://github.com/eslint/eslint/commits/v10.0.1/packages/js)

Updates `@types/node` from 24.12.0 to 25.5.0
- [Release notes](https://github.com/DefinitelyTyped/DefinitelyTyped/releases)
- [Commits](https://github.com/DefinitelyTyped/DefinitelyTyped/commits/HEAD/types/node)

Updates `eslint` from 9.39.4 to 10.1.0
- [Release notes](https://github.com/eslint/eslint/releases)
- [Commits](https://github.com/eslint/eslint/compare/v9.39.4...v10.1.0)

Updates `eslint-plugin-react-refresh` from 0.4.26 to 0.5.2
- [Release notes](https://github.com/ArnaudBarre/eslint-plugin-react-refresh/releases)
- [Changelog](https://github.com/ArnaudBarre/eslint-plugin-react-refresh/blob/main/CHANGELOG.md)
- [Commits](https://github.com/ArnaudBarre/eslint-plugin-react-refresh/compare/v0.4.26...v0.5.2)

Updates `globals` from 16.5.0 to 17.4.0
- [Release notes](https://github.com/sindresorhus/globals/releases)
- [Commits](https://github.com/sindresorhus/globals/compare/v16.5.0...v17.4.0)

Updates `typescript` from 5.9.3 to 6.0.2
- [Release notes](https://github.com/microsoft/TypeScript/releases)
- [Commits](https://github.com/microsoft/TypeScript/compare/v5.9.3...v6.0.2)

---
updated-dependencies:
- dependency-name: "@dagrejs/dagre"
  dependency-version: 3.0.0
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: bun-frontend
- dependency-name: "@dagrejs/graphlib"
  dependency-version: 4.0.1
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: bun-frontend
- dependency-name: "@hugeicons/core-free-icons"
  dependency-version: 4.0.0
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: bun-frontend
- dependency-name: "@streamdown/cjk"
  dependency-version: 1.0.3
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: bun-frontend
- dependency-name: "@streamdown/code"
  dependency-version: 1.1.1
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: bun-frontend
- dependency-name: lucide-react
  dependency-version: 1.6.0
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: bun-frontend
- dependency-name: recharts
  dependency-version: 3.8.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: bun-frontend
- dependency-name: shadcn
  dependency-version: 4.1.0
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: bun-frontend
- dependency-name: streamdown
  dependency-version: 2.5.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: bun-frontend
- dependency-name: "@biomejs/biome"
  dependency-version: 2.4.8
  dependency-type: direct:development
  update-type: version-update:semver-major
  dependency-group: bun-frontend
- dependency-name: "@eslint/js"
  dependency-version: 10.0.1
  dependency-type: direct:development
  update-type: version-update:semver-major
  dependency-group: bun-frontend
- dependency-name: "@types/node"
  dependency-version: 25.5.0
  dependency-type: direct:development
  update-type: version-update:semver-major
  dependency-group: bun-frontend
- dependency-name: eslint
  dependency-version: 10.1.0
  dependency-type: direct:development
  update-type: version-update:semver-major
  dependency-group: bun-frontend
- dependency-name: eslint-plugin-react-refresh
  dependency-version: 0.5.2
  dependency-type: direct:development
  update-type: version-update:semver-minor
  dependency-group: bun-frontend
- dependency-name: globals
  dependency-version: 17.4.0
  dependency-type: direct:development
  update-type: version-update:semver-major
  dependency-group: bun-frontend
- dependency-name: typescript
  dependency-version: 6.0.2
  dependency-type: direct:development
  update-type: version-update:semver-major
  dependency-group: bun-frontend
...

Signed-off-by: dependabot[bot] <support@github.com>

* Revert dagrejs upgrades

Keep @dagrejs/dagre at ^2.0.4 and @dagrejs/graphlib at ^3.0.4.

* Revert biome, eslint, typescript, and recharts upgrades

These upgrades break studio/frontend locally:

- @biomejs/biome 2.4.10 fails to parse the existing biome.json
  (files.ignore and organizeImports keys removed in v2; schema
  version mismatch).
- typescript 6.0.2 emits TS5101 on tsconfig.app.json baseUrl
  ("Option 'baseUrl' is deprecated and will stop functioning in
  TypeScript 7.0"), so tsc -b exits 2.
- eslint 10.2.0 conflicts with eslint-plugin-react-hooks@7.0.1,
  which peers on eslint ^9; npm install fails with ERESOLVE.
- recharts 3.8.1 widened LegendPayload.dataKey to include a
  function type, which breaks the React key={item.dataKey} usage
  in src/components/ui/chart.tsx (TS2322).

Hold these at their current pinned versions until the upstream
peer deps and config migrations are ready.

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-04-08 04:34:33 -07:00
Wasim Yousef Said
8e977445d4
Let recipes use the model loaded in Chat (#4840)
* feat: inject local model provider into recipe jobs via JWT

* feat: auto-generate JWT for local model providers in recipes

* feat: add is_local flag to model provider config types and utils

* fix(studio): skip endpoint validation for local providers

* feat(studio): add local/external model source toggle to provider dialog

* feat(studio): thread localProviderNames through model config dialog chain

* feat(studio): show 'Local model (Chat)' label for local model_provider configs

* fix: hardcode loopback for local endpoint, clear stale creds on toggle

* fix: document TOCTOU/JWT rotation, add deferred import comments, fix is_local serialization

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix(studio): clear stale local model state on provider toggle and validation

* fix(studio): override empty local endpoint in validation and skip model gate for unused providers

* fix(studio): resolve loopback port from app.state, clear stale local provider fields, sync model id on toggle

Address review feedback on the local-model-provider flow:

- Backend (jobs.py): _resolve_local_v1_endpoint now reads the actual bound
  port from app.state.server_port (set in run.py after binding) instead of
  parsing it out of request.base_url, which is wrong behind any reverse
  proxy or non-default port. The two duplicated urlparse blocks are gone.
- Backend (jobs.py): defensively pop api_key_env, extra_headers, extra_body
  from local providers so a previously external provider that flipped to
  local cannot leak invalid JSON or rogue auth headers into the local /v1
  call. Also dedupe the post-loop assignment and tighten the local-name
  intersection so empty names cannot match.
- Backend (jobs.py): hoist datetime and urllib.parse imports to the top
  import block for consistency with the rest of the file.
- Backend (run.py): expose the bound port on app.state.server_port after
  the uvicorn server is constructed.
- Frontend (model-provider-dialog.tsx): clear extra_headers and extra_body
  when toggling to local mode. Hidden inputs would otherwise keep stale
  JSON blocking validate/run.
- Frontend (model-config-dialog.tsx): factor the local-aware provider
  selection logic into applyProviderChange and call it from both
  onValueChange and onBlur, so manually typing a provider name and tabbing
  away keeps the model field consistent.
- Frontend (recipe-studio.ts store): handle both directions of the
  is_local toggle in the cascade. external -> local now backfills
  model: "local" on already-linked model_configs so they pass validation
  immediately, mirroring the existing local -> external clear path.
- Frontend (validate.ts + build-payload.ts): thread localProviderNames
  into validateModelConfigProviders and skip the "model is required"
  check for local-linked configs. Local providers do not need a real
  model id since the inference endpoint uses the loaded Chat model.

* fix(studio): narrow store cascade types, sync model placeholder on graph relink and node removal, harden ephemeral port path

Loop 2 review fixes:

- recipe-studio.ts: type-narrow next.is_local by also checking
  next.kind === "model_provider". TS otherwise raised TS2339 because
  next was typed as the union NodeConfig after the spread. The behavior
  is unchanged but the code now compiles cleanly.
- model-config-dialog.tsx: convert the lastProviderRef / providerInputRef
  ref-during-render pattern (pre-existing react-hooks/refs lint error)
  to a useEffect that syncs providerInputRef from config.provider. The
  combobox blur path still uses applyProviderChange and remains stable.
- recipe-graph-connection.ts: when a graph drag links a model_provider
  to a model_config, mirror the dialog applyProviderChange behavior:
  fill model: "local" if the new provider is local and the model field
  is blank, clear model when relinking from a local placeholder to an
  external provider, otherwise leave the model alone.
- reference-sync.ts: when a referenced provider node is removed, clear
  the synthetic model: "local" placeholder along with the provider
  field, so a future relink to an external provider does not pass
  validation with a stale value that fails at runtime.
- run.py: only publish app.state.server_port when the bound port is a
  real positive integer; for ephemeral binds (port==0) leave it unset
  and let request handlers fall back to request.base_url.
- jobs.py: _resolve_local_v1_endpoint also falls back when
  app.state.server_port is non-positive, and uses `is None` instead of
  the truthy fallback so a literal 0 is handled correctly.

* fix(studio): strict is_local check, narrow loaded-model gate to LLM-reachable configs, add scope-server port fallback

Loop 3 review fixes:

- jobs.py, validate.py: require `is_local is True` instead of truthy
  check. Malformed payloads such as is_local: "false" or is_local: 1
  would otherwise be treated as local and silently rewritten to the
  loopback endpoint.
- jobs.py: _resolve_local_v1_endpoint now tries request.scope["server"]
  (the actual uvicorn-assigned (host, port) tuple) as a second
  resolution step before falling back to parsing request.base_url.
  This covers direct-uvicorn startup paths and ephemeral binds that
  never publish app.state.server_port.
- jobs.py: new _used_llm_model_aliases helper collects the set of
  model_aliases that an LLM column actually references, and the
  "Chat model loaded" gate is now only triggered when a local
  provider is reachable from that set. Orphan model_config nodes on
  the canvas no longer block unrelated recipe runs.

* fix(studio): force skip_health_check on local-linked configs, skip JSON parsing for local providers, local-aware inline editor

Loop 4 review fixes:

- jobs.py: after rewriting local providers, also force
  skip_health_check: true on any model_config linked to a local
  provider. The /v1/models endpoint only advertises the real loaded
  model id, so data_designer's default model-availability health check
  would otherwise fail against the placeholder "local" id before the
  first chat completion call. The inference route already ignores the
  model id in chat completions, so skipping the check is safe.
- builders-model.ts: buildModelProvider now short-circuits for local
  providers and emits only { name, endpoint: "", provider_type, is_local }
  without running parseJsonObject on the hidden extra_headers/extra_body
  inputs. Imported or hydrated recipes with stale invalid JSON in those
  fields no longer block client-side validate/run.
- inline-model.tsx: the model_config branch now accepts an optional
  localProviderNames prop and mirrors the dialog applyProviderChange
  behavior. Changing provider to/from a local one auto-fills or clears
  the "local" placeholder consistently with the other edit paths.
- recipe-graph-node.tsx: derive localProviderNames from the store via
  useMemo (stable identity) and pass it through renderNodeBody to
  <InlineModel>. Hooks order is preserved by declaring them above the
  early return for markdown_note nodes.
- run.py: minor comment tweak - loop 3 already added the scope-server
  fallback path, note that in the comment.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: danielhanchen <info@unsloth.ai>
2026-04-08 03:48:22 -07:00
Daniel Han
c3d2d58046
Update dependabot.yml (#4915) 2026-04-08 03:39:50 -07:00
dependabot[bot]
0087515d5c
build(deps): bump oxc-parser (#4776)
Bumps the npm-oxc-validator group in /studio/backend/core/data_recipe/oxc-validator with 1 update: [oxc-parser](https://github.com/oxc-project/oxc/tree/HEAD/napi/parser).


Updates `oxc-parser` from 0.121.0 to 0.123.0
- [Release notes](https://github.com/oxc-project/oxc/releases)
- [Changelog](https://github.com/oxc-project/oxc/blob/main/napi/parser/CHANGELOG.md)
- [Commits](https://github.com/oxc-project/oxc/commits/crates_v0.123.0/napi/parser)

---
updated-dependencies:
- dependency-name: oxc-parser
  dependency-version: 0.123.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: npm-oxc-validator
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-08 03:35:40 -07:00
dependabot[bot]
67e9db4921
build(deps): bump oxc-parser (#4776)
Bumps the npm-oxc-validator group in /studio/backend/core/data_recipe/oxc-validator with 1 update: [oxc-parser](https://github.com/oxc-project/oxc/tree/HEAD/napi/parser).


Updates `oxc-parser` from 0.121.0 to 0.123.0
- [Release notes](https://github.com/oxc-project/oxc/releases)
- [Changelog](https://github.com/oxc-project/oxc/blob/main/napi/parser/CHANGELOG.md)
- [Commits](https://github.com/oxc-project/oxc/commits/crates_v0.123.0/napi/parser)

---
updated-dependencies:
- dependency-name: oxc-parser
  dependency-version: 0.123.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: npm-oxc-validator
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-08 03:35:33 -07:00
pre-commit-ci[bot]
c2184af079
[pre-commit.ci] pre-commit autoupdate (#4879)
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.15.8 → v0.15.9](https://github.com/astral-sh/ruff-pre-commit/compare/v0.15.8...v0.15.9)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-07 22:50:48 -07:00
Roland Tannous
f801e59c29
split venv_t5 into tiered 5.3.0/5.5.0 and fix trust_remote_code (#4878)
* split venv_t5 into venv_t5_530 and venv_t5_550 for tiered transformers 5.x support

* fix bfloat16 crash on T4 for FORCE_FLOAT32 models and disable trust_remote_code auto-enable for native t5 models

* revert FORCE_FLOAT32 dtype change

* restrict trust_remote_code auto-enable to Nemotron models only

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use config.json model_type for tier detection, add unsloth/nvidia namespace guard

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"

This reverts commit fb43d468e2.

* Revert "use config.json model_type for tier detection, add unsloth/nvidia namespace guard"

This reverts commit fc49ae2453.

* add unsloth/nvidia namespace guard to Nemotron trust_remote_code auto-enable

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* reorder tier checks: all substring matches before config.json fetches

* extract shared activate_transformers_for_subprocess into transformers_version.py

* narrow Nemotron trust_remote_code to nemotron_h/nemotron-3-nano, add to export worker

* clean venv_t5 dirs before re-install in setup.sh, clarify version alias comment

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* run venv_t5 migration outside deps fast-path gate in both setup scripts

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-07 20:05:01 +04:00
Daniel Han
1d8160376e
Bump minimum unsloth version to 2026.4.4 in install scripts (#4876) 2026-04-06 09:46:35 -07:00
Daniel Han
b295daf932 Update _utils.py 2026-04-06 09:39:06 -07:00
Lee Jackson
8c89b84bb6
Studio: Fix empty chat threads on navigation and stabilize new chat flow (#4872)
* fix(chat): prevent implicit empty thread creation and stabilize new-chat flow

* fix(chat): harden compare thread sync and simplify sidebar thread query

* fix(chat): harden new-thread state sync and isolate compare active thread updates

* fix(chat): stabilize new-thread state sync and prevent compare/session bleed

* Fix thread restoration, handleNewThread guard, sidebar filter, and delete flow

- Remove __LOCALID_ filter from getInitialSingleChatView: in this
  Dexie-backed adapter, AUI's __LOCALID_ prefixed IDs ARE the real
  persistent thread IDs stored by initialize(). Filtering them out
  breaks thread restoration on navigation.

- Simplify handleNewThread to synchronous: the async Dexie message
  check is redundant (persistence is already deferred to first append)
  and strands users on legacy empty threads. Use a simple guard that
  checks the store's activeThreadId to detect unsent drafts.

- Add message-count filter to sidebar: filter threads to only show
  those with at least one message, hiding legacy empty threads.

- Add store-based sidebar highlighting fallback: use activeThreadId
  from the store when view.threadId is not set (nonce-backed chats).

- Fix handleDelete to call onNewThread() instead of onSelect(), and
  clear activeThreadId, so the runtime properly resets after deleting
  the active thread.

* Fix handleDelete nonce path and restore __LOCALID_ filter

handleDelete was calling onNewThread() after clearing activeThreadId,
but the handleNewThread guard sees !view.threadId && !activeThreadId
and returns early, leaving the UI stuck on the deleted thread.
Fix by directly calling onSelect with a new nonce instead.

Restore __LOCALID_ filter in getInitialSingleChatView to prevent
restoring unpersisted AUI local thread IDs on navigation. Without
this filter, navigating away from /chat before sending a message
would restore a non-existent thread that Dexie cannot fetch.

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-04-06 09:32:54 -07:00
Daniel Han
4c83e3540e Update 2026-04-06 09:20:17 -07:00
Daniel Han
723bfb2363
Add unit tests for HfFileSystem glob skip guard (#4854)
Tests verifying that HfFileSystem().glob() is correctly skipped when
is_model or is_peft is False, matching the guard added in PR #4852.
2026-04-06 08:54:36 -07:00
JYYYYYT
aa4c6010e1
fix(studio): custom folder scan fails to find GGUF variants when pointing directly at a model directory (#4860)
Fix custom folder scanning when pointing directly at a model directory.

When a user adds a custom scan folder that points directly at a model
directory (e.g. /path/to/gemma-4-e2b-it-gguf/ containing config.json
and gemma-4-E2B-it-BF16.gguf), the model list previously showed
individual .gguf files as separate entries instead of recognizing the
directory as a single model. Clicking any entry showed "No GGUF
variants found" because list_local_gguf_variants received a file path
and immediately returned empty.

Changes:
- Add _is_model_directory() helper that detects directories with both
  config metadata and actual model weight files (excludes mmproj GGUFs
  and non-weight .bin files like tokenizer.bin)
- _scan_models_dir: detect self-model and return single directory entry
- _scan_lmstudio_dir: surface model directories directly instead of
  descending into them as publisher folders; handle both root and child
  model directories
- Add _resolve_gguf_dir() helper for GGUF path resolution that only
  falls back to parent directory when parent has model metadata
- list_local_gguf_variants / _find_local_gguf_by_variant: use resolver
  so .gguf file paths inside model directories work correctly
2026-04-06 08:31:07 -07:00
Roland Tannous
0835f0a61b
fix: skip redundant HfFileSystem().glob() calls in loader.py (#4852)
* fix: skip redundant HfFileSystem().glob() calls in loader.py

Guard the SUPPORTS_LLAMA32 glob blocks with `is_model and is_peft` so
the HfFileSystem HTTP call is only made when both configs could actually
exist. This prevents indefinite hangs on slow/unreliable networks since
the glob result is redundant when either AutoConfig or PeftConfig
already failed to load.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove test file from main PR - moved to separate PR

Tests for the glob skip guard belong in their own PR to keep
the loader change minimal and reviewable.

* Harden HfFileSystem glob: fix Windows path splitting, add try/except

- Use str.rsplit("/", 1) instead of os.path.split to extract filenames
  from HfFileSystem paths. HfFileSystem always returns POSIX-style paths,
  but os.path.split uses the OS separator, so on Windows the entire path
  was returned as the "filename" and the config name comparison always
  failed.
- Wrap the HfFileSystem().glob() call in try/except to gracefully handle
  network failures (offline mode, timeouts, unreachable Hub). On failure
  both_exist stays False, which is the safe default.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove redundant HfFileSystem().glob() call for remote repos

When is_model and is_peft are both True, AutoConfig and PeftConfig
have already loaded successfully, proving both config.json and
adapter_config.json exist. The HfFileSystem network call to re-verify
this was redundant and could cause hangs on slow networks.

Replace the glob + try/except block with a direct both_exist = True
assignment.

* Remove unused HfFileSystem import

HfFileSystem was only used for the glob() calls that were replaced
with direct both_exist = True assignments in the previous commit.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-04-06 07:46:39 -07:00
Daniel Han
07b6fcc344
Remove Gemma-4 from FORCE_FLOAT32 (#4875)
Gemma-4 does not need FORCE_FLOAT32. Testing shows that both float16 and
bfloat16 work correctly without the forced float32 override:

- Inference: identical outputs for float16 and bfloat16 (greedy decoding)
- Training (100 steps, 4-bit LoRA, SFT on FineTome-100k):
  - float16 final loss: 3.048
  - bfloat16 final loss: 3.065
  - Losses converge to within 0.02 by step 60
  - Grad norms healthy and comparable for both dtypes

The FORCE_FLOAT32 path was actually causing training divergence. With
it enabled, the compiled float32 run diverged at step ~28 with grad norms
collapsing to near zero and loss plateauing at ~12.4. Without it, both
dtypes train normally.

This enables float16 on Tesla T4 and other GPUs without bfloat16 support.
2026-04-06 07:33:28 -07:00
Daniel Han
ab65b47c73
Add tests for is_vision_model() caching behaviour (#4855)
* Add tests for is_vision_model() caching behaviour

* Fix review feedback: remove dead helper, fix exception test

- Remove unused _make_config() helper function (dead code)
- Fix test_exception_result_cached to actually exercise the exception path
  by mocking load_model_config to raise OSError instead of using
  side_effect=[False] which only tested normal False returns

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use strict mock specs so tests exercise intended detection paths

Use MagicMock(spec=[]) for all config mocks so hasattr() only returns
True for explicitly set attributes. Without this, MagicMock defaults
make all hasattr checks truthy, allowing tests to pass via unintended
detection paths (e.g. img_processor instead of vision_config).

---------

Co-authored-by: Roland Tannous <rolandtannous@gravityq.ai>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-06 06:41:40 -07:00
Roland Tannous
278f462996
[Studio][Optimization]Add vision detection cache to is_vision_model() (#4853)
* Add vision detection cache to is_vision_model() to avoid redundant subprocess spawns

is_vision_model() is called 4-5 times per training run for the same model
with zero caching. For transformers 5.x models, each call spawns a full
subprocess (~6s each). This adds a module-level _vision_detection_cache dict
following the same pattern as the existing _audio_detection_cache used by
detect_audio_type(). The function is refactored into a thin cache wrapper
around _is_vision_model_uncached(), saving ~12s per training run.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Include hf_token in vision cache key for gated model correctness

Cache key is now (model_name, hf_token) instead of just model_name.
This prevents stale False results when an unauthenticated probe for a
gated model is followed by an authenticated call.

* Remove test file from main PR - will be submitted separately

* Fix vision cache: normalize model names and skip caching transient failures

- Normalize model names in cache key using resolve_cached_repo_id_case()
  to avoid duplicate entries for different casings of the same HF repo
  (aligns with case normalization from #4822)
- Return None instead of False on transient failures (network errors,
  subprocess timeouts, HF API issues) so the cache layer can distinguish
  "definitely not a vision model" from "failed to check"
- Only cache definitive True/False results; transient failures are retried
  on the next call instead of being permanently locked in as False

* Refine failure handling: cache deterministic failures, guard normalization

- Subprocess non-zero exit, JSON errors, and general exceptions return
  False (deterministic, cached) instead of None (retryable). Only
  subprocess.TimeoutExpired returns None since timeouts are transient.
- Wrap cache key normalization in try/except so resolve_cached_repo_id_case
  or normalize_path failures fall back to raw model_name instead of
  crashing callers.

* Harden vision detection cache: fix transient failure handling, thread safety, token security

- All subprocess failure paths now return None (transient) instead of False,
  preventing permanent misclassification of VLMs after temporary HF/auth/network errors
- Use SHA256 fingerprint for hf_token in cache key instead of raw bearer token
- Add threading.Lock with double-checked locking to prevent thundering herd
  of concurrent subprocess spawns for the same uncached model
- Distinguish permanent failures (RepositoryNotFoundError, GatedRepoError,
  ValueError) from transient ones in _is_vision_model_uncached
- Pass resolved/normalized model name to detection (not just cache key)
- Log normalization fallback at debug level instead of silent swallow
- Thread hf_token through callers in routes/models.py and trainer.py
  that previously omitted it

* Refine lock strategy and token fingerprint

- Move detection computation outside the lock to avoid serializing
  long-running subprocess spawns (60s timeout) and HF API calls across
  all concurrent model checks. Lock is now only held for cache writes.
- Use full SHA256 digest for token fingerprint instead of truncated
  16-char prefix to eliminate collision risk.

* Fix huggingface_hub import fallback and use atomic cache read

- Add fallback import path for RepositoryNotFoundError/GatedRepoError
  from huggingface_hub.utils (older hub versions) when .errors is
  not available
- Use sentinel-based dict.get() for single atomic cache read instead
  of two-step in/[] pattern (future-proof for no-GIL runtimes)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-04-06 06:41:20 -07:00
Leo Borcherding
68965988cf
Fix/studio colab button message: Add fallback message for Colab Studio button when proxy URL fails (#4866)
* Add fallback message for Colab Studio button when localhost link doesn't work

* Make fallback message darker grey for better readability

* Make fallback message bold for better visibility

---------

Co-authored-by: LeoBorcherding <LeoBorcherding@users.noreply.github.com>
2026-04-05 21:57:45 -07:00
Daniel Han
6100867447
Bump minimum unsloth version to 2026.4.2 in install scripts (#4842) 2026-04-03 15:14:28 -07:00
Daniel Han
170c4b9b99 Update _utils.py 2026-04-03 15:02:14 -07:00
Daniel Han
4020a70a93
Add tests for cache case resolution (from PR #4822) (#4823)
Tests for resolve_cached_repo_id_case and get_model_config case
resolution, separated from the runtime changes in PR #4822.
2026-04-03 13:58:26 -07:00
Daniel Han
4f65cc94bc
Add Gemma 4 model sampling defaults (#4838)
Add per-model YAML configs and MODEL_NAME_MAPPING entries for all 8
Gemma 4 models (4 instruct + 4 base):
- gemma-4-31B-it / gemma-4-31B
- gemma-4-26B-A4B-it / gemma-4-26B-A4B
- gemma-4-E2B-it / gemma-4-E2B
- gemma-4-E4B-it / gemma-4-E4B

GGUF variants (only for -it models) resolve via the gemma-4 family
entry in inference_defaults.json.

Sampling defaults: temperature=1.0, top_p=0.95, top_k=64, min_p=0.0,
no repetition or presence penalty. Matches gemma-3n and gemma-3.
2026-04-03 13:57:15 -07:00
Daniel Han
a32b871f0e
studio: add speculative decoding support (ngram-mod, on by default) (#4836)
* studio: add speculative decoding support (ngram-mod, on by default)

Enable n-gram speculative decoding for GGUF models in Unsloth Studio.
Uses llama.cpp's ngram-mod mode which gives 10-40% faster generation
with zero VRAM cost via a 4MB fixed hash table that auto-resets on
low acceptance rates.

Backend:
- Add speculative_type field to LoadRequest, LoadResponse, and
  InferenceStatusResponse pydantic models
- Add speculative_type parameter to LlamaCppBackend.load_model()
  with allowlist validation (ngram-simple, ngram-mod)
- Pass --spec-type, --spec-ngram-size-n 16, --draft-max 24 flags
  to llama-server when ngram-mod is active
- Default to ngram-mod for non-vision GGUF models server-side
- Silently skip speculative decoding for vision models (unsupported
  in llama.cpp server-context.cpp)

Frontend:
- Add speculative_type to TS API types
- Add speculativeType/loadedSpeculativeType to chat runtime store
  with default value of "ngram-mod"
- Add On/Off toggle in Model settings section (GGUF only, hidden
  for vision models), included in dirty check for Apply/Reset
- Wire speculative_type through model load request and response
- Restore speculative type state on page refresh/reconnect

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: remove server-side speculative decoding override

The backend was overriding speculative_type=None to "ngram-mod" for
non-vision GGUF models, which prevented users from disabling spec
decoding via the UI toggle. The frontend store already defaults to
"ngram-mod", so the backend fallback was redundant and blocked the
explicit "Off" setting.

* fix: use recommended ngram-mod params from llama.cpp docs

Update speculative decoding params to match the recommended values
from llama.cpp docs (docs/speculative.md):
  --spec-ngram-size-n 24 (was 16, docs say small n not recommended)
  --draft-min 48 (was 0)
  --draft-max 64 (was 24, docs note MoEs need long drafts)

Also fix comment: ngram-mod uses ~16 MB (4M entries * 4 bytes),
not 4 MB.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add benchmark table and references to speculative decoding comment

Include speedup numbers from llama.cpp PRs #18471 and #19164 as an
inline comment so future readers understand the expected gains.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-03 13:56:59 -07:00
Daniel Han
2c73ab7871
fix(studio): harden sandbox security for terminal and python tools (#4827)
* fix(studio): harden sandbox security for terminal and python tools

The existing command blocklist used naive str.split() which is trivially
bypassable via quoting, full paths, nested shells, variable expansion,
and cross-tool pivoting through Python os.system/subprocess. Fixes #4818.

Changes:
- Replace str.split() blocklist with shlex.split() + os.path.basename()
  tokenization and regex scanning at shell command boundaries
- Add sanitized subprocess environment (_build_safe_env) that strips
  credentials (HF_TOKEN, WANDB_API_KEY, GH_TOKEN, AWS_*, etc.) and
  restricts PATH to /usr/local/bin:/usr/bin:/bin
- Add PR_SET_NO_NEW_PRIVS via prctl on Linux so sudo/su/pkexec fail
  at the kernel level regardless of how they are invoked
- Add RLIMIT_NPROC (256) and RLIMIT_FSIZE (100MB) to prevent fork
  bombs and disk filling attacks
- Extend AST safety checker to detect os.system(), os.popen(),
  subprocess.run/Popen/call/check_output, os.exec*, os.spawn* calls
  containing blocked commands or dynamic (non-literal) arguments
- Add cross-platform support: cmd.exe on Windows, bash on Unix;
  CREATE_NO_WINDOW flag on Windows, preexec_fn on Unix
- Expand blocklist from 7 to 14 commands: add su, chown, passwd,
  mount, umount, fdisk, kill, killall, pkill
- Apply all layers to both _bash_exec and _python_exec

Zero measurable performance overhead -- shlex parsing and a single
prctl syscall per subprocess fork.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix review findings: exception_catching dead code, false positives, process substitution

- Include exception_catching reasons in _check_code_safety so bare
  except-in-loop timeout evasion is actually blocked (was computed in
  _check_signal_escape_patterns but never read by the caller)
- Remove base.split() inner loop that caused false positives on quoted
  text arguments containing blocked words (e.g. echo "kill this process")
- Add targeted nested shell detection for bash/sh/zsh -c arguments
  instead, which catches bash -c 'sudo whoami' without false positives
- Add <() process substitution to the regex character class so
  diff <(rm -rf /path) is also caught
- Fix error message to say "unsafe patterns" instead of specifically
  mentioning signal manipulation when other categories trigger

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address review feedback: regex paths, keyword args, list element scanning

- Regex now matches blocked commands after optional path prefix at shell
  boundaries (catches ls; /usr/bin/sudo and similar)
- Nested shell detection uses os.path.basename so bash -c "/bin/rm" is
  caught
- AST checker now inspects keyword arguments (not just positional) so
  subprocess.run(args="sudo ...", shell=True) is detected
- List elements in subprocess calls are now checked via
  _find_blocked_commands for consistency (catches subprocess.run(["bash",
  "-c", "rm -rf /"]))
- Dynamic argument check uses _is_safe_literal that validates list
  contents are all string literals

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix nested shell scan to only check the script body, not positional args

bash -c 'script' arg0 arg1 -- only tokens[i+1] is the script body;
subsequent tokens are $0, $1 positional parameters passed to the script
and are not executed as shell commands. Scanning all remaining tokens
caused false positives.

* Add subshell parentheses to regex command boundary detection

(sudo whoami) was not caught because ( was not in the regex character
class for shell command boundaries. Add ( to the set alongside ;, &,
|, backtick, newline.

* Address high-priority review findings from 7 parallel reviewers

- Track from-imports of dangerous functions (from os import system,
  from subprocess import run as r, etc.) via shell_exec_aliases dict
  so bare-name calls are detected by the AST checker
- Include the active Python interpreter and virtualenv directories
  in the sanitized PATH so pip, uv, and Studio packages remain
  accessible in the sandbox
- Add Windows-specific blocked commands (rmdir, takeown, icacls,
  runas, powershell, pwsh) only on win32 platform
- Add os.posix_spawn and os.posix_spawnp to _SHELL_EXEC_FUNCS
- Handle tuple literals same as list literals in AST argument
  inspection (both _extract_strings_from_list and _is_safe_literal)

* Fix false positive on check=True kwargs and recursive nested shell scanning

- Only inspect command-carrying keyword arguments (args, command,
  executable, path, file) in the AST checker, not control flags like
  check=True, text=True, capture_output=True which are booleans and
  were incorrectly flagged as non-literal dynamic arguments
- Replace split() in nested shell detection with recursive call to
  _find_blocked_commands so that quoted commands (bash -c '"sudo"
  whoami') and semicolons (bash -c "sudo;ls") within nested shells
  are properly detected through the full shlex + regex pipeline

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Move preexec_fn imports to module level and use find_library for libc

Addresses two Gemini review findings:

1. preexec_fn thread safety: _sandbox_preexec previously imported ctypes
   and resource inside the function body, which runs between fork() and
   exec() in the child process. In a multi-threaded server, this could
   deadlock if the import machinery locks were held by another thread at
   fork time. Now all imports and the libc handle are resolved once at
   module load time, so _sandbox_preexec only calls C-level functions
   (prctl, setrlimit) with no Python import activity.

2. Hardcoded libc.so.6 path: replaced with ctypes.util.find_library("c")
   which works on glibc (libc.so.6), musl (libc.musl-*.so.1), and other
   Linux distributions where libc has a different soname.

* Apply Gemini style suggestions: combined regex, dict.fromkeys, constant hoisting

- Combine per-word regex loop into a single re.findall with alternation
  pattern, avoiding repeated regex compilation and searching
- Replace manual dedup loop with dict.fromkeys for PATH entries
- Hoist _CMD_KWARGS frozenset out of visit_Call to avoid recreating it
  on every AST node visit

* Add cmd /c nested shell detection for Windows parity

The nested shell scan only checked for Unix shells (bash -c, sh -c, etc).
Add cmd /c and cmd.exe /c detection so that Windows nested shell
invocations are also recursively scanned for blocked commands. The token
scan already catches blocked commands at any position, so this is
defense-in-depth for consistency across platforms.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Handle combined shell flags (-lc, -xc) and interleaved flags (--login -c)

The nested shell scan only matched token == "-c" with the immediately
preceding token being a shell name. This missed:
- Combined flags: bash -lc 'rm ...' (-lc ends with c, is a valid
  combined flag meaning -l -c)
- Interleaved flags: bash --login -c 'sudo ...' (--login sits between
  bash and -c)

Now matches any short flag ending in 'c' (e.g. -lc, -xc, -ic) and
walks backwards past intermediate flags to find the shell binary.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix /bin/bash bypass, remove RLIMIT_NPROC, reduce AST false positives

Addresses three high-consensus findings from 20-reviewer pass:

1. /bin/bash -c 'sudo whoami' bypassed nested shell scan because the
   backwards flag-skip logic treated paths starting with / as flags.
   Now only skips tokens starting with - as Unix flags; on Windows
   only skips short /X flags (not /bin/bash style paths). [9/20]

2. RLIMIT_NPROC=256 caused subprocess.run to fail with EAGAIN because
   Linux enforces NPROC per real UID, not per process tree. Removed
   RLIMIT_NPROC entirely; RLIMIT_FSIZE and PR_SET_NO_NEW_PRIVS remain
   as the primary resource and privilege controls. [5/20]

3. AST checker rejected safe dynamic subprocess usage like
   cmd=["git","status"]; subprocess.run(cmd) as shell_escape_dynamic.
   Now only flags dynamic args for shell-string functions (os.system,
   os.popen, subprocess.getoutput, etc.) or when shell=True is
   explicitly set. List-based subprocess calls with shell=False (the
   default) do not pass through a shell and are not flagged. [12/20]

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Handle Windows drive letter paths and .exe extensions in command detection

Gemini review found that Windows absolute paths (C:\Windows\System32\
shutdown.exe) and executable extensions (.exe, .com, .bat, .cmd) were
not handled:

- Token scan now strips .exe/.com/.bat/.cmd extensions before checking
  the blocklist, so sudo.exe matches sudo, shutdown.bat matches shutdown
- Regex pattern now includes optional Windows drive letter prefix
  ([a-zA-Z]:[/\\]) and optional executable extension suffix, so commands
  after shell metacharacters with full Windows paths are also caught

* Handle **kwargs dict expansion, non-literal shell=, and except Exception false positive

Addresses three findings from second 20-reviewer pass:

1. **kwargs dict expansion (9/20): subprocess.run(**{"args": "rm ...",
   "shell": True}) bypassed the AST checker because **kwargs were
   treated as opaque. Now expands literal dict **kwargs to inspect
   their keys, and flags opaque **kwargs (variable dicts) as unsafe.

2. Non-literal shell= values (7/20): shell=variable was treated as
   shell=False (safe). Now any shell= value that is not literally
   False is treated as potentially True (conservative default).

3. except Exception false positive (1/20): except Exception in a loop
   was flagged as timeout evasion, but Exception does not catch
   SystemExit or KeyboardInterrupt which are used for timeout
   enforcement. Narrowed to only flag except BaseException and
   except TimeoutError in loops.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-03 13:33:42 -07:00
Neodon
c027ec192e
fix(studio): ensure first chat tool call starts in session sandbox (#4810)
Fixes #4809

On a new Studio chat, the first tool call could start before the frontend
initializes the thread ID. That meant the first request could go out without
a session_id, so the backend started the tool in the shared sandbox root
instead of the chat's session sandbox.

Frontend:
- Eagerly initialize the thread when switching to a new chat
- Resolve the thread ID once at request time and keep it stable through
  async model-load waits
- Disable ActiveThreadSync during new-chat initialization to prevent
  stale thread IDs from being written back
- Add error handling for thread initialization failures
- Clear activeThreadId on all compare-mode entry paths to prevent
  cross-session leakage
- Fix exitCompare to restore context usage from the saved view
- Coerce falsy thread IDs to undefined for consistent backend/frontend
  fallback behavior
- Use _default as the image sessionId fallback to match the backend

Backend:
- Use ~/studio_sandbox/_default when a request arrives without a session_id
2026-04-03 11:44:22 -07:00
Lee Jackson
a29b4e23fd
studio: reuse HF cached repo casing to prevent duplicate downloads (#4822)
* fix(studio): reuse HF cached repo casing to prevent duplicate downloads

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Move cache case resolution tests to separate PR

Tests for resolve_cached_repo_id_case and get_model_config case resolution
belong in their own PR to keep this change focused on the runtime fix.

* fix(studio): debug-log HF_HUB_CACHE fallback in path_utils

* Fix stale memoization in resolve_cached_repo_id_case

- Check exact-case path before memo to ensure a newly-appeared exact
  match always wins over a previously memoized variant
- Validate memoized entries still exist on disk before returning them
  to prevent stale results when cache dirs are deleted/recreated

* Minor cleanups for cache case resolution

- Use .is_dir() instead of .exists() for exact-case cache check
  (cache entries are always directories)
- Remove redundant fallback in _detect_audio_from_tokenizer since
  get_cache_path already handles case resolution and returns None
  when the model is not cached

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-04-03 05:48:24 -07:00
Wasim Yousef Said
50dede11cc
Allow non-LLM recipes to run and move Data tab first in executions (#4805)
* feat: allow non-LLM recipes to run without provider block

* feat: reorder execution tabs and add generation-aware data tab empty state

* fix: add accessibility attrs to data tab spinner and use literal ellipsis

* fix(studio): use shared spinner, stub provider, and hide unused LLM metrics

Backend: inject stub model provider for sampler-only recipes so
DataDesigner init does not reject empty provider lists.

Frontend: use shared Spinner component, hide LLM columns metric
and model usage card when recipe has no LLM columns.

* Fix tab reset and terminal auto-scroll regressions for PR #4805

Reset detailTab to "data" when switching between executions so
the Data tab default is applied consistently, not only on first
mount. Also add detailTab to the terminal scroll effect deps so
auto-scroll-to-bottom fires when the user opens the Overview tab
after landing on Data.

* Guard terminal scroll reset to only fire on Overview tab

The previous scroll effect ran on every tab switch, which could
reset the user's manual scroll position if they scrolled up in
the terminal and briefly switched tabs. Now the scroll-to-bottom
and sticky-bottom reset only fires when navigating to the
Overview tab.

* Use None for stub provider api_key instead of literal string

The stub ModelProvider that satisfies the DataDesigner registry
for non-LLM recipes should not carry a fake credential string.
Using None avoids sending an Authorization header if the provider
is ever inadvertently invoked.

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-04-03 05:37:26 -07:00
Wasim Yousef Said
5b7c0615f3
feat(studio): differentiate web search and URL fetch in chat tool UI (#4802)
Differentiate web_search query searches from URL fetches in the Studio chat UI.

Backend (llama_cpp.py):
- Emit "Reading: hostname" for URL fetches and "Searching: query" for query searches in SSE status events
- Only show hostname for valid http/https URLs; schemeless/non-http URLs get "Reading page..." generic fallback
- Strip www. prefix for consistency with the frontend

Frontend (tool-ui-web-search.tsx):
- Tool card shows "Read hostname" / "Reading hostname..." for URL fetches
- Shows "Searched query" / "Searching for query..." for query searches
- Uses new URL() with protocol check; falls back to "Read page" / "Reading page..." for non-http URLs
2026-04-03 05:03:27 -07:00
Daniel Han
8981e6c804
Update test_pr4562_bugfixes.py for simplified install policy (#4817)
- Add TestFetchJsonRetries for JSON retry logic and max_pages
- Update TestSourceCodePatterns for simplified --simple-policy flow
- Add tests for installed prebuilt release reporting
- Add test for CUDA toolkit version-sorted nvcc discovery
- Remove assertions for removed --resolve-install-tag / --resolve-source-build paths
2026-04-03 04:06:14 -07:00
DoubleMathew
ac562bac66
Fix/llama.cppbuilding (#4804)
* Simplify llama.cpp install logic

* print release tag

* Retry failed json decode

* don't pull all ggml releases

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove test file changes from main PR

Test changes for test_pr4562_bugfixes.py will be submitted in a separate PR to keep this PR focused on the install path simplification.

* Fix setup.sh executable bit and direct tag lookup for pinned releases

- Restore setup.sh file mode to 100755 (was accidentally changed to 100644)
- Add direct GitHub API tag lookup in iter_release_payloads_by_time for
  non-latest requested tags (e.g. b7879) instead of relying on paginated
  release scans that may miss older releases beyond the 5-page limit
- Update stale DEFAULT_PUBLISHED_REPO comment to match new value

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix force-compile default ref and remove dead code in setup.ps1

- Change FORCE_COMPILE_DEFAULT_REF from "main" to "master" in all three
  files (install_llama_prebuilt.py, setup.sh, setup.ps1) since
  ggml-org/llama.cpp uses "master" as its default branch, not "main".
  Using "main" would cause git clone --branch to fail when
  UNSLOTH_LLAMA_FORCE_COMPILE=1 with UNSLOTH_LLAMA_TAG=latest.
- Remove dead if ($SkipPrebuiltInstall) block inside the else branch of
  setup.ps1 that could never be reached (the outer elseif already
  handles $SkipPrebuiltInstall=true).
- Maintain setup.sh executable bit (100755).

* Improve iter_release_payloads_by_time error handling for direct tag lookup

When a pinned release tag is not found (HTTP 404), fall through to the
paginated release scan instead of silently returning empty results.
Non-404 errors (network failures, rate limits) are propagated to the
caller so users get actionable error messages.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-04-03 00:34:20 -07:00
Michael Han
c1685b9459
Gemma 4 update.md 2026-04-02 22:54:03 -07:00
Manan Shah
a7e6964117
Fix/gemma4 install script (#4815)
* transformer 5.5.0 has now been released

* fallback for python < 3.10
:
2026-04-02 22:03:35 -07:00
Roland Tannous
6644a771b4
fix: patch PEFT for Gemma4ClippableLinear in loader checkpoint path (fixes export) (#4807)
* fix: patch PEFT for Gemma4ClippableLinear in loader checkpoint path

The same Gemma4ClippableLinear monkey-patch that exists in vision.py
for training is needed in loader.py for loading existing checkpoints
(used by export and inference).

Gemma4ClippableLinear wraps nn.Linear but does not subclass it, so
PEFT's LoRA injection fails with "Target module not supported".
The patch redirects PEFT to target the inner .linear child instead.

Applied only to the vision model PeftModel.from_pretrained path.
Temporary fix until PEFT adds native support (peft#3129).

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: wrap ClippableLinear patch in try/finally to always restore

Ensures _create_and_replace is restored even if PeftModel.from_pretrained
raises, preventing leaked global state across subsequent model loads.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-03 04:03:54 +04:00
Roland Tannous
f91ef8f9b0
fix(studio): lazy-import transformers in model_config to fix 5.x version switch (#4806)
* fix(studio): lazy-import AutoConfig in model_config.py to fix transformers 5.x version switch

Move `from transformers import AutoConfig` from module level to inside
load_model_config() where it is actually used.

model_config.py is transitively imported at module load time via:
  core/inference/__init__ → llama_cpp → utils.models → model_config

In inference subprocesses (mp.spawn), this chain runs before
_activate_transformers_version() can prepend .venv_t5/ to sys.path.
The eager import caches transformers 4.57.6 in sys.modules, and the
subsequent sys.path change has no effect — Python always checks
sys.modules before sys.path.

Making the import lazy ensures transformers is not loaded until after
version activation, so the subprocess picks up the correct version.

* fix(studio): also lazy-import extract_model_size_b in llama_cpp.py

Belt-and-suspenders: make the import that originally triggered the
chain lazy as well, so future module-level AutoConfig additions in
utils.models cannot reintroduce the problem.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-03 02:56:01 +04:00
Daniel Han
e553a8ad0b
fix(studio): suppress fatal error when prebuilt manifest is missing (#4799)
When DEFAULT_PUBLISHED_REPO is ggml-org/llama.cpp, the prebuilt
resolver raises PrebuiltFallback because ggml-org releases do not
include a llama-prebuilt-manifest.json asset. This was caught by the
generic Exception handler and printed as "fatal helper error" to
stderr, which triggers NativeCommandError on PowerShell.

Catch PrebuiltFallback separately in the top-level __main__ handler
and exit with EXIT_FALLBACK (code 2) instead of EXIT_ERROR (code 1).
The message is still logged but without the "fatal helper error"
prefix. The shell scripts already handle non-zero exits and fall
back to source builds.

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
2026-04-02 12:18:11 -07:00
Daniel Han
8ffd5826f2 Gemma-4 2026-04-02 11:59:37 -07:00
Daniel Han
934478ae31
fix(studio): revert llama.cpp default tag to latest (#4797)
* fix(studio): revert llama.cpp default tag to latest

The latest ggml-org/llama.cpp release (b8637) now includes Gemma 4
support. Revert the temporary "b8637" pin from #4796 to "latest" so
the prebuilt resolver always picks the newest release automatically
without needing manual tag bumps.

* docs: add comment explaining latest vs master for llama.cpp tag

Document in all three files why "latest" is preferred over "master"
and when "master" should be used as a temporary override.

---------

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
2026-04-02 11:52:37 -07:00
Daniel Han
401621618b
fix(studio): don't set trust_remote_code for Gemma 4 training (#4795)
Gemma 4 is a native transformers 5.5 model and does not need
trust_remote_code=True. The auto-enable logic (added for NemotronH)
was catching all transformers 5.x models, including Gemma 4.

When trust_remote_code=True, unsloth_compile_transformers() returns
early without running the compiler. This disables the fused cross
entropy patch, causing logged training loss to be inflated by the
gradient_accumulation_steps factor.

Exclude models matching "gemma-4" or "gemma4" from the auto-enable
so the compiler runs and applies fused cross entropy correctly.
2026-04-02 11:44:26 -07:00
Daniel Han
8d1712b4ea
fix(studio): pin llama.cpp to b8637 release (Gemma 4 support) (#4796)
ggml-org/llama.cpp b8637 includes Gemma 4 support (ggml-org/llama.cpp#21309).
Revert the temporary "master" default back to a pinned release tag.

This eliminates the HTTP 422 errors from the prebuilt resolver (which
could not find a release matching "master"), avoids unnecessary source
builds, and restores prebuilt binary downloads on all platforms.

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
2026-04-02 11:43:53 -07:00
DoubleMathew
7ae9b7f45f
fix windows llama.cpp compile from source issue (#4793)
* fix windows llama.cpp compile from source issue

* undo local repo usage

* fix llama.cpp install

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix windows

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: route resolve-source-build call through Invoke-LlamaHelper

The --resolve-source-build call at the source-build resolution path
was still calling install_llama_prebuilt.py directly instead of going
through Invoke-LlamaHelper. On PS7+ with ErrorActionPreference=Stop,
stderr from the 422 response (when tag is "master") would trigger a
terminating NativeCommandError and crash setup.

* fix: suppress stderr error records from Invoke-LlamaHelper

ErrorActionPreference=Continue prevents termination but PowerShell
still displays stderr lines as visible ErrorRecord objects. Capture
all output via 2>&1 and split stdout from stderr manually so that
stderr lines never appear on the console. When StderrPath is given
the stderr content is written to that file for diagnostics.

* fix: always rebuild llama.cpp on Windows when tag is master

When the requested llama.cpp tag is "master" (a moving target), skip
the "already built" early exit so the build path runs and syncs to
the latest commit. Without this, existing llama-server binaries from
an older build (e.g. b8635 which lacks Gemma 4 support) are reused
and model loading fails.

Pinned tags (e.g. b8635) still skip the rebuild when the binary
already exists, since the tag is immutable.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
2026-04-02 11:43:46 -07:00
Daniel Han
7023e2a4ff
fix(studio): prioritize curated defaults over HF download ranking in Recommended (#4792)
The model list merge order was `top_gguf + top_hub + static_models`,
which meant the HF download-ranked models always came first. New models
like Gemma 4 have low download counts and were not in the HF top-40,
so they got buried after 80 other models despite being at the top of
the curated static defaults in defaults.py.

Flip the merge to `static_models + top_gguf + top_hub` so editorial
picks (new model launches, promoted models) always appear first in the
Recommended section, with HF popularity backfilling after.

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
2026-04-02 10:46:53 -07:00
Roland Tannous
0446d46689
fixed name (#4791) 2026-04-02 21:04:42 +04:00
Daniel Han
1ce83c40aa
fix(studio): build llama.cpp from master instead of latest release tag (#4790)
The latest ggml-org/llama.cpp release (b8635) does not include Gemma 4
support (ggml-org/llama.cpp#21309 merged after the release was cut).
This causes `llama-server` to fail with "unknown model architecture:
gemma4" when loading Gemma 4 GGUFs.

Temporarily default _DEFAULT_LLAMA_TAG to "master" so all new installs
build from the llama.cpp master branch which includes Gemma 4 support.
Once a new upstream release is cut with Gemma 4, this can be reverted
back to "latest".

Changes:
- setup.sh: add _DEFAULT_LLAMA_TAG="master" maintainer default
- setup.ps1: add $DefaultLlamaTag="master" maintainer default
- install_llama_prebuilt.py: change DEFAULT_LLAMA_TAG fallback to "master"

Users can still override via UNSLOTH_LLAMA_TAG env var.
2026-04-02 09:45:56 -07:00
Daniel Han
2af53bf9a6
Pin transformers and huggingface-hub in main Studio venv (#4788)
Revert the >= loosening from f9c4b08 back to exact pins.
Using transformers>=4.57.6 allows pip to install 5.x into the main
Studio venv, which breaks huggingface_hub imports
(is_offline_mode removed in newer hub versions).

The main venv must stay on transformers==4.57.6 and
huggingface-hub==0.36.2. The 5.x version lives only in .venv_t5/
and is dynamically switched via sys.path at runtime.
2026-04-02 09:21:30 -07:00
Daniel Han
a241c58d84
Use transformers v5.5-release branch and pin to 5.5.0 (#4786)
The v5.5-release branch now exists on huggingface/transformers.
Use transformers==5.5.0 for all install paths and
git+transformers.git@v5.5-release for the MLX installer.

Also bumps huggingface_hub from 1.7.1 to 1.8.0 in setup.sh and
setup.ps1 to stay consistent.
2026-04-02 09:10:02 -07:00
Daniel Han
a353557249
Force llama.cpp to always use mainline ggml-org (#4785)
Hardcode the release repo to ggml-org/llama.cpp and remove the
UNSLOTH_LLAMA_RELEASE_REPO and UNSLOTH_LLAMA_SOURCE env var overrides
so that all users always build/download from mainline llama.cpp.
2026-04-02 09:03:00 -07:00
Daniel Han
f1c3b9caa9
Pin Gemma-4 transformers requirement to 5.5.0 stable (#4784)
Gemma-4 support landed in transformers main
(huggingface/transformers#45192). Update the version pin from
5.5.0.dev0 to 5.5.0 across loader, Studio version switcher,
and the MLX installer. Also fix install_gemma4_mlx.sh which
referenced a non-existent v5.5-release branch -- pin it to
the correct commit (91b1ab1) instead.
2026-04-02 08:59:21 -07:00
Daniel Han
4f9986ecb9
fix(studio): improve tool-calling re-prompt for small models (#4783)
Small GGUF models (<9B) frequently generate full code or lengthy
explanations instead of calling tools, bypassing the existing
plan-without-action re-prompt mechanism. Three issues:

1. _REPROMPT_MAX_CHARS=500 was too low -- models that output full
   HTML/code responses (often 1000+ chars) never triggered the
   re-prompt at all, since it only fires on short responses.

2. _MAX_REPROMPTS=1 gave the model only one chance to comply.
   Small models often need 2-3 nudges before switching from
   text generation to tool calling.

3. The re-prompt text ("Please use the available tools...") was
   too polite for small models to follow reliably.

4. Tool-calling detection missed chat templates using Jinja
   whitespace-trimming syntax ({%- if tools -%}) since only
   ({%- if tools %}) and ({% if tools %}) were checked.

Changes:
- Raise _REPROMPT_MAX_CHARS from 500 to 2000 so longer responses
  (code blocks, multi-paragraph plans) still trigger re-prompts
- Raise _MAX_REPROMPTS from 1 to 3 for more retry budget
- Use direct, imperative re-prompt language that small models
  follow more reliably ("STOP. You MUST call a tool NOW.")
- Strengthen the system prompt tool nudge to explicitly forbid
  outputting code blocks (redirect to the python tool instead)
- Add Jinja whitespace-trimmed variants to the tool_markers
  list so all template styles are detected correctly
2026-04-02 08:59:02 -07:00
Daniel Han
f9c4b08726
UI Changes (#4782)
* UI Changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove unrelated test file

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-02 08:05:55 -07:00
Roland Tannous
3b613eb1e8
ui improvement (#4781)
* ui

* ui

* ui
2026-04-02 07:57:47 -07:00
Daniel Han
c8d311a053
feat(studio): display images from Python tool execution in chat UI (#4778)
* feat(studio): display images from Python tool execution in chat UI

When the model calls the Python tool to create a matplotlib plot or
other image file, the image now displays inline in the chat output
instead of being invisible to the user.

Backend:
- Detect new image files (png/jpg/gif/webp/bmp) after Python subprocess
  completes by diffing os.listdir before/after execution
- Append __IMAGES__ sentinel to tool result for frontend consumption
- Strip sentinel before injecting result into LLM context (role: tool)
  so the model never sees file paths
- Add GET /sandbox/{session_id}/{filename} endpoint with JWT auth
  (header or query param), path traversal protection, extension
  allowlist, realpath containment check, and nosniff header

Frontend:
- Parse __IMAGES__ sentinel in tool_end SSE events, create structured
  result with text/images/sessionId
- Render <img> tags in Python tool UI pointing at the sandbox endpoint

Also fixes a bug where SyntaxError in user code was misreported as
"unsafe code detected" instead of showing the actual Python traceback.
The _check_code_safety function now lets SyntaxError pass through to
the subprocess for a proper error message.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix(studio): improve SVG detection and strip XML preamble

Handle <?xml ...?> declarations before <svg> tags in code fences,
strip XML declaration from SVGs before data URI rendering, and
update the sloth suggestion prompt to request showing code.

* fix(studio): persist parentId so retries survive reload

The append() handler was destructuring only { message } from
ExportedMessageRepositoryItem and discarding parentId. When loading
a saved thread, load() used ExportedMessageRepository.fromArray()
which chains all messages sequentially, flattening retry branches
into a linear list.

Now append() writes parentId to the MessageRecord, and load()
reconstructs the tree when parentIds are present. Old threads
without parentId fall back to the existing fromArray() behavior.

* fix(studio): address review findings for image display and retry persistence

Image detection:
- Use mtime comparison instead of filename-only diff so overwritten
  files (e.g. plt.savefig("chart.png") called twice) are detected

Sentinel parsing:
- Use rsplit/lastIndexOf instead of split/indexOf so user code that
  prints __IMAGES__: does not collide with the backend sentinel

Mixed legacy/new threads:
- For old messages without a stored parentId, infer sequential parent
  from the previous message instead of null, preventing multiple roots

Sandbox endpoint:
- Change Cache-Control from "public, max-age=3600" to "private,
  no-store" since these are authenticated responses

---------

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-02 05:08:16 -07:00
Lee Jackson
5a5f1a4f34
studio: fix chat font changes leaking outside chat page (#4775)
* fix(frontend): scope sans font overrides to chat thread only

* fix(frontend): use font-sans fallback for heading stack and simplify chat font rules

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-04-02 05:04:23 -07:00
DoubleMathew
1ce8a8e7cd
Feat/custom llama prebuilt (#4771)
* update logic to incorporate custom prebuilt installs

* bug fixes

* update for review comments

* fix tags

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Separate test changes from main PR

Move test file changes out of this PR to keep the diff focused on
the install_llama_prebuilt.py and setup script changes. Test updates
will be submitted in a follow-up PR.

* Fix branch ref normalization and harden JSON parsing

- Add checkout_friendly_ref() to strip refs/heads/ prefix from branch
  refs before emitting them in SourceBuildPlan. git clone --branch does
  not accept fully qualified refs like refs/heads/main.
- Apply normalization in source_build_plan_for_release() and the
  direct-ref fallback in resolve_source_build_plan().
- Allow validated_checksums_for_bundle() to accept releases that carry
  only an exact-commit source archive without the legacy upstream-tag
  source tarball.
- Add 2>/dev/null || true guards to all inline python -c JSON parsing
  in setup.sh so a malformed payload does not abort the script under
  set -e.

* Fix Windows CUDA asset ordering and tag ref normalization

- Reorder windows_cuda_upstream_asset_names to prefer the main binary
  archive (llama-{tag}-bin-win-cuda-*) over the cudart sidecar archive
  (cudart-llama-bin-win-cuda-*). The cudart ZIP only contains CUDA
  runtime DLLs, not llama-server or llama-quantize binaries.
- Extend checkout_friendly_ref to also strip refs/tags/ prefix for tag
  refs, matching the refs/heads/ handling for branch refs.

* Simplify JSON parsing consistency in setup.sh

Use json.load(sys.stdin) consistently for all inline JSON parsing
in setup.sh, instead of the more complex json.loads(raw) pattern
on the install-tag resolution path. The 2>/dev/null || true guard
already handles empty/malformed input gracefully.

* Fix source build plan fallback for commit ref kind in PR #4771

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <daniel@unsloth.ai>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-04-02 04:52:26 -07:00
Daniel Han
b20efc370a
Add regression tests for custom llama prebuilt installer (#4772)
Expand test coverage for install_llama_prebuilt.py:
- Add tests for source build plan resolution with custom repos
- Add tests for branch/commit/PR ref matching and normalization
- Add tests for manifest checksum validation
- Add tests for Windows CUDA upstream asset name patterns
- Update capsys checks to capture stderr after log() redirect
2026-04-02 04:45:09 -07:00
Michael Han
e2fd946fe1
Add files via upload 2026-04-02 03:00:10 -07:00
Michael Han
31d6aeb197
Unsloth new logo 2026-04-02 02:58:21 -07:00
Daniel Han
e4d1499230
fix(studio): prevent small models from stalling on tool-calling tasks (#4769)
* fix(studio): prevent small models from stalling on tool-calling tasks

Small GGUF models (< 9B params) in "Think, Search, Code" mode would
often describe what they planned to do ("Let me create this dashboard")
and then stop generating without ever calling a tool.

Three changes:

1. Simplify web_tips for small models: remove the "fetch its full content
   by calling web_search with the url parameter" guidance for models < 9B.
   This multi-step instruction causes small models to plan elaborate
   search-then-fetch-then-code sequences they cannot reliably execute.

2. Add "always call tools directly" imperative to the system prompt nudge
   so models act immediately instead of narrating their intentions.

3. Add plan-without-action re-prompt in the agentic loop: when the model
   emits planning text (matching patterns like "let me", "I'll", etc.)
   without calling any tool, inject a nudge asking it to call the tool
   and continue the loop. Capped at 2 re-prompts per request.

Benchmarked with Qwen3.5-4B-GGUF (N=5 trials per variant):
- Baseline: 40% of requests had any tool call
- Combined fix: 100% of requests had at least one tool call

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-02 02:11:07 -07:00
Daniel Han
dc0729aadf
Add regression test for shell injection fix in GGML conversion (#4773)
AST-based test ensures subprocess.Popen calls in GGML conversion functions
use argv lists instead of shell=True. Companion to PR #4768.
2026-04-02 00:10:47 -07:00
mateeaaaaaaa
752cef3299
fix(security): shell injection in GGML export conversion (#4768)
* Fix shell injection in GGML conversion paths

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove test file from security fix PR

Move test_save_shell_injection.py to a separate PR to keep this PR focused on the security fix itself.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-04-02 00:10:43 -07:00
AdamPlatin123
ba8081fc96
fix(chat): correct loading text for cached models during inference (#4764)
Distinguish between actual network downloads and GPU memory loading for cached LoRA adapters in Studio chat.

- Add isCachedLora detection for local LoRA adapter paths using comprehensive cross-platform regex (Unix, Windows, UNC, WSL, tilde)
- Thread isCachedLora through loadInfo to chat-page inline status for proper 3-way distinction (cached / local LoRA / downloading)
- Skip download progress polling for cached LoRA models (no useless /download-progress API calls)
- Fix initial toast state to use isCachedLoad consistently instead of only checking isDownloaded
- Fix cancelLoading toast to not mention background downloads for cached/local loads
- Keep download-specific text ("Downloading model..." / "Download complete") inside the download-only polling block
2026-04-01 20:24:48 -07:00
Lee Jackson
ca4ea8b9fb
studio: align composer/code, unify fonts, and remove tool collapse jitter (#4763)
- Add min-w-0 guards to thread/message/markdown containers to prevent
  content overflow past the composer width
- Unify chat typography from Hellix/Space Grotesk to the sans stack,
  keeping monospace for code blocks and inline code
- Restructure desktop navbar right-side controls with shrink-0 wrappers
  for consistent spacing across HoverCard roots
- Soften tool-call label styling (font-medium + text-foreground/85
  instead of bold)
- Add responsive code block sizing via @container queries
- Add horizontal scrolling for wide code blocks within the thread column
- Scope list-item code block alignment CSS to .aui-thread-root
- Preserve useScrollLock in tool-fallback and tool-group collapsibles
- Fall back to bg-background on ViewportFooter when hideComposer is true
- Widen inline code monospace selector to cover th, blockquote, and
  heading elements
- Remove unused @fontsource-variable/space-grotesk import
2026-04-01 19:57:10 -07:00
DoubleMathew
71b934ef9d
Fix custom llama.cpp source builds and macos metal source builds (#4762)
* Fix script unbound variable error

* remove stale test script, add llama.cpp metal source builds, update tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix Metal precedence, test sync, and add behavioral tests

- Move macOS arm64 Metal check before CUDA/ROCm in GPU backend
  decision chain so Metal is not bypassed when nvcc is in PATH
- Remove RPATH flags from CPU fallback CMAKE_ARGS (only needed
  for Metal library linking)
- Update test_llama_pr_force_and_source.py to match _CLONE_ARGS
  rename from _CLONE_BRANCH_ARGS in setup.sh
- Add confirm_install_tree guard test for
  existing_install_matches_choice
- Add TestMacOSMetalBuildLogic bash subprocess tests verifying
  Metal flag selection, nvcc precedence, and CPU fallback behavior

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix Metal CPU fallback to also cover cmake build failures and update tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* 1. _GPU_BACKEND_FRAGMENT synced -- removed dead CPU_FALLBACK_CMAKE_ARGS= init (6/8)
2. RPATH assertion replaced -- new test_macos_arm64_cpu_fallback_args_exclude_rpath checks the actual runtime CPU_FALLBACK_CMAKE_ARGS output for @loader_path and -DCMAKE_BUILD_WITH_INSTALL_RPATH=ON (6/8)
3. _TRY_METAL_CPU_FALLBACK=false reset after both configure-failure and build-failure fallback branches in setup.sh (4/8)
4. macOS test now removes libmtmd.0.dylib instead of the platform-agnostic convert_hf_to_gguf.py (3/8)
5. Empty-string tag test added -- test_empty_tag_omits_branch_flag for resolved_tag= (2/8)
6. RPATH checks on cmake call logs -- both fallback tests now assert @loader_path and -DCMAKE_BUILD_WITH_INSTALL_RPATH=ON are absent from CPU fallback cmake calls, plus baseline flag preservation (multiple)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tests clean up

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-01 14:06:39 -05:00
Daniel Han
39fe23ded8
Tests for architecture-aware KV cache estimation (#4760)
* test: add 66 tests for architecture-aware KV cache estimation

Covers all 5 estimation paths (MLA, Hybrid Mamba, Sliding Window,
Standard GQA, Legacy), GGUF parser for 8 new metadata fields,
_can_estimate_kv gate conditions, quantization scaling, edge cases,
path priority ordering, and lifecycle (init/unload/reparse).

Zero external dependencies beyond pytest. No GPU or network required.
Cross-platform (Linux, macOS, Windows, WSL).

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-01 06:13:37 -07:00
Daniel Han
653eb3819a
fix(studio): allow context length slider to reach model's native limit (#4746)
* fix(studio): allow context length slider to reach model's native limit

The context length slider was hard-capped to the VRAM-estimated maximum,
preventing users from requesting higher context even though the backend
already handles it safely (multi-GPU selection, --fit fallback). Expose
the model's native context length from GGUF metadata as a separate API
field and use it as the slider ceiling instead. Add an amber warning
when the selected context exceeds the estimated VRAM capacity.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Raise VRAM budget to 90% and add native_context_length tests

Increase the GPU memory utilization threshold from 70% to 90% across
_select_gpus and _fit_context_to_vram, allowing longer context lengths
before VRAM capping kicks in.

Add 33 tests for the native_context_length feature covering the backend
property, context value separation invariants, Pydantic models, route
completeness, edge cases, and cross-platform binary I/O.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-01 06:12:52 -07:00
Daniel Han
d22b2a18f9
fix: add tokenizers to no-torch deps and TORCH_CONSTRAINT for arm64 macOS py313+ (#4748)
* fix: add tokenizers to no-torch runtime deps and add TORCH_CONSTRAINT for arm64 macOS py313+

Two installer fixes:

1. Add `tokenizers` to `no-torch-runtime.txt` before `transformers`.
   Without it, `from transformers import AutoConfig` crashes on startup
   because `--no-deps` skips transitive dependencies.

2. Add `TORCH_CONSTRAINT` variable to `install.sh`. On arm64 macOS with
   Python 3.13+, tighten the torch requirement to `>=2.6` since torch
   <2.6 has no cp313 arm64 wheels. The variable replaces the previously
   hard-coded constraint in the uv pip install line.

Includes 66 tests (42 pytest + 24 bash) covering:
- Structural checks on install.sh, install.ps1, no-torch-runtime.txt
- Shell snippet tests with mocked python for 13 platform/version combos
- Mock uv integration verifying correct constraint string
- E2E venv tests on Python 3.12 and 3.13 confirming AutoConfig works
- Negative control proving AutoConfig fails without tokenizers
- Full no-torch sandbox regression guards (safetensors, huggingface_hub)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix incomplete no-torch manifest and align E2E tests with real --no-deps path

- Add missing transitive deps to no-torch-runtime.txt that are required
  under --no-deps: regex, typing_extensions, filelock, httpx, httpcore,
  certifi, idna, anyio, sniffio, h11. Without these, `from transformers
  import AutoConfig` still fails after install.sh --no-torch.

- Change all E2E tests to use --no-deps (matching what install.sh does)
  instead of normal dep resolution. Previous tests passed even with an
  incomplete manifest because uv backfilled transitive deps.

- Rewrite negative control to derive from the real no-torch-runtime.txt
  with tokenizers stripped, proving the specific fix matters.

- Replace GNU-only sed -i with heredoc in shell test for macOS compat.

- Remove unused os/sys imports from Python test file.

- Quote SKIP_TORCH and mock uv paths in bash -c strings.

* Assert install succeeds before checking import results in E2E tests

Address review feedback: test_torch_not_importable and
test_tokenizers_directly_importable in Group 3 now assert that
uv pip install returns 0 before checking import behavior. This
prevents false positives when the install itself fails silently.

* Assert install succeeds in negative control and tighten error check

- Add missing install-success assertion in test_negative_control_no_tokenizers
  to prevent false positives from network/install failures.

- Tighten error message check to look for "tokenizers" in stderr or
  ModuleNotFoundError, rather than the generic "No module" substring
  which could match unrelated import failures.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-01 06:12:17 -07:00
Daniel Han
76cb48be0b
fix: studio web search SSL failures and empty page content (#4754)
- Fix SSL handshake failures (SSLV3_ALERT_HANDSHAKE_FAILURE, CERTIFICATE_VERIFY_FAILED) when fetching HTTPS pages by introducing _PinnedHTTPSConnection that separates TCP connect (to pinned IP) from TLS handshake (with real hostname for SNI/cert verification)
- Fix SSRF DNS-rebinding vulnerability: previous impl swapped conn.host before connect(), causing fresh DNS resolution; new subclass keeps TCP pinned to validated IP
- Fix SPA/JS-rendered doc sites returning empty content by rotating real browser User-Agents (Chrome/Firefox/Safari)
- Strip nav/footer from HTML-to-Markdown output so article content is not buried under navigation chrome
- Increase raw fetch cap from 64KB to 512KB so SSR article content is reached on GitBook/Docusaurus/Next.js pages
- Fix IPv6 address bracketing in URL netloc construction
- Hoist SSL context, handler classes, and stdlib imports to module level (created once, not per-call)
- Use consistent UA across redirect hops to avoid breaking session-aware bot detection
2026-04-01 06:12:02 -07:00
Daniel Han
f84c2d03d3
Add installer test coverage for prebuilt llama.cpp changes (#4756)
Split out from #4741 to keep the main PR focused on installer logic.

- New test_install_llama_prebuilt_logic.py: tests for resolve logic,
  fallback behavior, env_int, busy/lock handling
- New test_validate_llama_prebuilt.py: validator tests for staged
  release_tag/upstream_tag handling
- New test_llama_pr_force_and_source.py: tests for PR_FORCE and
  LLAMA_SOURCE maintainer defaults
- Updated test_selection_logic.py: expanded selection/fallback coverage
- Updated test_pr4562_bugfixes.py: updated bugfix tests for new logic
- Updated smoke_test_llama_prebuilt.py: minor update
2026-04-01 06:06:29 -07:00
DoubleMathew
428efc7d95
Resolve latest usable published llama.cpp release instead of fixed pinned tag (#4741)
Replaces the fixed prebuilt llama.cpp tag with dynamic published-release
resolution, adds bounded fallback across older published releases, and
introduces maintainer-editable defaults for PR/source overrides.

Changes:
- Resolve latest from the latest usable published release in unslothai/llama.cpp
- Use the selected release upstream_tag as the authoritative llama.cpp version
- Prefer Unsloth-published platform assets when available
- Fall back to same-tag upstream ggml-org/llama.cpp assets where allowed
- Keep Linux CUDA anchored to Unsloth-published CUDA bundles only
- Add bounded fallback across older Unsloth published releases
- Add separate busy/in-use install handling (exit code 3)
- Skip reinstall when the installed bundle already matches the selected candidate
- Add maintainer-editable _DEFAULT_LLAMA_PR_FORCE and _DEFAULT_LLAMA_SOURCE
- Harden env parsing so malformed installer env vars do not crash import-time fallback logic
- Honor UNSLOTH_LLAMA_RELEASE_TAG in all resolve steps
- Always sync git remote URL in existing-checkout path
2026-04-01 06:06:17 -07:00
Daniel Han
5d7d882ce6
Fix save_pretrained_merged for full-finetuned models (#4755)
* Fix save_pretrained_merged for full-finetuned models

save_pretrained_merged and push_to_hub_merged silently do nothing when
the model is not a PeftModel (i.e. full finetuning without LoRA).
merge_and_overwrite_lora returns None immediately for non-PeftModel,
and unsloth_generic_save does not check the return value.

Add a non-PeftModel branch in unsloth_generic_save that falls back to
model.save_pretrained / model.push_to_hub. When save_method contains
"16bit", cast weights to bfloat16 (or float16) via a state_dict copy
to honor the user's intent without mutating the live model.

The existing PeftModel (LoRA) code path is unchanged.

* Forward create_pr and revision to tokenizer.push_to_hub

The tokenizer push_to_hub call was missing create_pr and revision,
which could cause the tokenizer to push to the wrong branch or
bypass PR creation when the model push uses them.

* Honor merged_16bit dtype contract for full-finetuned models

Cast state_dict to bfloat16/float16 when save_method contains "16bit"
to match the documented behavior of save_pretrained_merged. Also pass
state_dict and save kwargs consistently to both save_pretrained and
push_to_hub paths.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address review feedback for PR #4755

- Simplify PeftModel isinstance check (PeftModelForCausalLM inherits
  from PeftModel)
- Add is_main_process guard for distributed training
- Forward variant to save_pretrained
- Set tokenizer padding_side to "left" before saving (matches other
  save paths)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-01 06:05:37 -07:00
Daniel Han
77e1a9edc9
feat(studio): architecture-aware KV cache VRAM estimation (#4757)
* feat(studio): architecture-aware KV cache VRAM estimation

Replace the single legacy formula (2 * n_kv_heads * head_dim * n_layers
* n_ctx * bpe) with 5-path estimation that reads 8 additional GGUF
metadata fields:

  1. MLA (DeepSeek-V2/V3, GLM-4.7, GLM-5, Kimi-K2.5) -- K-only cache
     using compressed KV latent + RoPE; no separate V allocation
  2. Hybrid Mamba (Qwen3.5-27B, Qwen3.5-35B-A3B) -- only attention
     layers (1 in N) carry KV; Mamba layers have none
  3. Sliding Window (Gemma-3, gpt-oss) -- SWA layers cache
     min(ctx, window) tokens instead of the full context
  4. Standard GQA -- uses explicit key_length/value_length from GGUF
     instead of embed // n_heads (which is wrong for many models)
  5. Legacy fallback -- identical to old formula for old GGUFs

New GGUF fields parsed: attention.key_length, attention.value_length,
attention.sliding_window, full_attention_interval,
attention.kv_lora_rank, attention.key_length_mla, ssm.inner_size,
ssm.state_size.

Validated against 9 real GGUF files (72/72 field checks pass).
The legacy formula was off by +682% for Gemma-3 and -81% for
DeepSeek-V3.1.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix MLA fallback and SWA global/local ratio heuristic

Two fixes based on review findings:

1. MLA fallback now uses key_length_mla from GGUF metadata instead of
   hardcoded rope_dim=64. Falls back to 64 only when key_length_mla is
   absent. This ensures correct estimates for MLA variants that use
   rope dimensions other than 64.

2. SWA global/local layer ratio changed from 50/50 to 1/4 (25% global,
   75% SWA). Most sliding window architectures have predominantly local
   layers (Gemma-3 uses ~17% global, gpt-oss uses ~50%). The 1/4
   heuristic is closer to the common case and still a large improvement
   over the legacy formula which ignores SWA entirely.

* Tighten _can_estimate_kv gate and treat sliding_window=0 as disabled

Two additional fixes from review round 1 (5/8 and 4/8 reviewer consensus):

1. _can_estimate_kv now requires BOTH key_length AND value_length for
   the explicit-dims path. Previously key_length alone was enough,
   which could cause silent fallthrough to the legacy formula with
   fabricated defaults (n_kv=1, head_dim=128) when value_length was
   absent from the GGUF.

2. SWA path now requires sliding_window > 0. Some GGUFs use 0 as a
   disabled sentinel. Without this guard, min(ctx, 0) would zero out
   all SWA layer contributions, severely underestimating KV cache.

* Fix MLA n_kv safety and use ceiling division for hybrid path

Addresses Gemini Code Assist review findings:

1. MLA path now uses n_kv_mla = n_kv_heads or 1 (not n_heads). This
   prevents a 128x overestimate for DeepSeek-V3 if head_count_kv is
   absent from the GGUF (n_heads=128 would have been used instead).

2. Hybrid path now uses ceiling division for attention layer count.
   This prevents undercounting by 1 when n_layers is not perfectly
   divisible by full_attention_interval.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-01 06:04:12 -07:00
Daniel Han
3f3757b143
Fix forward compatibility with transformers 5.x (#4752)
* Fix forward compatibility with transformers 5.x

Tested on transformers 4.57.6, 5.3.0, and 5.4.0. All changes are no-ops
on transformers 4.x.

1. Skip exec-based config patching for transformers >= 5.0

   Config classes in v5 use @strict, @auto_docstring, and interval()
   which break exec(inspect.getsource(...)). Those configs already use
   rope_parameters (the v5 replacement for rope_scaling).

2. Slice position_ids to last token in fast_forward_inference

   Transformers 5.x generate() accumulates position_ids as
   [batch, full_seq_len] across decode steps instead of [batch, 1].
   cos[position_ids] then produces the wrong shape for rotary
   embeddings. Fixed in llama, qwen3, falcon_h1, gemma2, cohere,
   granite. No-op on 4.x since position_ids is already [batch, 1].

3. Handle @strict config kwargs for sequence classification

   num_labels, max_position_embeddings, id2label etc. are set on the
   config object and passed via config= instead of as kwargs.
   AutoModelForSequenceClassification routing added to FastModel loader.

4. Exclude modernbert from flex_attention

   ModernBERT with flex_attention hits CUDA illegal memory access in
   create_block_mask. Falls back to eager attention safely.

5. Propagate token_type_ids and mm_token_type_ids through GRPO VLM path

   Gemma3 Vision requires token_type_ids during training. Qwen3VL
   requires mm_token_type_ids for M-RoPE. Extract from inputs in
   compute_loss, pass to grpo_accumulated_loss, and extend
   mm_token_type_ids for completion tokens in
   _generate_and_score_completions.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add try/except safety net around config exec for pre-release transformers versions

* Pop config-level kwargs in seqclass path and use except Exception

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-01 06:04:03 -07:00
Roland Tannous
41df4ec437
feat(studio): strip org prefix in model search to surface unsloth variants (#4749)
When searching for a specific publisher model (e.g. `openai/gpt-oss-20b`), the
unsloth search used the full `openai/gpt-oss-20b` string with `author=unsloth`,
which returned zero results because no unsloth model contains the publisher
prefix in its name. Users never discovered unsloth variants.

This PR strips the org prefix for publisher-qualified queries so unsloth variants
surface, then pins the original publisher model after a small batch of unsloth
results. Plain queries (no slash) and unsloth-prefixed queries are unchanged.

- Strict regex (`/^([^/\s]+)\/([^/\s]+)$/`) only triggers on valid `owner/repo`
  identifiers; incomplete typeahead, multi-slash, and URL-like inputs are rejected
- Queries for `unsloth/...` models (case-insensitive) keep the full 20-result
  prefetch and secondary sort
- Pinned model lookup fires in parallel with the unsloth prefetch
- Canonical-name dedup prevents duplicates when HF normalizes casing
- Publisher detection extracted into a single `useMemo` block
2026-04-01 04:37:28 -07:00
Leo Borcherding
63ad6dbd6d
Fix OOM model styling in Studio model selectors (#4738)
Replace strikethrough + opacity-50 OOM styling with gray text and red pill badge across all Studio model selectors (chat, training, onboarding).

- Use gray-500/gray-400 for OOM model names (better contrast than strikethrough)
- Red pill badge for OOM indicator with light/dark mode support
- Scope GGUF gray override to quant name only so downloaded/recommended labels keep colors
- Add !important on TIGHT/OOM badges to resist ComboboxItem hover overrides
2026-04-01 02:06:49 -07:00
Daniel Han
6c0826a9e4
Fix Windows local GGUF model loading crash (#4730)
* Fix Windows "Non-relative patterns are unsupported" when loading local GGUF models

When a user loads a GGUF model from a local Windows path (e.g.
C:\Users\danie\.lmstudio\models\unsloth\functiongemma-270m-it-GGUF),
the model identifier contains backslashes and a drive letter. Both
load_model_defaults() and _has_specific_yaml() constructed a YAML
filename from the full absolute path and passed it to Path.rglob(),
which rejects non-relative patterns on Windows.

Fixed by detecting Windows-style paths (drive letters, UNC paths,
backslashes) in addition to Unix-style paths, and using only the
directory basename for the YAML filename lookup when the identifier
is a local filesystem path.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Refactor: reuse is_local_path helper, fix case-sensitive suffix lookup

- Replace inline local-path detection in model_config.py and
  inference_config.py with the existing is_local_path() from utils.paths,
  which already handles Unix, Windows drive-letter, UNC, and backslash paths
- Fix case-sensitive suffix lookup in load_model_defaults(): the
  _REVERSE_MODEL_MAPPING is lowercase-keyed, so suffix comparisons must use
  .lower() to match paths like /path/to/Spark-TTS-0.5B/LLM

* Fix WSL path parsing and _has_specific_yaml suffix lookup

- Use normalize_path() before Path() operations so backslash Windows
  paths (e.g. C:\Users\...\model) are correctly split on POSIX/WSL hosts
  where pathlib treats backslashes as literal characters
- Add suffix-based (2-component and 1-component) lookup to
  _has_specific_yaml() so it matches the same resolution rules as
  load_model_defaults(), fixing wrong inference params for local
  suffix-mapped models like Spark-TTS-0.5B/LLM

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-04-01 01:38:09 -07:00
Datta Nimmaturi
256c6e4884
Refactor flex attn to prefer flash if possible (#4734)
Replaces prefer_flex_attn_if_supported (which only returned flex_attention or None) with determine_attention_implementation, a centralized hierarchy: FA2 > Flex > SDPA > Eager.

Changes:
- New determine_attention_implementation function in _utils.py with clear priority chain
- _set_attn_impl helper to stamp config consistently
- _FLEX_EXCLUDED_MODELS / _FLEX_EXCLUDED_PREFIXES for model-specific exclusions
- Gemma3N explicit eager override in vision.py (timm vision towers)
- Preserved sdpa fallback for unmapped/remote-code vision configs
- Config re-stamped to eager when supports_sdpa guard fires

Co-authored-by: Datta Nimmaturi <Datta0@users.noreply.github.com>
2026-04-01 00:30:21 -07:00
Wasim Yousef Said
d63cc57e1e
fix: clear tool status badge immediately after tool execution (#4733)
* fix: clear tool status badge immediately after tool execution

The tool status timer badge (Searching 1s, 2s...) persisted after
tool calls finished because the status clear event was only sent
at the start of the next generation iteration, not after tool
execution completed.

Backend: yield status clear after all tools finish in the agentic
loop iteration, before continue starts the next generation pass.

Frontend: debounce badge visibility by 300ms so sub-second tool
calls dont flash the badge.

* Fix debounce regression for consecutive tool calls

Only apply the 300ms show-delay when transitioning from idle to
tool-active. When switching between consecutive tools in the same
turn (e.g. web_search -> python), keep the badge visible immediately
so it does not flicker or disappear during multi-tool runs.

* Delay wasActiveRef reset to bridge inter-iteration tool gaps

The backend emits a status-clear event between tool iterations,
which was resetting wasActiveRef immediately and causing the next
tool to be re-debounced (300ms hidden gap between consecutive tools
in the same turn). Now the ref reset is delayed by 500ms so a
follow-up tool within the same agentic turn shows the badge
immediately, while a genuinely new turn still gets the debounce.

* Use thread lifecycle to track tool-run boundaries

Replace the 500ms wall-clock timeout with the actual thread.isRunning
state to determine when wasActiveRef should reset. This properly
handles all cases:
- Consecutive tools within the same run stay visible without flicker
- The badge hides only when the thread run actually ends
- New turns always get a fresh 300ms debounce on the first tool
- No heuristic timeout that can misfire on slow or fast inference

* Consolidate wasActiveRef reset into single effect

Removes the separate isThreadRunning effect to avoid a race where
the ref resets before the tool-status effect reads it (when
isThreadRunning flips to false before setToolStatus(null) from
the adapter's finally block). Now wasActiveRef resets only when
both toolStatus is null AND the thread run has ended, eliminating
any flicker on the last tool of a run.

* Simplify debounce: use visible state instead of ref tracking

Drop wasActiveRef entirely and use the visible state as the
debounce gate. When the badge is not yet on screen, debounce
for 300ms before showing. When already visible from a prior tool,
keep showing immediately. This correctly handles all cases:
- All fast tools (<300ms) are suppressed, not just the first
- Consecutive tools after the badge is shown stay visible
- Badge persists across inter-iteration clears while thread runs
- New turns get a fresh debounce after visible resets

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-04-01 00:28:38 -07:00
Wasim Yousef Said
4fb9778988
feat: move folder management into model selector dropdown (#4731)
* refactor: move folder management from sidebar into model selector

* Fix folder management: restore LoRA picker sync, error handling, caching

- Restore onFoldersChange callback to keep LoRA adapter picker in sync
  when scan folders are added/removed (fixes regression from sidebar move)
- Thread onFoldersChange through ModelSelector -> HubModelPicker prop chain
- Add module-level _scanFoldersCache to prevent folder list flash on re-open
- Surface error toast on folder removal failure instead of silently ignoring
- Guard handleAddFolder against concurrent double-submit via folderLoading
- Clear folderInput on Escape key dismiss to prevent stale input on re-open
- Add refreshLocalModelsList and refreshScanFolders to useEffect dep array

* Fix compare-mode folder sync, Escape key propagation, cancel toggle state

- Wire onFoldersChange through CompareContent/GeneralCompareContent so
  compare-mode selectors also refresh local models after folder changes
- Add e.stopPropagation() on Escape key in folder input to prevent
  Radix Popover from closing the entire model selector dropdown
- Add e.preventDefault() on Enter key to prevent form submission
- Clear folderInput and folderError when cancel toggle hides the input,
  matching the Escape key behavior for consistency

* Fix folder mutation state ordering and touch accessibility

- Use optimistic updates for add/remove so the folder list reflects
  changes immediately instead of waiting on a second listScanFolders
  round-trip that could silently fail.
- Move refreshScanFolders out of the finally block in handleRemoveFolder
  so it runs after the cache update, not after onFoldersChange.
- Make the remove button visible on touch/mobile devices and reachable
  via keyboard focus (opacity-100 on small screens, focus-visible).
- Add aria-label to the remove button for screen readers.

* Deduplicate optimistic folder add to match backend behavior

The backend returns the existing ScanFolderInfo row when adding a
path that is already registered. The optimistic update was blindly
appending the returned row, producing duplicate entries and React
key warnings. Now checks by id before appending.

* Add aria-label to folder toggle button and strengthen dedup check

- Add aria-label to the +/cancel icon button for screen readers.
- Extend optimistic dedup check to also compare by path, not just id,
  to handle edge cases where the cache is stale.

---------

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-31 23:15:50 -07:00
Lee Jackson
2cac3e8e4d
studio: Polish Windows installer/setup logs (#4736)
* style(windows): clean installer/setup log output and remove seeded credential banner

* Keep startup credential hint without exposing plaintext password

Print the username and .bootstrap_password file path on first-run
admin creation instead of the raw password. Headless / Docker / SSH
operators still get a startup-time hint for initial sign-in, and the
plaintext credential no longer appears in terminal output or logs.

---------

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
2026-03-31 23:12:42 -07:00
Daniel Han
6984e118eb
Bump installer minimum version pin to 2026.3.18 (#4729)
Matches the latest PyPI release.
2026-03-31 07:00:51 -07:00
Daniel Han
cfeb8c3245 Versioning 2026-03-31 06:51:34 -07:00
Wasim Yousef Said
1e8875584d
feat: custom scan folders for GGUF model discovery (#4723)
* feat: add scan_folders table and CRUD functions to studio_db

* feat: add scan folders API endpoints and integrate into model scan

* feat: add scan folders API client and update source types

* feat: add custom source to model filters and selector

* feat: add Model Folders section to chat settings sidebar

* style: fix biome formatting in ModelFoldersSection

* fix: address review findings for custom scan folders

empty string bypass, concurrent delete crash guard,
Windows case normalization, response_model on endpoints,
logging, deduplicated filter/map, module level cache for
custom folder models, consistent source labels, handleRemove
error surfacing, per folder scan cap

* fix: show custom folders section regardless of chatOnly mode

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor: extract shared refreshLocalModelsList in pickers

* Harden custom scan folder validation and scanning

- Validate path exists, is a directory, and is readable before persisting
- Apply per-folder model cap during traversal instead of after (avoids
  scanning millions of inodes in large directories)
- Wrap per-folder scan in try/except so one unreadable folder does not
  break the entire /api/models/local endpoint for all callers
- Normalize case on Windows before storing so C:\Models and c:\models
  dedup correctly
- Extend macOS denylist to cover /private/etc and /private/tmp (realpath
  resolves /etc -> /private/etc, bypassing the original denylist)
- Add /boot and /run to Linux denylist

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Improve scan robustness and preserve Windows path casing

- Preserve original Windows path casing in DB instead of lowercasing
  (normcase used only for dedup comparison, not storage)
- Catch PermissionError per child directory so one unreadable subdirectory
  does not skip the entire custom folder scan
- Wrap list_scan_folders() DB call in try/except so a DB issue does not
  break the entire /api/models/local endpoint

* fix: scan custom folders for both flat and HF cache layouts

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix Windows case-insensitive path dedup with COLLATE NOCASE

Use COLLATE NOCASE on the scan_folders.path column so that the UNIQUE
constraint correctly deduplicates C:\Models and c:\models on Windows
without lowercasing the stored path. Also use COLLATE NOCASE in the
pre-insert lookup query on Windows to catch existing rows with
different casing.

* Restore early-exit limit in _scan_models_dir for custom folders

Keep the limit parameter so _scan_models_dir stops iterating once
enough models are found, avoiding unbounded traversal of large
directories. The post-traversal slice is still applied after combining
with _scan_hf_cache results.

* feat: scan custom folders with LM Studio layout too

* Fix custom folder models being hidden by dedup

Custom folder entries were appended after HF cache and models_dir
entries.  The dedup loop kept the first occurrence of each model id,
so custom models with the same id as an existing HF cache entry were
silently dropped -- they never appeared in the "Custom Folders" UI
section.

Use a separate dedup key for custom-source entries so they always
survive deduplication.  This way a model can appear under both
"Downloaded" (from HF cache) and "Custom Folders" (from the
user-registered directory) at the same time.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Harden LM Studio scan and fix COLLATE NOCASE on Linux

- Add per-child and per-publisher OSError handling in _scan_lmstudio_dir
  so one unreadable subdirectory does not discard the entire custom
  folder's results
- Only apply COLLATE NOCASE on the scan_folders schema on Windows where
  paths are case-insensitive; keep default BINARY collation on Linux
  and macOS where /Models and /models are distinct directories

* Use COLLATE NOCASE in post-IntegrityError fallback SELECT on Windows

The fallback SELECT after an IntegrityError race now uses the same
case-insensitive collation as the pre-insert check, so a concurrent
writer that stored the path with different casing does not cause a
false "Folder was concurrently removed" error.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-31 06:40:31 -07:00
Daniel Han
9a8b622306
Studio: simplify tool-call dedup and replace html2text with builtin converter (#4722)
* Simplify tool-call dedup: drop hashlib, inline helpers

The duplicate tool-call detector only compares calls within a single
request from the same JSON parser, so dict key order is guaranteed
identical for identical calls (Python 3.7+ insertion-ordered dicts).

- Replace hashlib.md5(json.dumps(...)) with name + str(args)
- Inline _tool_call_key, _is_duplicate_call, _record_tool_call
  since each was a one-liner used once
- Remove unused hashlib import

* Remove tool_calling_benchmark_results.md from repo

* Replace html2text with builtin HTML-to-Markdown converter

Drop the external html2text (GPL-3.0) dependency and its regex
fallback. Add _html_to_md.py (~190 lines, stdlib only) using
html.parser.HTMLParser that handles headings, links, bold/italic,
lists, tables, blockquotes, code blocks, and entity decoding.
Strips script/style/head tags entirely.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use json.dumps(sort_keys=True) for tool-call dedup key

str(dict) is sensitive to insertion order, so semantically identical
calls with different key ordering would bypass duplicate detection.
Switch to json.dumps with sort_keys=True for a canonical representation.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert dedup key to str(arguments)

json.dumps(sort_keys=True) is unnecessary here -- the arguments dict
always comes from the same JSON parser within a single request, so
key insertion order is deterministic (Python 3.7+).  str() is faster
and sufficient for consecutive-call dedup.

* Address review comments on _html_to_md.py

- Remove "hr" from _BLOCK_TAGS so the dedicated hr handler is reachable
- Prefix all newlines with ">" inside blockquotes (multi-line support)
- Emit full ![alt](url) for images instead of alt text only
- Replace newlines with spaces inside table cells
- Track header cells per-row (_row_has_th) instead of last-cell-only
- Strip trailing tabs in addition to spaces in cleanup regex

* Fix blockquote rendering, truncated-HTML buffer flush, and dedup key canonicalization

_html_to_md.py:
- Rewrite blockquote handling with stack-based buffer approach so nested
  blockquotes, pre blocks inside blockquotes, and multi-paragraph quotes
  all render correctly with proper "> " prefix on every line.
- Add flush_pending() to recover content from truncated HTML where closing
  tags are missing (common when _fetch_page_text caps the download size).
  Flushes open <a>, <td>, <pre>, and blockquote buffers.
- Skip <img> tags to match prior html2text ignore_images=True behavior
  and avoid data-URI amplification consuming the output budget.
- Collapse all whitespace (including newlines) in non-pre content per
  standard HTML whitespace rules: \s+ -> single space.
- Escape pipe characters in table cell content to prevent column breakage.
- Emit separator row after the first row for tables without <th> headers.
- Guard against IndexError on _ol_counter for orphan <li> elements.
- Normalize CRLF line endings before parsing.

llama_cpp.py:
- Restore canonical dedup key with json.dumps(sort_keys=True) so that
  semantically identical tool calls with different JSON key order are
  correctly detected as duplicates.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix table optional end tags, inline code whitespace, and link text normalization

_html_to_md.py:
- Extract _finish_cell() and _finish_row() helpers to handle HTML tables
  that omit optional </td>, </th>, or </tr> end tags. This is valid HTML
  and common on real web pages -- previously the parser would silently
  drop earlier cells and entire rows.
- Call _finish_cell()/_finish_row() from handle_starttag for <tr>/<td>/<th>,
  handle_endtag for </tr>/<td>/<th>/<table>, and flush_pending() so all
  three paths (normal close, implicit close, truncated HTML) use the same
  row-finalization logic including header separator emission.
- Add _in_inline_code flag so handle_data() preserves literal whitespace
  inside <code> spans instead of collapsing it. Source like
  <code>pip  install   unsloth</code> now correctly renders as
  `pip  install   unsloth` rather than `pip install unsloth`.
- Extract _finish_link() helper that normalizes accumulated link text with
  \s+ -> single space before building the Markdown link. Prevents block-
  level content inside <a> tags (e.g. <a><div>one</div><div>two</div></a>)
  from producing multiline [one\n\ntwo](href) link labels.
- Empty blockquotes now produce no output instead of a stray ">".
- Remove unused _bq_depth field (all routing uses _bq_stack).
- Flush open cells and rows in handle_endtag("table") for robustness.

* Support <ol start=N>, <dl>/<dt>/<dd>, and preserve code block whitespace

_html_to_md.py:
- Honor <ol start="N"> attribute so ordered lists preserve their original
  numbering instead of always restarting from 1. Important for docs/tutorials
  that continue numbering across sections.
- Add dl, dt, dd to _BLOCK_TAGS so definition lists (common on MDN, Python
  docs, Django docs) produce separated text instead of concatenated blobs.
- Rewrite _cleanup() to be fence-aware: content inside fenced code blocks
  is now preserved verbatim (intentional blank lines in <pre> content are
  no longer collapsed). Outside code blocks, blank runs are limited to one
  and trailing whitespace is stripped.
- Fix _prefix_blockquote() to strip trailing whitespace before collapsing
  blank lines, preventing the "\n\n \n\n" pattern from sneaking through.

* Suppress whitespace-only text nodes between table structural elements

Indented HTML tables (nearly all real-world pages) produce whitespace
text nodes between <table>, <tr>, </tr> etc. that land in the output
as leading spaces before table rows, breaking Markdown table alignment.

Skip whitespace-only text nodes when inside a table but not inside a
cell, so indentation from source HTML does not leak into the output.

* Revert dedup key to str(arguments) with explanatory comment

json.dumps(sort_keys=True) is unnecessary overhead here: arguments
always comes from json.loads on model output within a single request,
so dict insertion order is deterministic in Python 3.7+. A repeated
call from the model produces the same JSON, which parses to the same
dict repr. str() avoids re-serialization on every tool call.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-31 06:15:18 -07:00
Lee Jackson
9451bb1bac
fix(export): preserve selected/manual model on enter and blur (#4726) 2026-03-31 17:05:55 +04:00
Daniel Han
e159b93b97
studio: improve GGUF tool calling accuracy and reliability (#4700)
* studio: improve GGUF tool calling accuracy and reliability

- Add URL fetching to web_search tool so models can read full page
  content instead of only getting search snippets. Uses html2text for
  clean markdown conversion with regex fallback.
- Inject current date and behavioral guidance (URL fetch workflow,
  no repeated queries, use code for data processing) into the
  tool-use system prompt.
- Append error recovery nudge to tool results that indicate failure,
  helping small models avoid looping on the same broken call.
- Strip leaked <tool_call> XML from assistant messages in conversation
  history and from the outgoing SSE stream.
- Raise default max tool iterations from 10 to 25 across backend,
  model schema, and frontend defaults.
- Increase _MAX_PAGE_CHARS from 4k to 16k so fetched pages contain
  enough content for the model to extract useful information.
- Add "IMPORTANT: These are only short snippets" hint to search
  results so models know to fetch full pages when needed.

Tested with Qwen3.5-4B-GGUF (UD-Q4_K_XL), 10 runs before/after:
- XML leaks in responses: 10/10 -> 0/10
- URL fetch usage: 0 -> 4/10 runs
- Runs producing actual correct answers: 0/10 -> 2/10
- Average tool calls per query: 5.5 -> 3.8 (more efficient)
- Average response time: 12.3s -> 9.8s

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add tool calling benchmark results across model sizes and quants

Tested 16 configurations (4 models x 2 quants x 2 KV cache types)
with 10 runs each on NVIDIA B200.

Best config: 27B UD-Q4_K_XL + bf16 KV -- 6/10 runs found all 4
correct songs, 0 XML leaks, 131s average response time.

* Add duplicate tool-call detection and final-answer synthesis

When the model repeats the exact same tool call (same name + arguments)
twice in a row, skip execution and return a redirect message telling it
to try a different approach. This prevents the 8x-repeated-query loops
observed on 27B and 35B models.

When the tool iteration cap (25) is reached, inject a "provide your
final answer now" message before the final streaming pass. This lets
the model synthesize a useful answer from everything it gathered
instead of being silently cut off.

Tested on Qwen3.5-27B UD-Q4_K_XL (10 runs):
- Repeated query runs: 4/10 -> 2/10
- Cap hits: 1/10 -> 0/10
- All 4/4 accuracy: 5/10 -> 7/10

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix CodeQL alert: handle whitespace in script/style closing tags

The regex fallback for HTML stripping did not match closing tags
with whitespace before the angle bracket (e.g. </script >).
Use \s* before > in both script and style patterns.

* Address reviewer findings: SSRF, timeout crash, XML regex, dedup

- SSRF: resolve hostname via getaddrinfo and reject private, loopback,
  link-local, multicast, and reserved addresses before fetching
- Timeout: handle timeout=None (unlimited mode) in URL fetch path
  by defaulting to 60s instead of crashing on min(None, 60)
- Download cap: read at most max_chars*4+1 bytes instead of the
  full response body before truncating
- XML regex: match both <tool_call> and <function=...> markup in
  the history/stream cleanup (inference.py)
- CodeQL: use [^>]* in closing script/style tags to handle any
  whitespace or attributes before >
- Dedup: track whether each tool call failed so retries after
  transient errors are allowed; only block consecutive identical
  calls that both succeeded
- Final-answer synthesis: guard on max_tool_iterations > 0 so
  callers who disable tools do not get a false "used all calls" turn

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix redirect SSRF, SSE streaming regression, dedup off-by-one

- SSRF redirect bypass: disable auto-redirect in urllib, manually
  follow up to 5 hops with host validation at each step. Prevents
  public URLs from redirecting to loopback/private targets.
- SSE streaming: track prev_text on the raw cumulative and strip
  XML from the delta only, so completed tool_call tags do not cause
  the cumulative to shrink and drop trailing real text.
- Dedup off-by-one: check the immediately previous call (window=1)
  instead of requiring 2 matching history entries, so the second
  identical successful call is blocked rather than the third.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix redirect HTTPError handling and tighten error prefixes

- Redirect fix: urllib raises HTTPError (not a normal response) when
  the redirect handler returns None. Catch HTTPError for 3xx codes
  and extract the Location header from the exception object.
- Error prefixes: remove overly broad "No " prefix that matched
  "No results found." (a valid empty-search outcome, not an error).
  Replace with specific prefixes like "Blocked:", "No query provided",
  "Failed to resolve". This ensures empty search results are correctly
  classified as non-errors for duplicate-call tracking.

* Fix SSE cross-chunk XML leaks, cleanup review findings

- SSE streaming: sanitize the full cumulative text before diffing
  against the previous sanitized snapshot, so XML tags that span
  chunk boundaries are stripped correctly. The previous delta-based
  approach leaked split tags.
- DRAINING fallback: use _strip_tool_markup() helper instead of a
  manual regex that only handled <tool_call> but not <function=...>.
- Move hashlib import, _TOOL_XML_RE compile, and datetime import to
  module level per style guide.
- Remove unused _hit_tool_cap variable.

* Fix DNS rebinding, charset detection, HTTPError handling, dedup double-record

- DNS rebinding: resolve hostname once via getaddrinfo, pin the
  returned IP, rewrite the URL to connect to the pinned IP with
  a Host header. Each redirect hop re-resolves and re-validates.
  Closes the TOCTOU window between validation and connection.
- Charset: use resp.headers.get_content_charset() instead of
  hardcoding utf-8, so pages with other encodings decode correctly.
- HTTPError: return descriptive "HTTP {code} {reason}" instead of
  re-raising into a generic "Search failed" message.
- Dedup: remove redundant _record_tool_call in the duplicate branch;
  the single call at the end of the loop handles all cases.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-31 03:06:44 -07:00
Lee Jackson
815619d972
feat: add update instructions card with OS toggle and mobile expand flow (#4721)
Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
2026-03-31 14:05:05 +04:00
Roland Tannous
cc5e4fbf17
fix: auto-retry stalled HF downloads with HF_HUB_DISABLE_XET=1 (#4712)
* fix: auto-retry stalled HF downloads with HF_HUB_DISABLE_XET=1

The heartbeat thread now monitors the HF Hub cache directory for
file-size growth. If no bytes are written for 3 minutes, it sends a
"stall" message to the orchestrator, which kills the subprocess and
retries with HF_HUB_DISABLE_XET=1 (falling back from Xet to standard
HTTPS). If the retry also stalls, it errors out with a clear message.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: include transport type (xet/https) in heartbeat and stall log messages

Makes it clear in backend logs whether the download is using xet or
https transport, and which transport stalled — helpful for debugging.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: monitor HF Hub .tmp dir to avoid false stall detections

huggingface_hub downloads into .tmp/ before atomically moving to
blobs/. Without monitoring .tmp, a large shard actively downloading
for several minutes would show zero blob growth and trigger a false
stall.

* fix: scope HF cache size check to specific model being loaded

Instead of scanning every models--*/blobs directory (O(N) with cached
models), only check the specific model's blobs dir plus the global
.tmp dir. Much faster on systems with many cached models.

* Fix false stall detection on cached/local models and cleanup issues

- Only fire stall if download activity was observed (cache size changed
  at least once). Previously, any model load taking >180s would trigger
  a false stall, even for already-cached or local models where no
  download is happening.
- Return -1 from _get_hf_cache_size on exception to distinguish
  "unable to measure" from "genuinely zero bytes". Skip stall logic
  when measurement fails.
- Add _shutdown_subprocess before raising on terminal stall path to
  prevent leaking a stuck subprocess.
- Detect pre-existing HF_HUB_DISABLE_XET=1 in the parent environment
  to avoid a redundant retry cycle when Xet is already disabled.
- Remove global .tmp directory scanning (not used by modern
  huggingface_hub; in-progress downloads use .incomplete files in
  blobs/ which are already captured by iterdir).
- Add f.is_file() guard in cache size calculation.
- Replace em dashes with ASCII dashes for Windows terminal compat.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Harden stall detection edge cases

- Guard -1 to valid value transition: when initial _get_hf_cache_size
  returns -1 (error) and later recovers to a real value, do not count
  that as download activity. Only set saw_download_activity when the
  previous measurement was also valid (>= 0).
- Move os import to top-level in orchestrator.py instead of inline
  import os as _os.
- Fix misleading comment about post-download protection.

* Use .incomplete files to detect active downloads for stall detection

Replace the saw_download_activity heuristic with direct .incomplete file
detection. huggingface_hub creates *.incomplete files in blobs/ during
active downloads and removes them on completion. This gives a reliable
signal for whether a download is actually in progress.

Benefits:
- Cached models: no .incomplete files -> no stall fired even after 180s
- Post-download init (quantization, GPU loading): .incomplete files gone
  so stall timer resets, long init phases are not killed
- Pre-download hangs (XET handshake stall): .incomplete files are
  created at download start, so zero-byte stalls are now detected
- No more false positives from -1 to valid measurement transitions

The _get_hf_download_state function now returns (total_bytes,
has_incomplete) tuple or None on error, replacing _get_hf_cache_size.

* Add debug logging to download state exception handler

Log the exception at debug level when _get_hf_download_state fails,
instead of silently returning None. Helps with troubleshooting cache
measurement issues.

* Watch both adapter and base model repos for LoRA stall detection

When loading a LoRA adapter, the actual download bottleneck is often
the base model, not the adapter itself. Update the heartbeat to watch
both mc.identifier and mc.base_model cache directories so stall
detection works for LoRA loads where the base model stalls on Xet.

Also update _get_hf_download_state to accept multiple model names and
skip names without "/" (local paths) since those do not have HF cache
directories.

* Fix model name filtering for official HF models without org prefix

Models like gpt2 and bert-base-uncased do not contain a slash but are
still valid HF Hub models with cache directories. Replace the "/" check
with a proper local-path detection that checks for path separators and
path-like prefixes instead.

Also fix the base_model watch list to not require "/" in the base model
name, so official models used as LoRA bases are also monitored.

* Fix local path detection that broke all org/model names on Linux

The os.path.sep check matched "/" in HF model IDs like "org/model" on
Linux, causing the stall detector to skip ALL standard HF models.

Replace with a check that only skips names starting with "/" (absolute
paths), "." (relative paths), "~" (home-relative), or containing "\"
(Windows paths). HF model IDs like "org/model" or "gpt2" pass through
correctly on all platforms.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-31 03:00:46 -07:00
Daniel Han
e164c930ff
fix(studio): correct default weight_decay and learning rate (#4695)
* fix(studio): change default weight_decay from 0.01 to 0.001

The default weight decay across Studio was 0.01 but should be 0.001.
Updated the default in all backend fallbacks, the Pydantic model, the
frontend config, and every YAML preset/model-default config.

* fix(studio): auto-set learning rate based on training method

Default LR should be 2e-4 for LoRA/QLoRA and 2e-5 for full fine-tuning.

Frontend: track whether the user has manually edited the LR field via a
_learningRateManuallySet flag (same pattern as trainOnCompletions).
When switching training method and the user has not touched the LR,
auto-set it to the appropriate default. Reset the flag on model load.

Backend: change trainer.py start_training default from 5e-5 to 2e-4,
update default.yaml fallback from 5e-5 to 2e-4, and fix
full_finetune.yaml from 0.0002 (2e-4) to 2e-5.

* refactor(studio): centralize weight_decay and learning rate defaults

Create studio/backend/core/training/constants.py as the single source of
truth for DEFAULT_WEIGHT_DECAY (0.001), DEFAULT_LEARNING_RATE (2e-4),
DEFAULT_LEARNING_RATE_FULL (2e-5), and DEFAULT_LEARNING_RATE_STR ("2e-4").

All backend modules (trainer.py, training.py, worker.py, models/training.py)
now import from constants.py instead of hardcoding values.

On the frontend, add LR_DEFAULT_LORA and LR_DEFAULT_FULL to
config/training.ts and use them in the store instead of magic numbers.
A comment cross-references the backend constants file.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix model-specific LR override, persist migration, and flag resets

- Preserve model-specific learning rates from YAML configs when the
  async autoSelectTrainingMethod callback fires (fixes Qwen2.5-1.5B
  getting 2e-4 instead of its configured 1e-5, etc.)
- Bump zustand persist version to 9 with migration so existing users
  with weightDecay=0.01 get updated to 0.001
- Clear _learningRateManuallySet in reset() and applyConfigPatch()
  for consistency with trainOnCompletions flag behavior
- Add DEFAULT_LEARNING_RATE_FULL_STR to constants.py

* Refine applyConfigPatch to only clear LR flag when patch includes LR

Only reset _learningRateManuallySet when the applied config patch
actually provides a learningRate value. This prevents unrelated config
patches from silently disarming the manual-edit guard, which would
cause a subsequent setTrainingMethod call to overwrite the user's
custom LR.

* Preserve model-specific LR when switching between qlora and lora

Only auto-switch the learning rate when the training category changes
(adapter <-> full fine-tuning). Switching between qlora and lora keeps
the current LR since both methods share the same learning rate range.
This preserves curated per-model defaults (e.g. 1e-5 for
Qwen2.5-1.5B-Instruct) when the user toggles between adapter methods.

* Remove constants.py, use YAML configs as the source of truth

The YAML config files (model-specific + default.yaml) are the intended
config layer for training defaults. The Python backend fallbacks now use
inline values that match the YAML configs, rather than importing from a
separate constants module. This keeps the config architecture simple:
YAML files are the single source of truth, and the inline Python
fallbacks are just safety nets that mirror them.

* fix(studio): preserve model-specific LR when switching training method

Stash YAML-provided learning rate and use it to restore the correct
value when switching between adapter and full fine-tune modes.

- qlora <-> lora no longer overwrites the model's LR
- full -> adapter restores the YAML LR instead of a hardcoded constant
- selecting a model while on full fine-tune uses LR_DEFAULT_FULL
  instead of applying the YAML adapter LR

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
Co-authored-by: Roland Tannous <rolandtannous@gravityq.ai>
2026-03-31 13:50:25 +04:00
Wasim Yousef Said
28aaf849bf
fix: throttle and cache HuggingFace modelInfo API calls (#4696)
* fix: throttle and cache HuggingFace modelInfo API calls

The frontend was firing 40 to 60 parallel modelInfo requests on app
startup with zero caching or deduplication, causing HF rate limits.

Adds a caching layer (hf-cache.ts) with TTL cache, inflight request
dedup, and a concurrency limiter. Also debounces the HF token input
so typing a token no longer re-fires all model searches per keystroke.

* fix: only fetch VRAM info for visible models in chat selector

* Fix cache key isolation and VRAM badge stability for PR #4696

- Cache key now includes a token fingerprint (last 8 chars) instead of a
  boolean, so switching HF tokens gives separate cache entries instead of
  serving stale data from the previous token.
- Extract token via credentials?.accessToken to match the @huggingface/hub
  API surface.
- Extend CachedResult type with safetensors/tags fields so downstream
  consumers no longer need unsafe `as` casts.
- Merge VRAM param map with previous state on scroll instead of replacing
  it, preventing a brief flash of missing VRAM badges when new models
  become visible.

* Fix VRAM badges missing for search-filtered recommended models

When a user types a search query, filteredRecommendedIds can include
models beyond the currently visible page. These models had no VRAM data
because useRecommendedModelVram only received visibleRecommendedIds.

Now we pass the union of visibleRecommendedIds and filteredRecommendedIds
to the VRAM hook, so recommended models surfaced by search also show
their VRAM badges. The hf-cache layer ensures no duplicate network calls.

* Apply biome formatting to hf-cache.ts and use-recommended-model-vram.ts

Auto-formatted with biome check --write to match project lint rules:
- Block statements for single-line if/for bodies
- Import sorting (type imports first)
- Consistent line wrapping

* Fix extractToken to handle both current and deprecated HF auth forms

The @huggingface/hub CredentialsParams type is a union:
  - { accessToken: "hf_..." }               (current preferred form)
  - { credentials: { accessToken: "..." } }  (deprecated form)

Previously only checked params.credentials?.accessToken (deprecated path).
Now checks both forms so the cache key is correct regardless of which
calling convention is used.

* Simplify extractToken, map merge, and set construction

- extractToken: remove type assertions, use direct property access with
  truthiness checks for cleaner union type handling
- VRAM map merge: use Map spread constructor instead of manual for loop
- idsForVram: use Set spread construction for more concise dedup

* Add rationale comment for MAX_CONCURRENT=3 in hf-cache.ts

* Skip GGUF repos in VRAM fetch and pre-populate cache from listModels

Two changes to reduce redundant HF API calls:

1. Filter GGUF repos from idsForVram before passing to useRecommendedModelVram.
   GGUF repos have no safetensors metadata and the render layer already shows
   a static "GGUF" badge -- fetching modelInfo for them is a no-op that wastes
   a semaphore slot and a network round-trip.

2. Add primeCacheFromListing() to hf-cache.ts and call it from listModels
   yield sites in mergedModelIterator and priorityThenListingIterator.
   listModels returns the same type (ModelEntry & Pick<ApiModelInfo, T>) as
   modelInfo with the same additionalFields, so the data is interchangeable.
   Priming only writes if the key is not already fresh, so it never overwrites
   a recent modelInfo response.

   This means models discovered via listModels are already in cache when
   useRecommendedModelVram later calls cachedModelInfo for them, eliminating
   duplicate network requests.

* Fix cache key mismatch: prime both token and anonymous slots

The VRAM hook calls cachedModelInfo without credentials (anonymous key),
but listModels results were primed only under the authenticated key.
For authenticated users the priming was a no-op -- cache miss every time.

Fix: prime both the token-specific slot and the anonymous slot when an
access token is present. Public model metadata (safetensors, tags) is
identical regardless of auth so this is safe.

Also add a defensive guard in primeCacheFromListing for empty name.

* Auto-prime anonymous cache slot from authenticated modelInfo fetches

When cachedModelInfo is called with a token, the result was only stored
under the token-specific key (e.g. model::abc12345). The VRAM hook
calls cachedModelInfo without credentials and reads the anonymous slot
(model::anon), causing a cache miss and duplicate fetch for every
priority model.

Now cachedModelInfo also writes to the anonymous slot on success when
a token is present. Public model metadata (safetensors, tags) is
identical regardless of auth, so this is safe and eliminates ~10
duplicate API calls on first page load.

* Guard anonymous cache priming against gated/private models

Only prime the anonymous cache slot for non-gated, non-private models.
Previously, authenticated modelInfo responses and listing results were
unconditionally copied into the anonymous slot, which could briefly
expose gated/private model metadata after clearing the HF token.

Now checks result.gated and result.private before writing the anon slot.
Public unsloth/ models (the common case) still benefit from the
optimization; gated models like meta-llama/* require a fresh fetch
per auth context.

* Extract primeFromListing helper to deduplicate cache priming logic

The cache priming pattern (prime token slot + conditionally prime anon
slot for non-gated models) was duplicated in three places. Extracted
into a single primeFromListing() function for maintainability.

* Export CachedResult type, add isStale helper, simplify primeFromListing

- Export CachedResult so consumers can use it directly instead of
  the indirect Parameters<typeof ...> pattern.
- Extract isStale(key) helper to deduplicate the cache freshness
  check that was repeated in primeCacheFromListing, cachedModelInfo,
  and the anonymous-slot priming logic.
- Simplify primeFromListing to use CachedResult directly for both
  the data parameter and the gated/private guard, eliminating the
  double cast.

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-31 02:21:17 -07:00
Datta Nimmaturi
3b5a49776b
[studio] multi gpu: revert to balanced for inference. (#4698)
* Revert to balanced for inference

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove unused for_inference parameter from get_device_map

Since inference and training both use "balanced" now, the for_inference
flag is dead code. Remove it from the function signature, the call site
in inference.py, and simplify the tests accordingly.

* Remove redundant TestDeviceMapForInference test class

TestGpuAutoSelection already covers the same multi-gpu and single-gpu
device_map assertions. The TestDeviceMapForInference class was left
over from when for_inference had distinct behavior.

* Remove redundant test_get_device_map_multi_gpu_uses_balanced

Its assertions ([0,1] -> balanced, [0] -> sequential) are already
covered by test_get_device_map_uses_explicit_gpu_selection.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-31 01:24:41 -07:00
Daniel Han
fe6609a624
fix(studio): open tour ReadMore links in new tab (#4694)
* fix(studio): open tour ReadMore links in new tab

The quick tour "Read more" links navigate away from Studio instead of
opening in a separate tab. Add target="_blank" and rel="noopener
noreferrer" to the ReadMore component so external doc links open in a
new browser tab.

* fix(studio): only open external ReadMore links in new tab

Apply target="_blank" conditionally based on whether the href starts
with "http", so internal links still navigate in the same tab.

* Tighten external-link detection in ReadMore component

Use regex /^https?:\/\// instead of startsWith("http") so the check
requires the full protocol prefix and does not match non-URL strings
that happen to begin with "http".

* Hoist regex to module scope for ReadMore

Move EXTERNAL_URL_RE to top-level constant to satisfy the biome
useTopLevelRegex lint rule and avoid re-creating the RegExp on
every render.

---------

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
2026-03-30 23:41:14 -07:00
Lee Jackson
308bb948d1
studio: prevent false multimodal warning during model loading (#4704)
* studio: gate multimodal incompatibility warning on settled model capabilities

* Also disable Start button during isCheckingVision fallback

When getModelConfig fails and the fallback checkVisionModel is still
in-flight, isLoadingModelDefaults clears before isCheckingVision does.
Without also gating on isCheckingVision the Start button briefly
re-enables with stale capability flags.

Add isCheckingVision to the disabled condition and show "Loading
model..." text while either flag is active.

* Show correct error message for audio dataset incompatibility

The incompatibility warning always said "switch to a vision model"
even when the actual issue was an audio dataset on a non-audio model.
Now shows an audio-specific message when the mismatch is audio.

* Extract isLoadingModel constant for clarity

Pull the combined model-loading condition into a single constant
reused by the settled check, the disabled prop, and the button label.

---------

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
2026-03-30 23:11:20 -07:00
pre-commit-ci[bot]
66f250a614
[pre-commit.ci] pre-commit autoupdate (#4705)
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.15.7 → v0.15.8](https://github.com/astral-sh/ruff-pre-commit/compare/v0.15.7...v0.15.8)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-30 21:58:16 -07:00
Roland Tannous
d6d3f59984
fix: replace hard timeout with inactivity timeout for model loading (#4707)
The 180s wall-clock timeout would kill model loads on slow connections
even when the download was actively progressing. Now the worker sends
heartbeat status messages every 30s during loading, and the orchestrator
resets its 300s deadline on each one — so it only times out when the
subprocess goes truly silent.
2026-03-31 07:35:04 +04:00
Roland Tannous
7f353acfd4
fix: skip download progress polling for exported GGUF models (#4709)
* fix: skip download progress polling for exported GGUF models

* fix: revert isLocalGgufDir change — exported GGUFs are file paths, not dirs

* fix: set isDownloaded true for all adapters in LoraModelPicker
2026-03-31 07:21:23 +04:00
Etherll
34272a796f
Fix/bun windows bin detection (#4703)
* fix(studio): detect bun .exe shims in Windows binary check

* Update setup.sh

* add .bunx checking
2026-03-30 21:58:33 +04:00
Daniel Han
6d83ad9a28
fix(studio): avoid UnicodeEncodeError on Windows cp1252 consoles (#4699)
* fix(studio): replace unicode emoji in print() to avoid cp1252 crash on Windows

On Windows the default console encoding is cp1252 which cannot encode
unicode emoji like U+2705 or U+26A0. bare print() calls with these
characters cause a UnicodeEncodeError at runtime.

- run.py: replace emoji with ASCII status prefixes [OK] and [WARNING]
- format_conversion.py: remove duplicate print() that mirrors the
  logger.info() call on the next line, and drop the emoji from the
  log message since loggers handle encoding separately

* fix(studio): apply same emoji/print cleanup to parallel VLM conversion path

The parallel URL-based conversion logic has the same duplicate print()
with emoji that was fixed in the sequential path. Remove the bare
print() and drop the emoji from the logger.info() call.

* Treat install_python_stack.py failure as fatal in setup.ps1

On Linux/Mac, setup.sh runs under set -euo pipefail so a non-zero
exit from install_python_stack.py aborts the installer. On Windows,
setup.ps1 had no exit code check -- if the Python script crashed
(eg from the cp1252 UnicodeEncodeError), the installer silently
continued past the dependency loop and reported success. Studio
would then fail at launch with ModuleNotFoundError for structlog,
fastapi, and other deps that were never installed.

Capture $LASTEXITCODE and exit 1 if the dependency installer fails,
matching the error handling pattern already used for PyTorch install.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-30 06:40:47 -07:00
Daniel Han
a0bca759f3
Fix editable install scanning 6,500+ node_modules dirs (#4697)
* fix: scope packages.find to prevent node_modules namespace scanning

The packages.find section had no include filter, so setuptools'
find_namespace_packages discovered all directories as potential Python
packages -- including the 6,557 directories inside
studio/frontend/node_modules/ after the frontend build step.

This caused the editable install overlay step to run 20,000+ glob
operations across 6,619 "packages", which on fast NVMe takes ~5s but
on slower disks can take 7+ minutes.

Adding an explicit include filter scopes discovery to only the packages
we actually ship (unsloth, unsloth_cli, studio, studio.backend), dropping
from 6,619 to 58 discovered packages and the editable build time from
5.4s to 1.2s.

Also removes the broken kernels/moe exclude (used "/" instead of "."
notation so it never matched) and adds a node_modules exclude as a
safety net.

* fix: use precise node_modules exclude patterns

Use "*.node_modules" and "*.node_modules.*" instead of "*.node_modules*"
to avoid accidentally excluding valid packages that might contain
"node_modules" as a substring in their name.
2026-03-30 02:40:29 -07:00
Datta Nimmaturi
9311df2b29
[Studio] multi gpu finetuning/inference via "balanced_low0/sequential" device_map (#4602)
* [WIP] balanced device map for studio

* gpus as a request parameter

* API for multi GPU stuff

* return multi gpu util in new API

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use balanced_low0 instead of balanced

* Use balanced_low0 instead of balanced

* Fix device_map typo, UUID parsing crash, set() filter bug, and broken tests

- balanced_low0 -> balanced_low_0 (transformers/accelerate rejects the old string)
- get_parent_visible_gpu_ids() now handles UUID/MIG CUDA_VISIBLE_DEVICES
  gracefully instead of crashing on int() parse
- _get_backend_visible_gpu_info() set() or None bug: empty set is falsy so
  CUDA_VISIBLE_DEVICES=-1 would disable filtering and report all GPUs
- test_gpu_selection.py: add missing get_visible_gpu_utilization import and
  add required job_id arg to start_training() calls

* Smart GPU determinism using estimates

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* disallow gpu selection for gguf for now

* cleanup

* Slightly larger baseline

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Treat empty list as auto

* Verbose logging/debug

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cleanup and revert unnecessary deletions

* Cleanup excessive logs and guard against disk/cpu offload

* auth for visibility API. cleanup redundant imports. Adjust QLoRA estimate

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support for non cuda gpus

* Fix multi-GPU auto-selection memory accounting

The multi_gpu_factor was applied uniformly to all GPUs including the
first one, which unfairly penalizes single-GPU capacity when
transitioning to multi-GPU. This created a discontinuity where a model
that barely fits 1 GPU would suddenly require 2 GPUs because the first
GPU's free memory was discounted by 20%.

Now the first GPU keeps its full free memory, and only additional GPUs
have an overhead factor (0.85) applied to account for inter-GPU
communication and sharding overhead. This gives more accurate
auto-selection and avoids unnecessary multi-GPU for models that
comfortably fit on one device.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add sandbox tests for multi-GPU selection logic

24 tests covering model size estimation, memory requirements, automatic
GPU selection, device map generation, GPU ID validation, and multi-GPU
overhead accounting. All tests use mocks so they run without GPUs on
Linux, macOS, and Windows.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix reviewer findings: 4bit inference estimate, fallback, GGUF gpu_ids, retry

1. 4-bit inference now uses reduced memory estimate (model_size/3 + buffer)
   instead of the FP16 1.3x multiplier. This prevents over-sharding
   quantized models across unnecessary GPUs.

2. When model size estimation fails, auto_select_gpu_ids now falls back to
   all visible GPUs instead of returning None (which could default to
   single-GPU loading for an unknown-size model).

3. GGUF inference route now treats gpu_ids=[] as auto-selection (same as
   None) instead of rejecting it as an unsupported explicit request.

4. Training retry path for "could not get source code" now preserves the
   gpu_ids parameter so the retry lands on the same GPUs.

5. Updated sandbox tests to cover the new 4-bit inference estimate branch.

* Remove accidentally added unsloth-zoo submodule

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix UUID/MIG visibility and update test expectations

1. nvidia.py: When CUDA_VISIBLE_DEVICES uses UUID/MIG tokens, the
   visibility APIs now return "unresolved" with empty device lists instead
   of exposing all physical GPUs. This prevents the UI from showing GPUs
   that the backend process cannot actually use.

2. test_gpu_selection.py: Updated test expectations to match the new
   multi-GPU overhead accounting (first GPU at full capacity, 0.85x for
   additional GPUs) and 4-bit inference memory estimation formula.
   All 60 tests now pass.

* Add CPU/disk offload guard to audio inference path

The audio model loading branch returned before the common
get_offloaded_device_map_entries() check, so audio models loaded with a
multi-GPU device_map that spilled layers to CPU/disk would be accepted
instead of rejected. Now audio loads also verify no modules are offloaded.

* Improve VRAM requirement estimates

* Replace balanced_low_0 with balanced

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refine calculations for slightly easier nums

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* adjust estimates

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use nums instead of obj to avoid seralisation error

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Harden nvidia-smi parsing and fix fallback GPU list

1. nvidia.py: Wrap int() casts for GPU index and memory in try/except
   so MIG slices, N/A values, or unexpected nvidia-smi output skip the
   unparseable row instead of aborting the entire GPU list.

2. nvidia.py: Handle GPU names containing commas by using the last
   field as memory instead of a fixed positional index.

3. hardware.py: fallback_all now uses gpu_candidates (GPUs with verified
   VRAM data) instead of raw devices list, which could include GPUs
   with null VRAM that were excluded from the ranking.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* consolidate raise_if_offload

* Improve MoE support. Guard against nvidia-smi failures

* Improve MoE support. Guard against nvidia-smi failures

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix shared-expert LoRA undercount, torch VRAM fallback, and apply_gpu_ids edge case

1. vram_estimation.py: compute_lora_params now includes shared experts
   (n_shared_experts) alongside routed experts when computing MoE LoRA
   adapter parameters. Previously only n_experts were counted, causing
   the estimator to undercount adapter, optimizer, and gradient memory
   for DeepSeek/GLM-style models with shared experts.

2. hardware.py: _torch_get_per_device_info now uses mem_get_info (which
   reports system-wide VRAM usage) instead of memory_allocated (which
   only reports this process's PyTorch allocations). This prevents
   auto-selection from treating a GPU as mostly free when another
   process is consuming VRAM. Falls back to memory_allocated when
   mem_get_info is unavailable.

3. hardware.py: apply_gpu_ids([]) now returns early instead of setting
   CUDA_VISIBLE_DEVICES="" which would disable CUDA entirely. Empty
   list inherits the parent visibility, same as None.

4. hardware.py: Upgraded fallback_all GPU selection log from debug to
   warning so operators are notified when the model likely will not fit
   in available VRAM.

* Guard nvidia-smi subprocess calls against OSError and TimeoutExpired

get_visible_gpu_utilization and get_backend_visible_gpu_info now catch
OSError (nvidia-smi not found) and TimeoutExpired internally instead
of relying on callers to wrap every invocation. Returns the standard
available=False sentinel on failure so the torch-based fallback in
hardware.py can take over.

* Guard get_primary_gpu_utilization and reset GPU caches between tests

1. nvidia.py: get_primary_gpu_utilization now catches OSError and
   TimeoutExpired internally, matching the pattern already used in
   get_visible_gpu_utilization and get_backend_visible_gpu_info. All
   three nvidia-smi callers are now self-contained.

2. test_gpu_selection.py: Added _GpuCacheResetMixin that resets the
   module-level _physical_gpu_count and _visible_gpu_count caches in
   tearDown. Applied to all test classes that exercise GPU selection,
   device map, or visibility functions. This prevents stale cache
   values from leaking between tests and causing flaky results on
   machines with real GPUs.

* Fix nvidia-smi fallback regression and physical GPU count validation

1. hardware.py: get_gpu_utilization, get_visible_gpu_utilization, and
   get_backend_visible_gpu_info now check result.get("available") before
   returning the nvidia-smi result. When nvidia-smi is unavailable or
   returns no data (e.g., containers without nvidia-smi, UUID/MIG masks),
   the functions fall through to the torch-based fallback instead of
   returning an empty result. This fixes a regression where the internal
   exception handling in nvidia.py prevented the caller's except block
   from triggering the fallback.

2. hardware.py: resolve_requested_gpu_ids now separates negative-ID
   validation from physical upper-bound validation. The physical count
   check is only enforced when it is plausibly a true physical count
   (i.e., higher than the largest parent-visible ID), since
   torch.cuda.device_count() under CUDA_VISIBLE_DEVICES returns the
   visible count, not the physical total. The parent-visible-set check
   remains authoritative in all cases. This prevents valid physical IDs
   like [2, 3] from being rejected as "out of range" when nvidia-smi is
   unavailable and CUDA_VISIBLE_DEVICES="2,3" makes torch report only
   2 devices.

* Fix UUID/MIG torch fallback to enumerate devices by ordinal

When CUDA_VISIBLE_DEVICES uses UUID or MIG identifiers,
get_parent_visible_gpu_ids() returns [] because the tokens are
non-numeric. The torch fallback in get_visible_gpu_utilization() and
get_backend_visible_gpu_info() previously passed that empty list to
_torch_get_per_device_info(), getting nothing back.

Now both functions detect the empty-list case and fall back to
enumerating torch-visible ordinals (0..device_count-1) with
index_kind="relative". This means the UI and auto-selection still
see real device data in Kubernetes, MIG, and Slurm-style UUID
environments where nvidia-smi output cannot be mapped to physical
indices.

Updated test_uuid_parent_visibility to verify the new torch fallback
path returns available=True with relative ordinals.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add type hint for gpu_ids parameter in InferenceOrchestrator.load_model

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-30 02:33:15 -07:00
Michael Han
fbfcbc69f2
Update README.md 2026-03-30 01:34:36 -07:00
Michael Han
d2b8ed8def
Update install.md 2026-03-30 01:33:33 -07:00
Lee Jackson
2f0a5baa87
fix(studio): preserve GGUF context max after apply and refresh (#4691)
Fixes #4670

Separates the GGUF context slider ceiling from the currently active context length so lowering context via Chat Settings no longer locks the slider max to the reduced value.

- Backend: adds `max_context_length` to GGUF load/status responses, computed from the largest VRAM/KV-fit cap across all usable GPU subsets
- Frontend: stores `ggufMaxContextLength` and uses it for Context Length slider/input bounds; hydrates from both `/api/inference/load` and `/api/inference/status`
- Defaults UI ceiling to native context for CPU-only and fallback paths
- Seeds `effective_ctx` and `max_available_ctx` before GPU probing to prevent `UnboundLocalError` on probe failure
- Property fallback uses native `_context_length`, not effective `context_length`
2026-03-30 01:33:16 -07:00
Lee Jackson
5557e1fd27
studio: unify Windows installer/setup logging style, verbosity controls, and startup messaging (#4651)
* refactor(studio): unify setup terminal output style and add verbose setup mode

* studio(windows): align setup.ps1 banner/steps with setup.sh (ANSI, verbose)

* studio(setup): revert nvcc path reordering to match main

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* studio(setup): restore fail-fast llama.cpp setup flow

* studio(banner): use IPv6 loopback URL when binding :: or ::1

* Fix IPv6 URL bracketing, try_quiet stderr, _step label clamp

- Bracket IPv6 display_host in external_url to produce clickable URLs
- Redirect try_quiet failure log to stderr instead of stdout
- Clamp _step label to column width to prevent negative padding

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add sandbox integration tests for PR #4494 UX fixes

Simulation harness (tests/simulate_pr4494.py) creates an isolated uv
venv, copies the real source files into it, and runs subprocess tests
for all three fixes with visual before/after demos and edge cases.

Standalone bash test (tests/test_try_quiet.sh) validates try_quiet
stderr redirect across 8 scenarios including broken-version contrast.

39 integration tests total (14 IPv6 + 15 try_quiet + 10 _step), all
existing 75 unit tests still pass.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Truncate step() labels in setup.sh to match PS1 and Python

The %-15s printf format pads short labels but does not truncate long
ones.  Change to %-15.15s so labels wider than 15 chars are clipped,
matching the PowerShell .Substring(0,15) and Python label[:15] logic.

* Remove sandbox integration tests from PR

These test files are not part of the styling fix and should not
ship with this PR.

* Show error output on failure instead of suppressing it

- install_python_stack.py: restore _red for patch_package_file
  warnings (was downgraded to _dim)
- setup.ps1: capture winget output and show on failure for CUDA,
  Node, Python, and OpenSSL installs (was piped to Out-Null)
- setup.ps1: always show git pull failure warning, not just in
  verbose mode

* Show winget error output for Git and CMake installs on failure

Same capture-and-print-on-failure pattern already used for
Node, Python, CUDA, and OpenSSL winget installs.

* fix: preserve stderr for _run_quiet error messages in setup.sh

The step() helper writes to stdout, but _run_quiet's error header
was originally sent to stderr (>&2). Without the redirect, callers
that separate stdout/stderr would miss the failure headline while
still seeing the log body on stderr. Add >&2 to both step calls
inside _run_quiet to match main's behavior.

* feat: add --verbose flag to setup and update commands

Wire UNSLOTH_VERBOSE=1 through _run_setup_script() so that
'unsloth studio update --verbose' (and the deprecated 'setup')
passes the flag to setup.sh / setup.ps1 / install_python_stack.py.

* fix(studio): honor verbose logging and keep llama.cpp failures non-blocking

* fix(studio): switch installer to 'studio update' and normalize Windows setup logs

* chore(studio): refine localhost tip and remove skip-base setup nois

* fix(studio): align Windows setup logs with Linux style and improve startup tips

* fix(studio): align Windows setup logs with Linux style

* refactor(windows-installer): align install/setup logs with Linux style and silence auto-launch output

* refactor(windows): align installer/setup output with Linux style and reduce default verbosity

* refactor(windows): match install.ps1 output style/colors to setup and quiet default logs

* fix(studio-banner): update personal-computer localhost tip

* fix(setup.sh): restore verbose llama.cpp build output while keeping default quiet mode

* fix(install.sh): align installer logging with setup style and restore POSIX-safe color output

* fix(install.sh): preserve installer reliability and launch visibility

Export verbose mode for child setup processes, harden install command handling under set -e, and keep first-run studio launch non-silent so users can always see URL and port fallback output.

* fix(windows installer): keep exit semantics and degrade status accurate

Use quiet command redirection that preserves native exit codes, keep startup output visible on first launch, and report limited install status when llama.cpp is unavailable.

* fix(setup.sh): improve log clarity and enforce GGUF degraded signaling

Restore clean default setup output, add verbose-only diagnostics, fail fast on Colab dependency install errors, and return non-zero when GGUF prerequisites or llama.cpp artifacts are unavailable.

* fix(installer): harden bash preflight and PowerShell GPU checks

Fail fast when bash is unavailable before invoking setup.sh, and replace remaining nvidia-smi pipeline checks with stream redirection patterns that preserve reliable native exit-code handling.

* fix(windows): keep verbose output visible while preserving exit codes

Ensure PowerShell wrapper helpers in install/update stream native command output to host without returning it as function output, so npm logs no longer corrupt exit-code checks in verbose mode.

* fix(windows): avoid sticky UNSLOTH_VERBOSE and gate studio update verbosity

* Fix degraded llama.cpp exit code, PS verbose stderr, banner URLs, npm verbose

- setup.sh: Do not exit non-zero when llama.cpp is unavailable; the footer
  already reports the limitation, and install.sh runs under set -e so a
  non-zero exit aborts the entire install including PATH/shortcuts/launch.
- setup.ps1: Remove $? check in Invoke-SetupCommand verbose path; PS 5.1
  sets $? = $false when native commands write to stderr even with exit 0.
  Merge stderr into stdout with 2>&1 and rely solely on $LASTEXITCODE.
- startup_banner.py: Show the actual bound address when Studio is bound to
  a non-loopback interface instead of always showing 127.0.0.1/localhost.
- setup.sh: Use run_quiet_no_exit instead of run_quiet_no_exit_always for
  npm install steps so --verbose correctly surfaces npm output.

* Fix install.ps1 verbose stderr, propagate UNSLOTH_VERBOSE, fix git clone verbose

- install.ps1: Apply same Invoke-InstallCommand fix as setup.ps1 -- merge
  stderr into stdout with 2>&1 and drop the $? check that misclassifies
  successful native commands on PS 5.1.
- install.ps1 + setup.ps1: Export UNSLOTH_VERBOSE=1 to the process env
  when --verbose is passed so child processes like install_python_stack.py
  also run in verbose mode.
- setup.sh: Use run_quiet_no_exit for git clone llama.cpp so --verbose
  correctly surfaces clone diagnostics during source-build fallback.

* Surface prebuilt llama.cpp output in verbose mode, remove dead code, fix banner

- setup.sh: Use tee in verbose mode for prebuilt llama.cpp installer so
  users can see download/validation progress while still capturing the log
  for structured error reporting on failure.
- setup.ps1: Same fix for Windows -- use Tee-Object in verbose mode.
- setup.sh: Remove run_quiet_no_exit_always() which has no remaining callers.
- startup_banner.py: Avoid printing the same URL twice when Studio is
  bound to a specific non-loopback address that matches the display host.

* Fix run_install_cmd exit code after failed if-statement

The previous pattern 'if "$@"; then return 0; fi; _rc=$?' always captured
$? = 0 because $? reflects the if-statement result, not the command's exit
code. Switch to '"$@" && return 0; _rc=$?' which preserves the actual
command exit code on failure. Applies to both verbose and quiet branches.

* Fix _run_quiet exit code, double uv install, missing --local flag

- setup.sh: Fix _run_quiet verbose path that always captured exit code 0
  due to $? resetting after if-then-fi with no else. Switch to the same
  '"$@" && return 0; exit_code=$?' pattern used in install.sh.
- setup.sh: Consolidate the two uv install branches (verbose + quiet)
  into a single attempt with conditional output. Previously, when verbose
  mode was on and the install failed, a second silent attempt was made.
- install.ps1: Pass --local flag to 'unsloth studio update' when
  $StudioLocalInstall is true. Without this, studio.py's update() command
  overwrites STUDIO_LOCAL_INSTALL to "0", which could cause issues if
  setup.ps1 or install_python_stack.py later checks that variable.

* Revert SKIP_STUDIO_BASE change for --no-torch, restore install banners

- Revert SKIP_STUDIO_BASE from 0 to 1 for --no-torch. install.sh already
  installs unsloth+unsloth-zoo and no-torch-runtime.txt before calling
  setup.sh, so letting install_python_stack.py redo it was redundant and
  slowed down --no-torch installs for no benefit.
- Restore the "Unsloth Studio installed!" success banner and "starting
  Unsloth Studio..." launch message so users get clear install completion
  feedback before the server starts.

* Make llama.cpp build failure a hard error with proper cleanup

- setup.sh: Restore exit 1 when _LLAMA_CPP_DEGRADED is true. GGUF
  inference requires a working llama.cpp build, so this should be a
  hard failure, not a silent degradation.
- install.sh: Catch setup.sh's non-zero exit with '|| _SETUP_EXIT=$?'
  instead of letting set -e abort immediately. This ensures PATH setup,
  symlinks, and shortcuts still get created so the user can fix the
  build deps and retry with 'unsloth studio update'. After post-install
  steps, propagate the failure with a clear error message.

* Revert install.ps1 to 'studio setup' to preserve SKIP_STUDIO_BASE

'studio update' pops SKIP_STUDIO_BASE from the environment, which
defeats the fast-path version check added in PR #4667. When called
from install.ps1 (which already installed packages), SKIP_STUDIO_BASE=1
must survive into setup.ps1 so it skips the redundant PyPI check and
package reinstallation. 'studio setup' does not modify env vars.

* Remove deprecation message from 'studio setup' command

install.ps1 uses 'studio setup' (not 'studio update') to preserve
SKIP_STUDIO_BASE. The deprecation message was confusing during first
install since the user never typed the command.

* Fix stale env vars, scope degraded exit, generic error message for PR #4651

- install.ps1: Always set STUDIO_LOCAL_INSTALL and clear STUDIO_LOCAL_REPO
  when not using --local, to prevent stale values from a previous --local
  run in the same PowerShell session. Fix log messages to say 'setup' not
  'update' since we call 'studio setup'.
- setup.sh: Only exit non-zero for degraded llama.cpp when called from the
  installer (SKIP_STUDIO_BASE=1). Direct 'unsloth studio update' keeps
  degraded installs successful since Studio is still usable for non-GGUF
  workflows and the footer already reports the limitation.
- install.sh: Make the setup failure error message generic instead of
  GGUF-specific, so unrelated failures (npm, Python deps) do not show
  misleading cmake/git recovery advice.

* Show captured output on failure in quiet mode for PR #4651

Both Invoke-InstallCommand (install.ps1) and Invoke-SetupCommand
(setup.ps1) now capture command output in quiet mode and display it
in red when the command fails. This matches the behavior of
run_install_cmd in install.sh where failure output is surfaced even
in quiet mode, making cross-platform error debugging consistent.

* Match degraded llama.cpp exit on Windows, fix --local recovery hint for PR #4651

- setup.ps1: Exit non-zero for degraded llama.cpp when called from
  install.ps1 (SKIP_STUDIO_BASE=1), matching setup.sh behavior. Direct
  'unsloth studio update' keeps degraded installs successful.
- install.sh: Show 'unsloth studio update --local' in the recovery
  message when the install was run with --local, so users retry with
  the correct flag instead of losing local checkout context.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-30 00:53:23 -07:00
Roland Tannous
5bbfabb151
fix: [Studio] setup.ps1 update-flow for windows (#4667)
* fix: add PyPI version check to setup.ps1 for fast update path

Port the update-flow logic from setup.sh to setup.ps1 so that
`unsloth studio update` on Windows skips Python dependency reinstall
when the installed version already matches PyPI latest.

* fix: clear SKIP_STUDIO_BASE in update command

install.ps1 sets SKIP_STUDIO_BASE=1 which persists in the PowerShell
session. If the user runs `unsloth studio update` in the same terminal,
the env var causes the version check to be skipped. Clear it explicitly
in the update command.

* fix: harden version check and clear stale env vars in update flow

- Normalize $InstalledVer with Out-String + Trim() to avoid array/whitespace
  comparison issues in PowerShell 5.1 (python output can be captured as
  string[] instead of scalar string)
- Move Fast-Install --upgrade pip inside if (-not $SkipPythonDeps) so the
  fast path avoids unnecessary network round-trips
- Clear STUDIO_LOCAL_REPO when --local is not passed to prevent a previous
  --local session from leaking into a plain update

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-29 21:14:36 -07:00
Roland Tannous
a6c1f893fc
Fix blank page on Windows due to broken .js MIME type (#4674)
* Fix blank page on Windows due to broken .js MIME type in registry

* Update studio/backend/main.py

adding defensive suggestion by gemini where we make the mimetypes specific to windows platforms

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-28 22:26:49 +04:00
Lee Jackson
5d2dca801c
studio: add HF/local model selection UI for GGUF export (#4365)
* feat(studio): add HF/local model selection UI for GGUF export

* fix(studio):fix selector ring clipping

* fix(studio): export page trust_remote_code control and label styling

* fix(studio): accept hf_token in load_checkpoint orchestrator method

The route was passing hf_token to load_checkpoint() but the method
didn't accept it, causing a TypeError on every /api/export/load-checkpoint
request.

* fix(studio): clear HF model selection when input is edited

Previously selectedSourceModel was only cleared when the input became
empty, so editing to a different repo ID after selecting a model would
silently keep the old selection.

---------

Co-authored-by: Roland Tannous <rolandtannous@gravityq.ai>
2026-03-28 22:18:25 +04:00
Daniel Han
362ad3606b Update _utils.py 2026-03-27 08:42:00 -07:00
Daniel Han
82d14b44d3
fix: preserve Windows drive-letter paths on native Windows (#4665)
normalize_path() unconditionally converted Windows paths like
C:\Users\... to WSL format /mnt/c/Users/..., which breaks path
resolution on native Windows. This caused LM Studio GGUF models
to fail detection (detect_gguf_model returned None for the invalid
path), falling through to the Unsloth import path which requires
a GPU.

Now only performs the /mnt/ mapping when actually running under WSL.
On native Windows, drive letters are preserved and backslashes are
normalized to forward slashes.
2026-03-27 08:19:41 -07:00
Daniel Han
9477e7c43f
Bump minimum unsloth version to 2026.3.16 in install scripts (#4663)
Update install.sh and install.ps1 to require unsloth>=2026.3.16,
matching the latest PyPI release.
2026-03-27 07:47:08 -07:00
Daniel Han
df3b18c579 Update _utils.py 2026-03-27 07:24:39 -07:00
Daniel Han
844a816ed0 Update pyproject.toml 2026-03-27 07:14:03 -07:00
Roland Tannous
562e54fc6e
Fix HF cache default and show LM Studio models in chat/inference (#4653)
* fix: default HF cache to standard platform path instead of legacy Unsloth cache

* feat: show LM Studio and local models in chat Fine-tuned tab

* feat: show LM Studio models in Hub models tab

* fix: fetch local models after auth refresh completes

* Revert "fix: fetch local models after auth refresh completes"

This reverts commit cfd61f0ac7.

* fix: increase llama-server health check timeout to 600s for large models

* feat: expandable GGUF variant picker for LM Studio local models

* fix: show GGUF variant label for locally loaded LM Studio models

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: show publisher name in LM Studio model labels

* fix: set model_id for loose GGUF files in LM Studio publisher dirs

* fix: show publisher prefix in Fine-tuned tab LM Studio models

* fix: only use model_id for lmstudio source models

* fix: only show LM Studio models in Hub tab on Mac/chat-only mode

* fix: respect XDG_CACHE_HOME, handle Windows paths in isLocalPath, refresh LM Studio on remount

- _setup_cache_env now reads XDG_CACHE_HOME (falls back to ~/.cache)
  instead of hard-coding ~/.cache/huggingface. This follows the standard
  HF cache resolution chain and respects distro/container overrides.

- isLocalPath in GgufVariantExpander uses a regex that covers Windows
  drive letters (C:\, D:/), UNC paths (\\server\share), relative paths
  (./, ../), and tilde (~/) -- not just startsWith("/").

- HubModelPicker.useEffect now calls listLocalModels() before the
  alreadyCached early-return gate so LM Studio models are always
  refreshed on remount. Also seeds useState from _lmStudioCache for
  instant display on re-open.

* fix: add comment explaining isLocalPath regex for Windows/cross-platform paths

* fix: prioritize unsloth publisher in LM Studio model list

* fix: scope unsloth-first sort to LM Studio models on all platforms

* fix: add missing _lmStudioCache module-level declaration

* fix: prioritize unsloth publisher before timestamp sort in LM Studio group

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-27 06:59:27 -07:00
Wasim Yousef Said
73969a1e4f
fix: disable OCR in pymupdf4llm PDF extraction (#4659) 2026-03-27 06:53:33 -07:00
Daniel Han
c4e34c88c8
Fall back to parsing model name when HF API has no param count (#4656)
Some models like unsloth/Qwen3-0.6B have no safetensors metadata
on Hugging Face, so the training model selector showed no parameter
size badge. The chat model picker already had extractParamLabel()
as a fallback that parses sizes like "0.6B" from the model name.

Add the same fallback to the training model selector and the
onboarding model selection step.

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
2026-03-27 05:57:49 -07:00
Wasim Yousef Said
4ab7fb1f7b
fix: replace navbar shutdown text button with icon-only button (#4655) 2026-03-27 05:44:59 -07:00
Daniel Han
e36f72c685
Detect always-on reasoning models and show Think button as locked-on (#4654)
* Detect always-on reasoning models and show Think button as locked-on

Models with hardcoded <think>/<think> tags or reasoning_content in
their chat template (e.g. distilled reasoning models) always produce
thinking output regardless of any toggle. Previously these models
were not detected as reasoning-capable at all, so the Think button
was grayed out even though the model was actively reasoning.

Backend:
- Detect <think>/<think> and reasoning_content in GGUF chat templates
  as a fallback when enable_thinking is not present
- Add reasoning_always_on flag to LoadResponse and InferenceStatusResponse
- Pass the flag through all GGUF load and status response paths

Frontend:
- Add reasoningAlwaysOn to the chat runtime store and API types
- When reasoning_always_on is true, show the Think button as lit
  (active) but not clickable, with a tooltip explaining the model
  always uses thinking
- Force reasoningEnabled=true when the model always reasons

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use pointer-events-none instead of disabled for always-on Think button

The HTML disabled attribute was not fully blocking clicks on the Think
button for always-on reasoning models. Switch to pointer-events-none
CSS class which prevents all mouse interaction at the CSS level.

* Use a static span instead of disabled button for always-on Think

Replace the button element with a plain span when reasoning is
always on. This makes it physically impossible to toggle since
there is no clickable element at all, avoiding any CSS or
disabled-attribute edge cases.

* Simplify always-on Think button to stay lit and remain toggleable

Keep the Think button as a normal toggleable button but ensure it
shows as lit when reasoning_always_on is true. The model always
reasons regardless of the toggle state so there is no need to
block interaction.

---------

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-27 05:42:26 -07:00
Daniel Han
eacaf6827c
fix: no-torch install deps without pulling torch transitively (#4650)
Use --no-deps for ALL packages (unsloth, unsloth-zoo, and runtime deps)
since the current PyPI metadata for unsloth still declares torch as a
hard dependency. Runtime deps (typer, pydantic, safetensors,
transformers, etc.) are installed from no-torch-runtime.txt with
--no-deps to prevent transitive torch resolution from accelerate, peft,
trl, and sentence-transformers.

no-torch-runtime.txt now includes unsloth's own direct deps (typer,
pydantic, pyyaml, nest-asyncio) since --no-deps skips those too.

install.sh installs no-torch-runtime.txt directly (via helper function
_find_no_torch_runtime). install.ps1 does the same via
Find-NoTorchRuntimeFile. SKIP_STUDIO_BASE stays at 1 to avoid setup.sh
fast-path issues.

install_python_stack.py NO_TORCH branch does the same for unsloth
studio update, using package_name instead of hardcoded "unsloth".
2026-03-27 05:19:26 -07:00
Daniel Han
a7c43bc46d
Fix inference failing for transformers 5.x models (trust_remote_code) (#4652)
* Fix inference failing for transformers 5.x models (trust_remote_code)

The training worker in core/training/worker.py auto-enables
trust_remote_code for unsloth/* models that need transformers 5.x
(e.g. NVIDIA-Nemotron-3-Nano-4B). The inference worker did not have
the same logic, so loading these models for chat would fail with
"No config file found" while training worked fine.

Add the same auto-detection to the inference worker so
trust_remote_code is set automatically when needed.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-27 04:51:30 -07:00
Wasim Yousef Said
887b8cb1c2
fix: add auth + UX improvements to shutdown button (#4642)
* Studio shutdown button

* fix: add auth to shutdown endpoint and improve UX

- Add JWT auth (Depends(get_current_subject)) to POST /api/shutdown
- Use authFetch instead of bare fetch in shutdown dialog
- Only show beforeunload prompt when training is running
- Remove Ctrl+W/Cmd+W interception (browsers don't allow it)
- Store shutdown task on app.state to prevent GC

---------

Co-authored-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-27 04:36:08 -07:00
Daniel Han
1fb9fe3304
Fix orphan server cleanup killing user's own llama-server (#4622)
* fix: only kill studio-managed llama-server processes, not user's own servers

_kill_orphaned_servers() checked for "unsloth" anywhere in the process
cmdline, which matched the user's own llama-server when serving models
from unsloth/ HF repos (the model path in -m contains "unsloth"). This
caused the user's server to get SIGKILLed on Studio startup, destroying
their prompt cache and forcing full model re-loads.

Narrow the check to only match processes whose binary path lives under
~/.unsloth/llama.cpp/ (the Studio install directory).

* Address review: cover env var paths, move Path.home() inside try block

- Also check LLAMA_SERVER_PATH and UNSLOTH_LLAMA_CPP_PATH so orphans
  from custom install locations are still cleaned up.
- Move studio_dirs construction inside the try/except so a Path.home()
  failure (containers without HOME) does not crash the constructor.

* Address reviewer feedback: proper path ancestry, /proc/pid/exe, legacy paths

Changes based on 10-reviewer consensus:

- Use Path.is_relative_to() instead of substring matching to prevent
  false positives on sibling paths like ~/.unsloth/llama.cpp-backup/.
- Use /proc/<pid>/exe (symlink to real binary) instead of parsing the
  first cmdline token, which breaks on paths with spaces. Falls back
  to cmdline parsing on non-Linux or when /proc is unavailable.
- Add legacy in-tree install paths (project_root/llama.cpp/ and
  project_root/bin/) so orphans from older setup.sh are still cleaned.
- Treat LLAMA_SERVER_PATH as an exact binary match rather than widening
  it to its parent directory, which could match unrelated servers in
  shared locations like /usr/local/bin/.
- Keep everything inside the try/except so Path.home() failures in
  containers do not crash the constructor.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address review: add Linux platform guard and log cleanup errors

- Guard pgrep fallback with sys.platform check so it does not crash
  on Windows/macOS when psutil is unavailable.
- Replace silent except-pass with logger.warning for observability.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-27 04:33:04 -07:00
Daniel Han
b1c3a1e857
fix: replace [huggingfacenotorch] with no-torch-runtime.txt requirements (#4649)
The [huggingfacenotorch] extras only exist in pyproject.toml but are
NOT published on PyPI, so uv pip install "unsloth[huggingfacenotorch]"
fails on fresh installs from the registry.

Fix: add studio/backend/requirements/no-torch-runtime.txt with the
runtime deps (safetensors, transformers, datasets, accelerate, etc.)
that mirror [huggingfacenotorch] from pyproject.toml. In no-torch mode:
1. install.sh/ps1 install unsloth + unsloth-zoo with --no-deps
2. SKIP_STUDIO_BASE=0 so install_python_stack.py's NO_TORCH branch runs
3. install_python_stack.py installs no-torch-runtime.txt
2026-03-27 03:58:51 -07:00
Daniel Han
9d68621614
Streaming tool detection: guard late tool_calls, filter incomplete fragments (#4648)
* Guard against late tool_calls after visible content, filter incomplete fragments

1. If visible content was already emitted (_last_emitted is non-empty)
   when delta.tool_calls arrives, ignore the tool_calls instead of
   reclassifying the turn as a tool call. llama-server never
   interleaves content and tool_calls (they are mutually exclusive),
   but this guard is defensive for other OpenAI-compatible backends.

2. Filter out incomplete structured tool_calls fragments before
   execution. Entries with empty function.name (from truncation by
   max_tokens, disconnect, or interruption) are skipped instead of
   being passed to execute_tool().

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-27 03:40:14 -07:00
Wasim Yousef Said
5c7c3883cb
feat: update app icons to rounded logo (#4640)
Replace favicon.png, unsloth-gem.png, and unsloth.ico with rounded.png.
Update install.sh to source rounded.png for Linux/macOS shortcuts.
2026-03-27 03:18:20 -07:00
Daniel Han
79d9bf0c9a
Fix GGUF GPU fit check to account for KV cache VRAM (#4623)
* fix: account for KV cache in GGUF GPU fit check and auto-cap context length

The GPU fit check only compared GGUF file size against free VRAM,
ignoring KV cache memory. Models with large native context lengths
(e.g. Qwen3.5-9B at 262k) would pass the fit check since the GGUF
is only 5.6 GB, but the KV cache at 262k context needs ~40 GB at
f16. This caused llama-server to silently fall back to CPU inference.

Changes:
- Parse block_count, head_count_kv, head_count, and embedding_length
  from GGUF metadata alongside context_length
- Add KV cache VRAM estimation based on architecture params and the
  selected cache quantization type (f16, q8_0, q4_0, etc.)
- Auto-reduce context length to the maximum that fits in available
  GPU VRAM when the native context would exceed it
- Include estimated KV cache size in the _select_gpus total so the
  fit decision reflects actual runtime memory, not just file size

For the reported scenario (Qwen3.5-9B on RTX 3090 with 22415 MiB
free), context is auto-reduced from 262144 to ~63k with f16 KV cache,
keeping the model fully on GPU. With q4_0 KV cache quantization the
context can reach ~226k.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: resolve 6 bugs in KV cache VRAM estimation and add test harness

- Fix q8_0 BPE constant: 1.125 -> 34/32 (1.0625) to match llama.cpp block size
- Fix _fit_context_to_vram returning min_ctx when weights exceed budget
  (should return requested_ctx unchanged, let --fit handle it)
- Fix binary search inflating below-2048 requests (lo=min_ctx=2048 > hi)
- Fix n_ctx=0 regressing to 4096 when metadata unavailable (preserve sentinel)
- Fix multi-GPU auto-cap using single-GPU budget instead of aggregate
- Fix _context_length being overwritten with capped effective value

Add tests/test_gguf_kv_vram.py: 43 cross-platform pytest tests covering
pure logic, integration (monkeypatched load_model), and real GGUF parsing.
Runs in an isolated uv venv with only pytest -- no GPU/torch/structlog needed.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: complete _effective_context_length lifecycle

- Initialize _effective_context_length in __init__ (prevents AttributeError)
- Reset _effective_context_length in unload_model (prevents stale values)
- Update context_length property to return effective (capped) value for
  the UI/API, falling back to native _context_length if not set

* fix: multi-GPU selection tries smallest subset first

The previous approach summed all GPUs' memory to cap context, then
selected GPUs afterward. This was overly optimistic for heterogeneous
setups (e.g., 48 GiB + 4 GiB): the context was inflated by the tiny
GPU's contribution, then both GPUs were dragged in.

Now we try GPU subsets from smallest (1 GPU) to largest, capping
context for each. We pick the smallest subset where the model+KV
fits. This prefers single-GPU when possible (simpler, no tensor
split overhead) and avoids pulling in GPUs that barely help.

Add tests: test_multi_gpu_prefers_fewer_gpus,
test_multi_gpu_heterogeneous.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: prefer fewer GPUs over higher context in GPU selection

Multi-GPU inference is slower due to tensor-split overhead, so we
should prefer fewer GPUs with reduced context over more GPUs with
full context. Now the loop stops at the first GPU subset where the
model fits, rather than continuing to find subsets that allow higher
context. Only if the model can't fit on N GPUs do we try N+1.

This preserves the original behavior: use multi-GPU only when the
model doesn't fit on a single GPU.

* fix: make _kill_orphaned_servers cross-platform via psutil

Replace pgrep + os.kill(SIGKILL) with psutil.process_iter() and
proc.kill(), which work on Linux, macOS, and Windows. Build an
allowlist of install roots matching _find_llama_server_binary so
only studio-managed servers are killed.

* fix: skip KV estimation loop when effective context is unknown

When n_ctx=0 and GGUF metadata lacks context_length, effective_ctx
stays 0. _estimate_kv_cache_bytes(0) returns 0, so a GPU could be
selected with no KV headroom. Guard the loop with effective_ctx > 0
to fall back to file-size-only GPU selection in this case.

* chore: temporarily remove test harness (will add back separately)

* refactor: deduplicate UINT32/UINT64 handling in GGUF parser

Replace duplicated if/elif chains for vtype 4 and 10 with a single
block using setattr. No behavioral change.

* fix: honor explicit n_ctx by using multi-GPU before capping

When the user explicitly sets n_ctx, try to fit the full requested
context using _select_gpus (which adds GPUs as needed). Only cap
context if it doesn't fit on any GPU combination.

When n_ctx=0 (auto/native context), keep the existing behavior:
prefer fewer GPUs with reduced context, since multi-GPU is slower
and the user didn't ask for a specific context length.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: context_length property returns native value for frontend slider

The frontend uses context_length as the slider max. Returning the
capped effective value prevented users from requesting higher context
on reload (e.g., after switching to q4_0 KV cache). Revert to
returning the native GGUF metadata value -- the backend auto-caps
at load time regardless.

* revert: context_length returns effective (capped) value

The UI slider should show what the server is actually running at,
not the theoretical maximum. Revert to returning the effective
context length.

* fix: raise minimum context floor from 2048 to 4096

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-27 03:14:42 -07:00
Daniel Han
e318da21a7
Fix ~1.2s TTFT penalty when tools are enabled in Studio (#4639)
* Fix ~1.2s TTFT penalty when tools are enabled in Studio

When users enable web search, Python execution, or terminal tools,
every message gets a ~1.2s delay before any text appears -- even when
the model does not call any tool. This happens because
generate_chat_completion_with_tools() does a non-streaming detection
pass (stream: False) first, waits for the complete response, then
checks for tool calls. For the ~90% of messages that don't trigger a
tool call, this blocking wait is entirely wasted.

Root cause: the detection pass payload uses stream: False, forcing
llama-server to generate the entire response before returning any
tokens.

Fix: replace the non-streaming detection pass with a streaming pass
(stream: True) and a speculative buffer state machine that detects
tool signals in the first 1-2 SSE chunks:

- BUFFERING: accumulate content tokens, check first chars for tool
  signal prefixes (<tool_call>, <function=)
- STREAMING: no tool detected, yield tokens to caller immediately
- DRAINING: tool signal found, silently accumulate rest of stream

Three detection paths:
1. Structured delta.tool_calls -- detected instantly, transition to
   DRAINING, accumulate fragments, assemble at stream end.
2. XML tool markup in content -- buffer holds up to 32 chars checking
   for <tool_call> or <function= prefix, then transitions to DRAINING.
3. No tool signal -- first non-whitespace, non-XML char triggers
   immediate transition to STREAMING (fast path, ~90% of requests).

Safety net: after any stream ends in STREAMING state, check accumulated
content for XML tool signals. Handles rare "content before tool call"
edge case.

Additional supporting changes:
- Add headers parameter to _stream_with_retry for auth forwarding
- Share _strip_tool_markup and regex patterns between the detection
  pass and the final streaming pass (removes duplication)
- Remove the iteration==0 non-streaming content shortcut (no longer
  needed since all iterations stream directly)
- Keep the final streaming pass as fallback for max_tool_iterations
  exhaustion

Benchmarked on Qwen3.5-4B Q4_K_XL:
- No tools:              TTFT ~112ms (unchanged)
- Tools enabled, no call: TTFT ~112ms (was ~1207ms)
- Decode TPS:            226 (unchanged in all cases)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add unit tests for streaming tool detection state machine

16 tests covering every tool call parsing path:
- Plain text (no tool call) streaming
- Structured delta.tool_calls detection and fragment assembly
- XML <tool_call>JSON</tool_call> detection via buffer
- XML <function=name> tag detection via buffer
- Whitespace before tool XML
- Safety net (content then tool XML)
- Parallel multi-tool calls
- Reasoning token bypass (thinking models)
- Reasoning then tool call
- Empty response handling
- Buffer prefix timeout (HTML not mistaken for tool)
- Non-XML first char instant streaming
- False positive rejection (<tool_tip> vs <tool_call>)
- Arguments split across multiple chunks
- auto_heal_tool_calls=False respects the flag
- Metrics accumulation across tool iterations

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix reasoning-only BUFFERING, pre-tool content emission, and code duplication

Addresses review feedback on the streaming tool detection:

1. Reasoning tokens are no longer yielded during BUFFERING/DRAINING
   states. The consumer in routes/inference.py tracks prev_text across
   tool iterations without resetting it, so yielding reasoning during
   a detection pass that resolves to a tool call would corrupt the
   delta computation for subsequent iterations. Reasoning is now
   silently accumulated during detection (matching the old non-streaming
   behavior) and flushed together with content when the buffer resolves
   to STREAMING.

2. Handle reasoning-only responses in the BUFFERING resolver. When a
   thinking model emits only reasoning_content with no content tokens,
   the stream ends while still in BUFFERING state. The resolver now
   detects this case and yields reasoning as plain text (without
   <think> wrapper), matching the final streaming pass behavior for
   models like Qwen3 in always-think mode.

3. Replace duplicated re.sub calls for stripping tool markup with
   the existing _strip_tool_markup(content_text, final=True) helper,
   removing ~40 lines of redundant regex code.

4. Update tests: adjust reasoning test expectations to match the new
   behavior (reasoning batched with content, not streamed individually
   during BUFFERING). Add test_reasoning_only_no_content for the
   reasoning-only edge case. 17/17 tests pass.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address remaining reviewer findings: late tool_call IDs and XML speculation

1. Late-arriving tool_calls.id: when a provider sends the real ID on a
   later delta chunk (after the initial one with index and function
   name), the accumulator now updates the ID instead of keeping the
   synthetic "call_{idx}" placeholder. (P2, 2/10 reviewers)

2. XML speculation respects auto_heal_tool_calls: when auto_heal is
   explicitly disabled, _TOOL_XML_SIGNALS is empty so the BUFFERING
   state never speculatively holds content for XML prefix detection.
   Content starting with literal "<tool_call>" or "<function=" text
   flows straight through without delay. (P2, 1/10 reviewers)

Skipped: finish_reason="tool_calls" without delta.tool_calls fallback
(P1, 1/10 reviewers). llama-server always sends delta.tool_calls
fragments in streaming mode. A non-streaming fallback for this edge
case would add complexity for a scenario that does not occur in
practice with the supported backend.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Check request.is_disconnected() every 20 tokens instead of every token

The disconnect check is an async round-trip that adds overhead on every
loop iteration. Since the cancel watcher in llama_cpp.py already
handles connection teardown (closes the streaming response on cancel),
this route-layer check is a secondary safety net that does not need to
run on every single token.

Check every 20 tokens across all 4 streaming paths:
- gguf_tool_stream (tool-enabled GGUF)
- gguf_stream_chunks (standard GGUF)
- audio_input_generate (audio/whisper input)
- generic backend stream (non-GGUF fallback)

* Fix safety net, DRAINING metadata, and test import path

1. Safety net no longer retroactively executes tools after visible
   content was already emitted to the user. Once _last_emitted is
   non-empty, the stream is committed to normal content mode.
   Retroactive tool execution after visible output would violate the
   streaming contract and corrupt the route-layer cumulative delta
   tracker (prev_text). The tool XML is still stripped by
   _strip_tool_markup so the user sees clean content.

2. DRAINING false-positive path now merges accumulated metrics from
   prior tool iterations instead of dropping them. Uses the same
   merge formula as the STREAMING path.

3. Test import path fixed to use repo root instead of hardcoded
   sibling directory. Works in clean checkouts and CI.

4. Renamed test_content_then_tool_xml_safety_net to
   test_content_then_tool_xml_no_retroactive_execution to reflect
   the corrected behavior.

17/17 tests pass.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Redact --api-key value from llama-server startup log

When UNSLOTH_DIRECT_STREAM=1, the generated bearer token was logged
verbatim in the startup command. Replace the secret with <redacted>
before logging.

* Remove test file temporarily

* Revert disconnect throttle, reset prev_text on tool_start, restore XML safety net

Addresses all P1 findings from reviewer round 3 (10 reviewers):

1. Revert disconnect check to every iteration (was every 20th).
   All 10 reviewers flagged this as a correctness regression for
   short streams and sparse tool event loops. The cancel watcher in
   llama_cpp.py is the primary mechanism but the route-layer check
   must remain per-iteration for completeness. [10/10]

2. Reset prev_text on tool_start in gguf_tool_stream. When a tool
   cycle begins after visible content was already streamed, the
   route-layer cumulative delta tracker (prev_text) must be reset
   so the post-tool synthesis response is not truncated or dropped.
   [9/10]

3. Remove the _last_emitted gate from the XML safety net. The gate
   was added to prevent retroactive tool execution after visible
   content, but with prev_text now reset on tool_start (#2), the
   root cause is fixed and the safety net can correctly handle
   content-then-tool-XML responses (matching pre-PR behavior).
   [8/10]

* Use None instead of {} for empty auth headers in TTS methods

* Include accumulated metrics in STREAMING metadata check

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-27 03:13:38 -07:00
Lee Jackson
0233fe7f9c
studio: setup log styling (#4494)
* refactor(studio): unify setup terminal output style and add verbose setup mode

* studio(windows): align setup.ps1 banner/steps with setup.sh (ANSI, verbose)

* studio(setup): revert nvcc path reordering to match main

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* studio(setup): restore fail-fast llama.cpp setup flow

* studio(banner): use IPv6 loopback URL when binding :: or ::1

* Fix IPv6 URL bracketing, try_quiet stderr, _step label clamp

- Bracket IPv6 display_host in external_url to produce clickable URLs
- Redirect try_quiet failure log to stderr instead of stdout
- Clamp _step label to column width to prevent negative padding

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add sandbox integration tests for PR #4494 UX fixes

Simulation harness (tests/simulate_pr4494.py) creates an isolated uv
venv, copies the real source files into it, and runs subprocess tests
for all three fixes with visual before/after demos and edge cases.

Standalone bash test (tests/test_try_quiet.sh) validates try_quiet
stderr redirect across 8 scenarios including broken-version contrast.

39 integration tests total (14 IPv6 + 15 try_quiet + 10 _step), all
existing 75 unit tests still pass.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Truncate step() labels in setup.sh to match PS1 and Python

The %-15s printf format pads short labels but does not truncate long
ones.  Change to %-15.15s so labels wider than 15 chars are clipped,
matching the PowerShell .Substring(0,15) and Python label[:15] logic.

* Remove sandbox integration tests from PR

These test files are not part of the styling fix and should not
ship with this PR.

* Show error output on failure instead of suppressing it

- install_python_stack.py: restore _red for patch_package_file
  warnings (was downgraded to _dim)
- setup.ps1: capture winget output and show on failure for CUDA,
  Node, Python, and OpenSSL installs (was piped to Out-Null)
- setup.ps1: always show git pull failure warning, not just in
  verbose mode

* Show winget error output for Git and CMake installs on failure

Same capture-and-print-on-failure pattern already used for
Node, Python, CUDA, and OpenSSL winget installs.

* fix: preserve stderr for _run_quiet error messages in setup.sh

The step() helper writes to stdout, but _run_quiet's error header
was originally sent to stderr (>&2). Without the redirect, callers
that separate stdout/stderr would miss the failure headline while
still seeing the log body on stderr. Add >&2 to both step calls
inside _run_quiet to match main's behavior.

* feat: add --verbose flag to setup and update commands

Wire UNSLOTH_VERBOSE=1 through _run_setup_script() so that
'unsloth studio update --verbose' (and the deprecated 'setup')
passes the flag to setup.sh / setup.ps1 / install_python_stack.py.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-27 03:12:48 -07:00
Daniel Han
3a5e3bbd6d
Make Studio shortcuts launch in a visible terminal (#4638)
* Make Studio shortcuts launch in a visible terminal

Studio shortcuts (Desktop/Start Menu) previously launched the server as a
hidden background process. Closing the browser tab did not stop the server,
leaving users with no obvious way to shut it down. This change makes shortcuts
open a visible terminal window so users can see server output and close the
terminal to stop Studio.

Launcher changes (install.sh):
- Add TTY detection in the launcher's main section. When a TTY is present
  (foreground mode), the launcher spawns a background browser-opener and then
  exec's the studio process directly. This means closing the terminal sends
  SIGHUP to studio, stopping it cleanly. When no TTY is present (background
  mode, e.g. macOS .app or headless), the existing _spawn_terminal behavior
  is preserved.
- Add _open_browser_when_ready helper that polls health on the specific
  launch port and opens the browser once ready.
- Add WSL fallback in _open_browser: uses powershell.exe Start-Process or
  cmd.exe /c start instead of unreliable xdg-open under WSL.

Linux .desktop shortcut:
- Change Terminal=false to Terminal=true so the desktop environment opens
  the user's default terminal emulator for the launcher.

WSL support:
- Remove the early-return that skipped WSL entirely. WSL now gets the
  launcher script and studio.conf written.
- Add WSL shortcut creation: generates Windows Desktop and Start Menu .lnk
  files via a temp PowerShell script. Targets wt.exe (Windows Terminal) with
  automatic fallback to wsl.exe. Uses WSL_DISTRO_NAME for multi-distro setups.

Windows launcher (install.ps1):
- Add Find-FreeLaunchPort function that mirrors the Unix _find_launch_port
  logic, scanning Get-NetTCPConnection for busy ports and returning the first
  free port in the configured range.
- Replace the hardcoded $basePort with the dynamic port result, with a
  MessageBox error dialog if no free port is found.

* Fix review findings: lock race, WSL quoting, Windows port fallback

Foreground lock race (10/10 reviewers):
The foreground mode released the single-instance lock before exec,
allowing a second launcher to acquire the lock and race for the same
port during startup. Move lock release into the background subshell
so it only happens after the health check passes.

WSL shortcut quoting (10/10 reviewers):
WSL_DISTRO_NAME values with spaces (e.g. "Ubuntu Preview", "Fedora
Remix for WSL") were not quoted, causing the distro name to be split
across multiple arguments. Add double-quoting around the distro name
and launcher path in the generated shortcut arguments.

Windows port fallback (3/10 reviewers):
Find-FreeLaunchPort silently assumed no ports were listening when
Get-NetTCPConnection was unavailable, which could return 8888 even
when busy. Add a Test-PortBusy fallback that probes ports with
TcpListener when Get-NetTCPConnection fails. Also scope the
Get-NetTCPConnection query to only the port range we care about.

* Skip powershell.exe shortcut creation if wslpath fails

If wslpath -w fails (returns empty), do not attempt to pass a Linux-style
path to powershell.exe -- it would always fail. Only run powershell.exe
when we have a valid Windows path for the temp PS1 script.

* Remove dead code and fix background health poll target

- Remove unused _open_browser_when_ready function
- Background mode now polls only the specific _launch_port instead of
  scanning all ports via _find_healthy_port, matching foreground behavior
- Add launcher test harness (22 unit + 19 integration tests)

* Fix port probe scope, lock ownership, and T4 test coverage

- Test-PortBusy: bind on Any instead of Loopback to match Studio's
  0.0.0.0 bind scope (prevents false-free in fallback path)
- _release_lock: verify PID ownership before removing lock dir
  (prevents a timed-out subshell from deleting another launcher's lock)
- T4 test: fail first curl call so the test actually exercises the
  lock-contention wait path instead of short-circuiting via fast path

* Temporarily remove launcher test scripts

Tests will be re-added in a follow-up PR to keep this diff focused
on the launcher changes.
2026-03-27 03:12:26 -07:00
Daniel Han
6b5da2ea0f
Fix missing num_items_in_batch in unsloth_prediction_step (#4616)
* Fix missing num_items_in_batch in unsloth_prediction_step

unsloth_prediction_step calls compute_loss without num_items_in_batch
during evaluation. This causes _unsloth_pre_compute_loss to see
num_items_in_batch=None, which triggers a spurious warning for every
model when gradient_accumulation_steps > 1:

  "Unsloth: Not an error, but {model} does not accept num_items_in_batch.
   Using gradient accumulation will be very slightly less accurate."

The standard transformers prediction_step computes num_items_in_batch
via _get_num_items_in_batch before passing it to compute_loss. This
patch does the same in unsloth_prediction_step.

Tested on Llama-3.2-1B-Instruct and Olmo-3-7B-Instruct with
gradient_accumulation_steps=3 and eval_steps=3. Warning is gone and
eval loss is computed correctly for both.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Guard _get_num_items_in_batch for older transformers versions

_get_num_items_in_batch was added in transformers 4.46. Wrap the call
in try/except so older versions fall back to num_items_in_batch=None,
which preserves the original behavior of not passing it.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-27 03:06:59 -07:00
Michael Han
0ffac92cf4
Update Install instructions.md 2026-03-27 03:04:07 -07:00
Michael Han
19298a0b41
Update Uninstall instructions.md 2026-03-27 02:56:34 -07:00
Daniel Han
5c9a22b816
Fix Gemma3N audio training stride assertion with non-reentrant checkpointing (#4629)
* Fix Gemma3N audio training stride assertion with non-reentrant checkpointing

Gemma3N audio conformer processes variable-length audio tensors
that cause stride mismatches in AOT autograd compiled backward
when non-reentrant gradient checkpointing is used. The error
manifests as:

    AssertionError: expected size 2==2, stride 1928==1936 at dim=0

This happens because the audio conformer's conv/norm layers produce
tensors whose strides vary with audio clip duration, but AOT autograd
traces the backward graph assuming fixed strides from the first batch.

The notebook sets gradient_checkpointing_kwargs={"use_reentrant": False}
and TRL 0.27.0+ also forces this. Both override Unsloth's own
use_reentrant=True set during prepare_model_for_training.

Fix: intercept gradient_checkpointing_enable on Gemma3N models to
always force use_reentrant=True, regardless of what the notebook
or TRL passes.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-27 02:53:21 -07:00
Daniel Han
3c9f0ed149
fix: use unsloth[huggingfacenotorch] instead of --no-deps in no-torch mode (#4647)
The previous --no-deps approach skipped ALL dependencies, not just
torch. This left safetensors, transformers, datasets, accelerate, etc.
missing, causing PackageNotFoundError at runtime.

Fix: in no-torch mode, install unsloth[huggingfacenotorch] (which pulls
all runtime deps except torch), then install unsloth-zoo with --no-deps
(since zoo's published metadata still declares torch as a hard dep).
This gives a working no-torch environment with all non-torch packages.

Applied to all three installer files: install.sh, install.ps1, and
studio/install_python_stack.py.
2026-03-27 02:38:11 -07:00
Daniel Han
2ffc8d2cea
tests: add no-torch / Intel Mac test suite (#4646)
* tests: add no-torch / Intel Mac test suite

Add comprehensive test coverage for the no-torch / --no-torch installer
and Studio backend changes introduced in #4624.

Shell tests (tests/sh/test_mac_intel_compat.sh):
- version_ge edge cases (9 tests)
- Architecture detection + Python version resolution (4 tests)
- get_torch_index_url on Darwin (2 tests)
- UNSLOTH_NO_TORCH propagation via SKIP_TORCH (5 tests)
- E2E uv venv creation at Python 3.12 (3 tests)
- E2E torch skip with mock uv shim (4 tests)
- UNSLOTH_NO_TORCH env propagation (4 tests)
- --python override flag parsing + resolution (11 tests)
- --no-torch flag parsing (4 tests)
- SKIP_TORCH unification (3 tests)
- CPU hint printing (2 tests)

Python tests (tests/python/test_no_torch_filtering.py):
- _filter_requirements unit tests with synthetic + real requirements files
- NO_TORCH / IS_MACOS constant parsing
- Subprocess mock of install_python_stack() across platform configs
- install.sh --no-torch flag structural + subprocess tests

Python tests (tests/python/test_studio_import_no_torch.py):
- AST checks for data_collators.py, chat_templates.py, format_conversion.py
- Parametrized venv tests (Python 3.12 + 3.13) for no-torch exec
- Dataclass instantiation without torch
- format_conversion convert functions without torch
- Negative controls (import torch fails, torchao fails)

Python tests (tests/python/test_e2e_no_torch_sandbox.py):
- Before/after import chain tests
- Edge cases (broken torch, fake torch, lazy import)
- Hardware detection without torch
- install.sh logic tests (flag parsing, version resolution)
- install_python_stack filtering tests
- Live server startup tests (opt-in via @server marker)

* fix: address review comments on test suite

- Fix always-true assertion in test_studio_import_no_torch.py (or True)
- Make IS_MACOS test platform-aware instead of hardcoding Linux
- Restore torchvision + torchaudio in server test cleanup (not just torch)
- Include server stderr in skip message for easier debugging

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-27 02:33:45 -07:00
Daniel Han
e9ac785346
fix: install.sh Mac Intel compatibility + Studio no-torch support (#4624)
* fix: install.sh Mac Intel compatibility + Studio no-torch support (#4621)

On Intel Macs (x86_64), PyTorch has no wheels for torch >= 2.3, so the
installer crashes. Even when torch is absent, Studio crashes on startup
because two files have bare top-level torch imports.

Studio's GGUF inference (llama.cpp) does not need PyTorch. Training and
HF-inference already isolate torch to subprocesses. Only 2 files in the
server startup chain had top-level torch imports preventing startup.

Changes:
- install.sh: detect architecture, default to Python 3.12 on Intel Mac,
  skip torch install, add Python 3.13.8 guard for arm64, pass
  UNSLOTH_NO_TORCH env var to setup.sh
- data_collators.py: remove unused `import torch` (no torch.* refs)
- chat_templates.py: lazy-import IterableDataset into function bodies
- install_python_stack.py: add IS_MACOS/NO_TORCH constants, skip
  torch-dependent packages, skip overrides.txt, skip triton on macOS

No existing working flow changes. Linux/WSL and macOS arm64 behavior is
identical.

* tests: add test suite for Mac Intel compat + no-torch mode

Shell tests (test_mac_intel_compat.sh):
- version_ge edge cases (9 tests)
- Architecture detection for Darwin x86_64/arm64, Linux x86_64/aarch64
- get_torch_index_url returns cpu on simulated Darwin
- UNSLOTH_NO_TORCH propagation to both setup.sh branches

Python unit tests (test_no_torch_filtering.py):
- _filter_requirements with NO_TORCH_SKIP_PACKAGES
- NO_TORCH env var parsing (true/1/TRUE/false/0/unset)
- IS_MACOS constant check
- Overrides skip and triton macOS skip guards

Python import tests (test_studio_import_no_torch.py):
- data_collators.py loads in isolated no-torch venv
- chat_templates.py has no top-level torch imports
- Negative control confirms import torch fails without torch

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tests: add E2E sandbox tests for Mac Intel no-torch mode

Replace static/synthetic test stubs with real sandbox tests:

- Shell: E2E uv venv creation at Python 3.12, mock uv shim to verify
  torch install is skipped when MAC_INTEL=true, dynamic env propagation
  test for UNSLOTH_NO_TORCH in both local and non-local install paths
- Python filtering: test real extras.txt and extras-no-deps.txt with
  NO_TORCH_SKIP_PACKAGES, subprocess mock of install_python_stack() for
  5 platform configs (NO_TORCH+macOS, Windows+NO_TORCH, normal Linux,
  Windows-only, macOS-only), VCS URL and env marker edge cases
- Python imports: parametrized Python 3.12+3.13 venv fixture, dataclass
  instantiation for all 3 collator classes, chat_templates.py exec with
  stubs, negative controls proving import torch and torchao install fail
  in no-torch venvs

91 total tests, all passing.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: address reviewer findings for Intel Mac no-torch mode

P1 fixes:
- Auto-infer NO_TORCH in install_python_stack.py via platform.machine()
  so `unsloth studio update` preserves GGUF-only mode without needing
  the UNSLOTH_NO_TORCH env var (6/10 reviewers)
- Add openai-whisper and transformers-cfg to NO_TORCH_SKIP_PACKAGES
  since both have unconditional torch dependencies (4/10 reviewers)
- Skip unsloth-zoo on Intel Mac --local installs (depends on torch)
  in both migrated and fresh install paths (1/10)
- Recreate stale 3.13 venvs as 3.12 on Intel Mac re-runs (1/10)
- Detect Apple Silicon under Rosetta via sysctl hw.optional.arm64
  and warn user to use native arm64 terminal (1/10)

P2 fixes:
- Wire new test files into tests/run_all.sh (4/10 reviewers)
- Add update-path tests (skip_base=False) for Intel Mac
- Add _infer_no_torch tests for platform auto-detection

P3 fixes:
- Fix macOS progress bar total (triton step skipped but was counted)
- Fix temp file leak when Windows + NO_TORCH filters stack

All tests pass: 30 shell, 66 Python (96 total).

* feat: add --python override flag to install.sh

Lets users force a specific Python version, e.g. ./install.sh --python 3.12.
Addresses M2 Mac users whose systems resolve to a problematic 3.13.x patch.
When --python is set, the Intel Mac stale-venv guard and 3.13.8 auto-downgrade
are skipped so the user's choice is respected.

* tests: add comprehensive E2E sandbox tests for no-torch mode

Add test_e2e_no_torch_sandbox.py with 7 test groups (43 tests total)
covering the full no-torch import chain, edge cases, and install logic:

- Group 1: BEFORE vs AFTER import chain comparison (proves the bug
  existed and the fix works by synthetically prepending top-level torch
  imports)
- Group 2: Dataclass instantiation without torch
- Group 3: Edge cases with broken/fake torch modules on sys.path
- Group 4: Hardware detection fallback to CPU without torch
- Group 5: install.sh flag parsing, version resolution, arch detection
- Group 6: install_python_stack.py NO_TORCH filtering
- Group 7: Live server startup without torch (marked @server, skipped
  when studio venv is unavailable)

All 43 tests pass on both Python 3.12 and 3.13 isolated venvs.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat: add --no-torch flag to install.sh/ps1, fix lazy import bug in dataset formatting

- Fix chat_templates.py: narrow torch IterableDataset import into inner
  try/except ImportError so dataset.map() works without torch installed
- Fix format_conversion.py: same lazy import fix for convert_chatml_to_alpaca
  and convert_alpaca_to_chatml
- Add --no-torch flag to install.sh with unified SKIP_TORCH variable
  (driven by --no-torch flag OR MAC_INTEL auto-detection)
- Add --no-torch flag to install.ps1 with $SkipTorch variable
- Print CPU hint when no GPU detected and --no-torch not set
- Replace MAC_INTEL guards with SKIP_TORCH in torch install sections
- Update shell tests (40 pass) and Python tests (90 pass)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: address reviewer findings for --no-torch installer paths

- Fix migrated-env branch in install.sh and install.ps1: check
  SKIP_TORCH first, then branch on STUDIO_LOCAL_INSTALL. Previously
  SKIP_TORCH+non-local fell into else and installed unsloth-zoo (which
  depends on torch), defeating --no-torch mode.
- Fix $env:UNSLOTH_NO_TORCH leak in install.ps1: always set to "true"
  or "false" instead of only setting on the true branch. Prevents stale
  no-torch state from leaking across runs in the same PS session.
- Fix install_python_stack.py update path: add NO_TORCH guard around
  base.txt install so unsloth studio update does not reinstall
  unsloth-zoo (which depends on torch) in no-torch mode.

* fix: install unsloth + unsloth-zoo with --no-deps in no-torch mode

Instead of skipping unsloth-zoo entirely (which breaks unsloth's
dependency on it), install both packages with --no-deps so they are
present but torch is not pulled in transitively. Applied consistently
across all no-torch paths: migrated-env, fresh-local, fresh-non-local
in install.sh, install.ps1, and install_python_stack.py.

* chore: temporarily remove test files (will be added in a follow-up)

* refactor: deduplicate SKIP_TORCH conditional branches in installers

Collapse if/else blocks that differ only by --no-deps into a single
branch with a conditional flag variable. Applied to migrated-env and
fresh-local paths in install.sh, install.ps1, and install_python_stack.py.

* fix: apply --no-deps to fresh non-local --no-torch install path

The non-local else branch was missing $_no_deps_arg/$noDepsArg, so
uv pip install unsloth would resolve torch from PyPI metadata (the
published unsloth package still declares torch as a hard dep). Now
--no-deps is applied consistently to all SKIP_TORCH code paths.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-27 02:09:21 -07:00
Daniel Han
d57a4d993d studio: fix chat CPU spike (#4632)
Inline querier identity changed every render, forcing useLiveQuery to
resubscribe continuously causing CPU spikes. Store querier in a ref and
only re-subscribe when explicit deps change.
2026-03-27 06:20:26 +00:00
Daniel Han
e62085a3d6
Fix repetition_penalty default causing 24% TPS drop in GGUF inference (#4634)
The ChatCompletionRequest Pydantic model defaulted repetition_penalty
to 1.1 when clients omitted the field. This silently forced
llama-server to perform per-token repetition scanning, dropping
streaming throughput from ~225 TPS to ~172 TPS (a 24% penalty).

The Studio frontend always sends repetition_penalty=1.0 explicitly,
so UI users were unaffected. But any API client hitting
/v1/chat/completions without setting the field (curl, third-party
integrations, Open WebUI, etc.) would get the slow path.

Benchmarked on Qwen3.5-4B Q4_K_XL, GPU 0:
- repeat_penalty=1.0: 225.2 TPS
- repeat_penalty=1.1: 172.7 TPS (24% slower)
- LM Studio (which applies rp internally): 170.8 TPS

This aligns the Pydantic default with the frontend default (1.0),
generate_chat_completion's function signature default (1.0), and
llama-server's own default (1.0).
2026-03-26 20:20:53 -07:00
Roland Tannous
e79a178200
Allow install_python_stack to run on Colab (#4633)
* Allow install_python_stack to run on Colab

The _COLAB_NO_VENV flag was setting _SKIP_PYTHON_DEPS=true, which
skipped both the PyPI version check (needs $VENV_DIR/bin/python) and
install_python_stack (uses sys.executable, works without a venv).

Introduce a separate _SKIP_VERSION_CHECK flag for the version check,
so install_python_stack still runs on Colab. The _SKIP_PYTHON_DEPS
flag remains available for the "versions match" fast path.

* Remove colab.py workarounds that broke transformers/hf-hub compatibility

PR #4601 added _pip_install_backend_deps(), _bootstrap_studio_venv(),
and _is_colab() to colab.py as workarounds for install_python_stack
being skipped on Colab. These workarounds:
- Stripped version constraints from studio.txt and installed into system Python
- Upgraded huggingface-hub to >=1.0, breaking Colab's pre-installed
  transformers which requires huggingface-hub<1.0

With install_python_stack now running on Colab (previous commit), these
workarounds are unnecessary — all deps are properly installed by setup.sh.
Restore colab.py to its original PR #4237 structure: just get_colab_url(),
show_link(), and start().

* Remove --local flag from setup.sh in Colab notebook

The --local flag is not needed for the standard Colab flow since
install_python_stack now runs on Colab and installs deps from PyPI.
2026-03-27 00:29:27 +04:00
Wasim Yousef Said
71781272dd
fix: add python-json-logger dependency to data-designer-deps (#4627) 2026-03-26 09:50:51 -07:00
Radouane Elhajali
a6fe743ebe
studio: humanize ETA display for long training runs (#4608)
* studio: humanize ETA display for long training runs

When training takes hours or days, the ETA displayed raw minutes
(e.g. '560m 50s'). This changes the format to:
- Under 1 hour: Xm Ys (unchanged)
- 1-24 hours: Xh Ym Zs
- Over 24 hours: Xd Xh Xm

* Fix formatDuration edge cases and consolidate duplicate for PR #4608

- Guard NaN/Infinity inputs with Number.isFinite() (matches formatNumber in same file)
- Add sub-minute branch so 30s displays as "30s" instead of "0m 30s"
- Accept undefined in type signature to match formatNumber pattern
- Remove duplicate formatDuration from history-card-grid.tsx and import the shared one

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-26 06:55:54 -07:00
Michael Han
937da02f6c
Update Unsloth_Studio_Colab.ipynb 2026-03-26 05:45:30 -07:00
Etherll
b3a3435ac3
fix: Windows installer fails on _yaml.pyd Access Denied (os error 5) (#4617)
* fix: avoid _yaml.pyd lock on Windows during dependency overrides

* fix: move pytorch_tokenizers and kernels to no-deps install to avoid Windows _yaml.pyd loc
2026-03-26 05:15:19 -07:00
Lee Jackson
352455610b
studio: align Dataset/Parameters/Training cards, fix expandable height, animate LoRA settings (#4614)
* fix(studio): align config cards, dynamic height for expanders, LoRA collapsible

* Fix clipping regressions in training, dataset, and params section cards

- training-section: Add hasMessage conditional so the card expands
  (min-h) when startError, vision/audio incompatibility, or config
  validation messages are present instead of always using fixed height
- dataset-section: Expand card when a local dataset is selected via
  upload (datasetSource === "upload" && selectedLocalDataset), not only
  when the Advanced panel is open
- params-section: Guard loraOpen behind isLora so switching to full
  fine-tune collapses the card instead of staying expanded from stale
  React useState

* Fix dataset card clipping for direct file uploads

Use uploadedFile instead of selectedLocalDataset in the card height
condition. selectedLocalDataset is derived from localDatasets.find()
which only resolves for Data Recipe entries, not direct file uploads
(.jsonl, .csv, .parquet, .arrow). The card already renders the Eval
Dataset panel based on uploadedFile (line 750), so the height gate
should match.

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-26 04:05:30 -07:00
Wasim Yousef Said
07abcb46de
fix: normalize search matching for recommended models and LoRA picker (#4615)
Recommended models matching the query were filtered from HF results but the Recommended section was hidden during search, causing them to vanish entirely.

- Show filtered recommended models during search by introducing `filteredRecommendedIds`
- Switch `recommendedSet` to use filtered IDs when searching so dedup against HF results is correct
- Hide empty "Hugging Face" label when recommended matches cover the query
- Add `normalizeForSearch` helper to strip separators (spaces, hyphens, underscores, dots) so queries like "llama 3" match "Llama-3.2-1B" and "qwen 2.5" matches "Qwen2.5-7B" in both the recommended model filter and the LoRA adapter filter
2026-03-26 03:40:11 -07:00
Roland Tannous
6b3eb504b2
Fix Colab setup skipping llama.cpp installation (#4618)
* Fix Colab setup skipping llama.cpp installation

The early exit 0 in the Colab no-venv path prevented setup.sh from
ever reaching the llama.cpp install section. Remove the early exit
and instead guard only the venv-dependent Python deps section, so
execution continues through to the llama.cpp prebuilt/source install.

* Simplify _SKIP_PYTHON_DEPS initialization

* Add --local flag to setup.sh in Colab notebook
2026-03-26 13:55:46 +04:00
Abhinav
74ddef1402
fix: skip flex_attention for models with non-zero attention_dropout (#4605) 2026-03-26 01:12:23 -07:00
Michael Han
d4e9b708bb
Update Install instructions.md 2026-03-25 19:55:30 -07:00
Michael Han
d3049db427
Update install instructions.md 2026-03-25 19:04:10 -07:00
Roland Tannous
88a6dfc5cd Revert "Update README.md"
This reverts commit c30e1d2029.
2026-03-25 19:54:12 +00:00
Roland Tannous
c30e1d2029
Update README.md
remove newline from windows command
2026-03-25 23:26:37 +04:00
Daniel Han
9fa67809e6 Update README.md 2026-03-25 09:43:55 -07:00
Roland Tannous
c23c3a17e9
Update README.md (#4604)
Update install instructions for studio
2026-03-25 09:42:32 -07:00
Daniel Han
55db24fc31 Update _utils.py 2026-03-25 09:40:17 -07:00
Daniel Han
baabfa0a6e
Fix Colab huggingface-hub conflict, ensurepip fallback, bump to 2026.3.14 (#4603)
* Fix Colab huggingface-hub conflict, ensurepip fallback, bump to 2026.3.14

- colab.py / setup.sh: relax == pins to >= when installing studio.txt
  on Colab so huggingface-hub does not clobber Colab's bundled version
  (breaks transformers is_offline_mode import)
- install_python_stack.py: when uv is unavailable and pip is missing
  (uv-created venvs), bootstrap via ensurepip before attempting upgrade
- Bump version to 2026.3.14
- Bump installer min version pins to 2026.3.14

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-25 09:38:02 -07:00
Daniel Han
9cb698c774 Update _utils.py 2026-03-25 09:04:23 -07:00
Daniel Han
23eb7fc0a7
Fix Colab Studio launch and setup.ps1 box alignment (#4601)
* Fix Colab Studio launch and setup.ps1 box alignment

- colab.py: when the Studio venv is missing on Colab, pip-install
  backend dependencies (structlog, fastapi, etc.) from studio.txt
  into the current Python instead of failing with ModuleNotFoundError
- setup.sh: on Colab without a venv, install backend deps into system
  Python and skip venv-dependent sections (Python stack update,
  llama.cpp build) that would otherwise fail
- setup.ps1: use PadRight(47) for the done-line so "Setup Complete!"
  and "Update Complete!" both align with the box border

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-25 09:00:08 -07:00
Daniel Han
b713a5085a
Bump installer min version to 2026.3.12 (#4600) 2026-03-25 08:40:53 -07:00
Daniel Han
55d24d7c49
feat(studio): editable context length with Apply/Reset for GGUF settings (#4592)
* feat(studio): editable context length with Apply/Reset for GGUF model settings

Previously the Context Length field was read-only and the backend
hardcoded `-c 0`, ignoring custom values entirely. KV Cache Dtype also
triggered an immediate model reload with no way to cancel.

Backend:
- llama_cpp.py: pass the actual n_ctx value to `-c` instead of always 0
- models/inference.py: relax max_seq_length to 0..1048576 (0 = model
  default) so GGUF models with large context windows are supported

Frontend:
- chat-runtime-store: add customContextLength and loadedKvCacheDtype
  state fields for dirty tracking
- chat-settings-sheet: make Context Length an editable number input,
  stop KV Cache Dtype from auto-reloading, show Apply/Reset buttons
  when either setting has been changed
- use-chat-model-runtime: send customContextLength as max_seq_length
  in the load request, reset after successful load

* fix: preserve maxSeqLength for non-GGUF models in load request

customContextLength ?? 0 sent max_seq_length=0 for non-GGUF models,
breaking the finetuning/inference path that needs the slider value.

Now uses a three-way branch:
- customContextLength set: use it (user edited GGUF context)
- GGUF without custom: 0 (model's native context)
- Non-GGUF: maxSeqLength from the sampling slider

* fix: keep max_seq_length default at 4096 for non-GGUF callers

Only relax the bounds (ge=0 for GGUF's "model default" mode,
le=1048576 for large context windows). The default stays at 4096
so API callers that omit max_seq_length still get a sane value
for non-GGUF models.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix(studio): rename trust remote code toggle and hide when no model selected

- Rename "Trust remote code" to "Enable custom code"
- Shorten subtitle to "Only enable if sure"
- Hide the toggle when no model is loaded (already hidden for GGUFs)

* fix: restore ge=128 for max_seq_length validation

Keep the minimum at 128 so the API rejects nonsensical values.
GGUF path now sends the model's native context length (from
ggufContextLength) instead of 0 when the user has not customized it.
The upper bound stays at 1048576 for large-context GGUF models.

* feat(studio): replace Context Length input with slider

Use a ParamSlider (512 to model's native context, step 512) instead
of a small number input. Shows "Max" when at the model's native
context length. Consistent with the other slider controls in the
settings panel.

* feat(studio): add editable number input alongside Context Length slider

The slider and number input stay synced -- dragging the slider updates
the number, typing a number moves the slider. The input also accepts
values beyond the slider range for power users who need custom context
lengths larger than the model default.

* fix(studio): widen context length input and use 1024 step for slider

Make the number input wider (100px) so large values like 262144 are
fully visible. Change slider step from 512 to 1024 and min from 512
to 1024.

* fix(studio): context length number input increments by 1024

* fix(studio): cap context length input at model's native max

Adds max attribute and clamps typed/incremented values so the context
length cannot exceed the GGUF model's reported context window.

* fix(studio): point "What's new" link to changelog page

Changed from /blog to /docs/new/changelog.

* fix(studio): preserve custom context length after Apply, remove stale subtitle

- After a reload with a custom context length, keep the user's value
  in the UI instead of snapping back to the model's native max.
  ggufContextLength always reports the model's native metadata value
  regardless of what -c was passed, so we need to preserve
  customContextLength when it differs from native.
- Remove "Reload to apply." from KV Cache Dtype subtitle since the
  Apply/Reset buttons now handle this.

* feat(studio): auto-enable Search and Code tools when model supports them

Previously toolsEnabled and codeToolsEnabled stayed false after loading
a model even if it reported supports_tools=true. Now both toggles are
automatically enabled when the loaded model supports tool calling,
matching the existing behavior for reasoning.

* fix(studio): auto-enable tools in autoLoadSmallestModel path

The suggestion cards trigger autoLoadSmallestModel which bypasses
selectModel entirely. It was hardcoding toolsEnabled: false and
codeToolsEnabled: false even when the model supports tool calling.
Now both are set from the load response, matching the selectModel
behavior. Also sets kvCacheDtype/loadedKvCacheDtype for dirty
tracking consistency.

* fix(studio): re-read tool flags after auto-loading model

The runtime state was captured once at the start of the chat adapter's
run(), before autoLoadSmallestModel() executes. After auto-load enables
tools in the store, the request was still built with the stale snapshot
that had toolsEnabled=false. Now re-reads the store after auto-load so
the first message includes tools.

* fix(studio): re-read entire runtime state after auto-load, not just tools

The runtime snapshot (including params.checkpoint, model id, and all
tool/reasoning flags) was captured once before auto-load. After
autoLoadSmallestModel sets the checkpoint and enables tools, the
request was still built with stale params (empty checkpoint, tools
disabled). Now re-reads the full store state after auto-load so the
first message has the correct model, tools, and reasoning flags.

* feat(studio): add Hugging Face token field in Preferences

Adds a password input under Configuration > Preferences for users to
enter their HF token. The token is persisted in localStorage and
passed to all model validate/load/download calls, replacing the
previously hardcoded null. This enables downloading gated and private
models.

* fix(studio): use model native context for GGUF auto-load, show friendly errors

The auto-load paths and selectModel for GGUF were sending
max_seq_length=4096 which now actually limits the context window
(since we fixed the backend to respect n_ctx). Changed to send 0
for GGUF, which means "use model's native context size".

Also replaced generic "An internal error occurred" messages with
user-friendly descriptions for known errors like context size
exceeded and lost connections.

LoadRequest validation changed to ge=0 to allow the GGUF "model
default" signal. The frontend slider still enforces min=128 for
non-GGUF models.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix(studio): filter out FP8 models from model search results

Hide models matching *-FP8-* or *FP8-Dynamic* from both the
recommended list and HF search results. These models are not
yet supported in the inference UI.

---------

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-25 08:32:38 -07:00
Daniel Han
6d6008a1ef
Add PID file tracking and unsloth studio stop command (#4598)
* Add PID file tracking and `unsloth studio stop` command

On macOS the .app shortcut launches Studio via osascript into a
Terminal window, then the launcher script exits. The server process
runs outside of the launcher's context with no PID file, so there
is no straightforward way to find or stop it.

This adds:
- PID file at ~/.unsloth/studio/studio.pid, written after the
  server starts and removed on graceful shutdown or via atexit
- `unsloth studio stop` command that reads the PID file and sends
  SIGTERM (or taskkill on Windows) to shut down the server

The PID file is only removed if it still contains the current
process ID, avoiding races when a new server instance replaces
a crashed one.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Move atexit PID cleanup into run_server()

The atexit registration was only in the __main__ block, so it
did not cover the `unsloth studio` CLI path that calls
run_server() directly via studio_default(). Moving it into
run_server() ensures the PID file is cleaned up on unexpected
exit regardless of entry point.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-25 08:27:27 -07:00
Daniel Han
561f0f39be Fix install.ps1 --local: pass script args to Install-UnslothStudio
The function was called with no arguments, so $args inside the function
was always empty. Script-level args (--local, --package) were never
forwarded. Use @args splatting to pass them through.
2026-03-25 15:14:51 +00:00
Daniel Han
289c7dd7bb Add --local and --package flags to install.ps1
Windows install.ps1 had no way to install from a local repo checkout,
unlike install.sh which supports ./install.sh --local. This adds:

- --local: install from the local repo via editable install (-e . --no-deps)
  after installing deps from PyPI, mirroring install.sh behavior
- --package: install a different package name for testing

The --local flag:
1. Validates pyproject.toml exists at the script's directory
2. Installs torch + unsloth deps normally
3. Overlays the local checkout with uv pip install -e <repo> --no-deps
4. Passes STUDIO_LOCAL_INSTALL and STUDIO_LOCAL_REPO to setup.ps1
2026-03-25 15:12:56 +00:00
Daniel Han
2683c2ab58
Add unsloth to User PATH on Windows after install (#4597)
After installation, `unsloth studio` only works if the user
activates the Studio venv first or uses the full absolute path.
The Desktop/Start Menu shortcuts work fine, but typing `unsloth
studio` in a fresh terminal does not.

This adds the venv Scripts dir to the persistent User PATH env
var (if not already present) so `unsloth studio` works from any
new terminal window. The current session is also updated via the
existing Refresh-SessionPath helper.
2026-03-25 08:00:44 -07:00
Roland Tannous
48a7884584
feat: multi-source model discovery (HF default, legacy cache, LM Studio) (#4591)
* feat: multi-source model discovery (HF default, legacy cache, LM Studio)

* Fix multi-source model discovery bugs

- Fix lmstudio_model_dirs: add ~/.lmstudio/models as default path,
  remove dead sys.platform branch, add dedup via seen set
- Fix _setup_cache_env: preserve legacy HF cache env vars when the
  legacy hub directory exists and is non-empty
- Fix _scan_lmstudio_dir: use absolute path for id field so
  is_local_path() returns True
- Remove LM Studio dirs from allowed_roots (scanned unconditionally)
- Replace bare except passes with logger.warning in legacy cache blocks
- Fix delete_cached_model to search both default and legacy HF caches
- Make lmstudio_dirs non-optional in TS interface (matches Python schema)
- Exclude lmstudio source from trainable model filter
- Remove unused import sys

* Scan HF default cache alongside legacy and active caches

When _setup_cache_env overrides HF_HUB_CACHE to the legacy Unsloth
path, the standard HF default cache (~/.cache/huggingface/hub) was
never scanned, hiding models downloaded before Unsloth Studio was
installed.

Add hf_default_cache_dir() and _all_hf_cache_scans() helper that
deduplicates and scans all three HF cache locations (active, legacy,
default). Used in list_local_models, list_cached_gguf,
list_cached_models, and delete_cached_model.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-25 07:48:04 -07:00
Daniel Han
ebe22c1e9e Update _utils.py 2026-03-25 07:30:40 -07:00
Daniel Han
366fb048d4
fix(studio): add bun cache validation to Windows setup.ps1 (#4596)
Port the bun cache corruption fix from setup.sh to setup.ps1.

bun's package cache can become corrupt, storing only package metadata
without actual content. This causes bun install to exit 0 but leave
binaries like tsc missing from node_modules/.bin/.

Changes:
- After bun install, verify tsc and vite exist in node_modules\.bin\
- Check for both bare names and .cmd wrappers (Windows creates both)
- If missing, clear the bun cache and retry once
- Only fall back to npm if the retry also fails
2026-03-25 07:27:08 -07:00
Daniel Han
3efea63e2f
fix(studio): source-build fallback prefers Unsloth's tested tag over upstream latest (#4593)
* fix(studio): source-build fallback prefers Unsloth's tested tag over upstream latest

When the prebuilt install fails and falls back to source build,
--resolve-llama-tag now queries the Unsloth release repo
(unslothai/llama.cpp) first to get the latest tested/approved tag
(e.g. b8508), instead of going straight to ggml-org/llama.cpp which
may return a newer untested tag (e.g. b8514).

This ensures the source-build fallback compiles the same version that
the prebuilt path would have installed, rather than a potentially
incompatible bleeding-edge release.

Resolution order for "latest":
  1. Unsloth release repo (tested/approved)
  2. ggml-org upstream (bleeding-edge)
  3. Raw requested tag string (last resort)

Changes:
- resolve_requested_llama_tag() accepts optional published_repo param
  with docstring explaining the resolution order
- CLI --resolve-llama-tag passes --published-repo through
- setup.sh and setup.ps1 pass --published-repo to --resolve-llama-tag
  with inline comments explaining the preference

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-25 07:25:47 -07:00
Daniel Han
bc9cf31478
Pin torch>=2.4,<2.11.0 in Studio installers (#4595)
torch 2.11.0 has a torch.compile/dynamo bug that causes a
StopIteration crash in dict_keys_getitem when compiling MoE
router functions (e.g. GptOssTopKRouter_forward). Pin to
<2.11.0 until the upstream fix lands.

Applies to both install.sh (Linux/macOS) and install.ps1
(Windows) fresh install paths.
2026-03-25 07:20:55 -07:00
Daniel Han
2e4569e06a
fix(studio): clear bun cache on failure and retry before falling back to npm (#4594)
bun's package cache can become corrupt, storing only package metadata
(package.json, README) without actual content (bin/, lib/). When this
happens, bun install exits 0 and reports packages as installed, but
binaries like tsc are missing from node_modules/.bin/.

For example, a corrupt typescript cache entry is 64KB (metadata only)
vs 23MB when correctly downloaded.

Changes:
- After bun install, verify tsc and vite exist in node_modules/.bin/
- If missing, clear the bun cache with bun pm cache rm and retry once
- Only fall back to npm if the retry also fails
- Revert bun installation to npm install -g bun (the binary is fine,
  the cache was the problem)
2026-03-25 07:05:02 -07:00
Daniel Han
457c42964f
fix(studio): validate bun install and retry from official source on failure (#4589)
bun install (specifically the npm "bun" shim v1.3.x installed via
npm install -g bun) can exit 0 while silently failing to install
packages. This causes the frontend build to fail with "tsc: not found"
or missing type declarations, since the fallback to npm only triggers
on a non-zero exit code.

Changes:

1. Initial bun install now tries the official bun.sh installer first
   (which gives a real bun runtime), falling back to npm install -g bun
   only if that fails.

2. After bun install reports success, verify that critical binaries
   (tsc, vite) actually exist in node_modules/.bin/. If they are
   missing, reinstall bun from the official source and retry once
   before falling back to npm.

3. Extract the bun install + validation logic into _try_bun_install()
   to avoid duplicating the check/cleanup across both attempts.
2026-03-25 06:38:32 -07:00
Roland Tannous
1f498a73e6 Revert "feat: multi-source model discovery (HF default, legacy cache, LM Studio)"
This reverts commit d56b115bb4.
2026-03-25 13:35:03 +00:00
Roland Tannous
d56b115bb4 feat: multi-source model discovery (HF default, legacy cache, LM Studio) 2026-03-25 13:24:46 +00:00
Daniel Han
ae2b1b97ba
fix(studio): add pip-installed nvidia CUDA libs to LD_LIBRARY_PATH for llama-server (#4590)
The prebuilt llama.cpp binary (cuda13-newer) links against
libcudart.so.13 and libcublas.so.13. When torch is installed via pip,
these libraries live in the venv's site-packages under
nvidia/cu13/lib/, not in /usr/local/cuda/.

The existing LD_LIBRARY_PATH logic only searched /usr/local/cuda*
paths (which have CUDA 12.x), so the CUDA backend failed to load
silently and llama-server fell back to CPU -- even with -ngl -1.

This adds a glob scan of the venv's nvidia package directories
(cu*, cudnn, nvjitlink) to LD_LIBRARY_PATH before launching
llama-server, matching where pip puts the CUDA runtime.

Tested on Colab with RTX PRO 6000 Blackwell (CUDA 13.0, pip torch):
before -- 3 MiB GPU, 0% util, CPU inference
after  -- 13317 MiB GPU, 77% util, full GPU inference

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
2026-03-25 06:24:40 -07:00
Daniel Han
d87c21aebf
fix(studio): add -ngl -1 when model fits on GPU to enable GPU offloading (#4588)
When _select_gpus determines that a GGUF model fits on the selected
GPU(s), the code sets CUDA_VISIBLE_DEVICES but never passes -ngl
(number of GPU layers) to llama-server. Without -ngl or --fit,
llama-server defaults to 0 GPU layers and runs entirely on CPU.

This adds -ngl -1 (offload all layers) in the elif branch where
gpu_indices is set and use_fit is False, so models that fit in VRAM
actually use the GPU for inference.

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
2026-03-25 06:14:33 -07:00
DoubleMathew
f4d8a246bf
Use prebuilt llama.cpp for unsloth studio setup (#4562)
* Use prebuilt llama.cpp for unsloth studio setup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix 3 issues that cause unnecessary fallback to source build

1. Make filelock import optional -- environments without filelock
   (e.g. minimal installs) crashed at import time instead of
   gracefully skipping the lock.

2. Use already-verified converter script from the hydrated source
   tree instead of re-downloading from raw.githubusercontent.com
   with no checksum. Adds symlink with copy fallback for the
   legacy filename.

3. Initialize $SkipPrebuiltInstall in setup.ps1 before first use
   to prevent potential uninitialized variable errors.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Keep network fallback in ensure_converter_scripts

Prefer the local verified copy from the hydrated source tree, but
retain the original network download as a fallback if the file is
missing. Create the legacy hyphenated filename as a symlink with a
copy fallback instead of writing a second full copy.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix 4 bugs in source-build fallback and binary_env paths

- setup.ps1: Replace git pull + checkout FETCH_HEAD with fetch + checkout -B
  to avoid detached HEAD state that breaks re-runs. Use pinned tag in both
  fetch and clone paths.
- setup.sh: Move rm -rf after cmake/git prerequisite checks so a missing
  tool no longer deletes the existing install. Add --branch tag to clone.
- install_llama_prebuilt.py: Add binary_path.parent to Linux LD_LIBRARY_PATH
  in binary_env() so bundled .so files in build/bin are found even without
  RPATH, matching the existing Windows PATH logic.
- Add test for binary_env LD_LIBRARY_PATH on Linux.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Handle unresolved "latest" tag in source-build fallback clone

When tag resolution fails and the requested tag is "latest", both
setup scripts now omit --branch from git clone so the default branch
is cloned instead of failing on a nonexistent "latest" branch/tag.
Similarly, the PS1 fetch path fetches the default ref when the tag
is "latest".

* Resolve actual latest ggml-org tag instead of using literal "latest"

When both Python tag resolution attempts fail and the requested tag
is "latest", query the GitHub API for the actual latest release tag
from ggml-org/llama.cpp (e.g. b8508) instead of passing the literal
string "latest" to git clone --branch, which would fail since no
such branch/tag exists.

setup.sh uses curl + python json parsing; setup.ps1 uses
Invoke-RestMethod. Both fall back to the raw requested tag if the
API call also fails.

* Try Unsloth release repo before ggml-org when resolving latest tag

When falling back to the GitHub API to resolve "latest", query the
Unsloth release repo (unslothai/llama.cpp) first since it has the
prebuilt binaries pinned to tested tags. Only fall back to
ggml-org/llama.cpp if the Unsloth repo query fails.

* Add comprehensive sandbox tests for PR #4562 bug fixes

35 tests covering all fixes across platforms:
- binary_env cross-platform (Linux LD_LIBRARY_PATH, Windows PATH,
  macOS DYLD_LIBRARY_PATH) with edge cases (dedup, ordering, existing paths)
- resolve_requested_llama_tag (concrete, latest, None, empty)
- setup.sh logic via subprocess: prereq check ordering (cmake/git missing
  preserves install), pinned tag in clone, fetch+checkout -B pattern,
  fetch failure warns instead of aborting
- "latest" tag resolution fallback chain (Unsloth API -> ggml-org ->
  raw) with mock curl: success, failure, malformed JSON, empty body,
  empty tag_name, env overrides
- Source code pattern verification for both .sh and .ps1 files

All 138 tests pass in isolated uv venv.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add binary_path.parent to macOS DYLD_LIBRARY_PATH in binary_env

macOS prebuilt .dylib files are overlaid into build/bin (same as
Linux), but binary_env only added install_dir to DYLD_LIBRARY_PATH.
Add binary_path.parent so the loader can find sibling dylibs even
without embedded loader paths.

Mirrors the existing fix for Linux LD_LIBRARY_PATH and the Windows
PATH pattern.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Guard --branch when resolved tag is "latest"; fix broken test assertion

When all API fallbacks fail and the tag stays as literal "latest",
omit --branch from git clone (clones default branch instead of
failing). Both setup.sh and setup.ps1 now check for "latest" before
passing --branch to git clone/fetch.

Also fix test_setup_ps1_clone_uses_branch_tag which used Python
tuple syntax (assert "x", "y" in z) that always passes. Changed to
assert "x" in z and "y" in z.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix macOS DYLD trailing colon, install_lock no-op, and debug log

- binary_env macOS: use dedupe_existing_dirs instead of raw string
  concatenation. Eliminates trailing colon in DYLD_LIBRARY_PATH
  (which causes dyld to search CWD for libraries) and deduplicates
  when binary_path.parent == install_dir. Now consistent with the
  Linux and Windows branches.
- install_lock: when filelock is not installed, use os.O_CREAT|O_EXCL
  as a fallback exclusive file lock with timeout, instead of yielding
  with no locking. Prevents concurrent installs from corrupting each
  other's staging directories.
- setup.ps1: remove [DEBUG] log line that printed to every user on
  every Windows setup run.

* Add stale-lock detection and atomic clone-then-swap

install_lock fallback (no filelock): write PID to lock file and
check if the holder process is still alive on contention. Dead PIDs
(ProcessLookupError) and unreadable lock files trigger immediate
cleanup. Live processes owned by other users (PermissionError) are
correctly recognized as alive -- the lock is not removed.

setup.sh/setup.ps1 source-build: clone into a temporary directory
first, then swap into place only on success. If git clone fails,
the existing install is preserved instead of being deleted by the
premature rm -rf.

* Remove redundant upstream_tag != release_tag check

load_approved_release_checksums compared checksums.upstream_tag
against the Unsloth release_tag, which are different namespaces
(upstream ggml-org tag vs Unsloth published tag). This only worked
because both happened to be "b8508" by convention. Would break if
Unsloth ever uses a different release naming scheme.

The existing check at parse_approved_release_checksums (line 950)
already validates the release_tag field correctly.

* Fix lock TOCTOU race and build-in-temp-dir swap

install_lock fallback: add os.fsync(fd) after writing PID to ensure
the PID is visible to racing processes before they check. Treat
empty lock files (PID not yet written) as "wait and retry" instead
of stale, closing the window where two processes could both see an
empty file, both unlink it, and both acquire the lock.

setup.sh/setup.ps1 source-build: clone AND build in a temp directory
(LLAMA_CPP_DIR.build.$$). Only swap into the final LLAMA_CPP_DIR
after the build succeeds. If clone or cmake or build fails, the temp
dir is cleaned up and the existing working install is preserved.
Previously, rm -rf ran after clone but before build, destroying the
existing install even if the build later failed.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-25 05:42:43 -07:00
Lee Jackson
cc1be75621
studio: stabilize reasoning panel scroll behavior and prevent composer overlap (#4587)
* fix(studio): reasoning panel scroll and thread footer overlap

* refactor(studio): dedupe reasoning scroll lock teardown
2026-03-25 05:32:31 -07:00
Roland Tannous
19e9c60a8e
Consolidate dual venvs and separate install from update (#4530)
* refactor: consolidate dual venvs into single ~/.unsloth/studio/unsloth_studio

* refactor: separate install.sh (first-time) from setup.sh (smart update with PyPI version check)

* fix: install.sh calls setup.sh directly, keep both setup and update CLI commands

* fix: use importlib.resources.files() directly without _path attribute

* fix: bootstrap uv before pip upgrade to handle uv venvs without pip

* fix: frontend 404 when launched via CLI, add global symlink to ~/.local/bin

* feat: add --local flag to install.sh and unsloth studio update for branch testing

* fix: resolve repo root from script location for --local installs

* feat: add --package flag to install.sh for testing with custom package names

* feat: add --package flag to unsloth studio update

* fix: always nuke venv in install.sh for clean installs

* revert: remove Windows changes, will handle in separate PR

* fix: error when --package is passed without an argument

* revert: restore Windows scripts to current main

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: always explicitly set STUDIO_LOCAL_INSTALL and STUDIO_PACKAGE_NAME env vars

* fix: pass explicit STUDIO_LOCAL_REPO env var for --local installs

* fix: align banner box for Setup vs Update labels

* deprecate: hide 'unsloth studio setup' command, point users to update/install.sh

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: check stdout not stdin for auto-launch detection (curl pipe fix)

* fix: update install URL to unsloth.ai/install.sh

* fix: update install.sh usage comments to unsloth.ai/install.sh

* fix: use --upgrade-package for base deps to preserve existing torch/CUDA installs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: --local install now also installs unsloth-zoo via base.txt before editable overlay

* fix: don't skip base packages for --local installs (editable needs unsloth-zoo)

* refactor: move --local full dep install to install.sh, keep SKIP_STUDIO_BASE for all paths

* feat: add migration support for old .venv and CWD-based installs in setup.sh

* Revert "feat: add migration support for old .venv and CWD-based installs in setup.sh"

This reverts commit 301291d002.

* feat: migrate old .venv layout in install.sh instead of always nuking

* feat: validate old .venv with torch CUDA test before migration, recovery message on launch failure

* fix: try CUDA then fall back to CPU for migration validation

* fix: upgrade unsloth/unsloth-zoo with --reinstall-package on migration to preserve torch

* remove: delete unused unsloth ui command (use unsloth studio instead)

* Fix Windows venv path mismatch between install.ps1, setup.ps1, and studio.py

install.ps1 was creating the venv CWD-relative ($VenvName = "unsloth_studio"),
setup.ps1 was using an absolute path to ".unsloth\studio\.venv", and studio.py
looks for ".unsloth\studio\unsloth_studio". All three paths were different, so
the Windows installer would never produce a working Studio setup.

install.ps1:
- Use absolute $StudioHome + $VenvDir matching the Linux install.sh layout
- Add 3-way migration: old .venv at STUDIO_HOME, CWD-relative ~/unsloth_studio
  from the previous install.ps1, or fresh creation with torch validation
- For migrated envs, upgrade unsloth while preserving existing torch/CUDA wheels
- Set SKIP_STUDIO_BASE=1 before calling setup.ps1 (matches install.sh behavior)
- Fix launch instructions to use the absolute venv path

setup.ps1:
- Change $VenvDir from ".unsloth\studio\.venv" to ".unsloth\studio\unsloth_studio"
- Add SKIP_STUDIO_BASE guard: error out if venv is missing when called from
  install.ps1 (which should have already created it)
- Differentiate "Setup" vs "Update" in banners based on SKIP_STUDIO_BASE

* setup.ps1: unconditionally error if venv missing, matching setup.sh

setup.sh always errors out if the venv does not exist (line 224-228),
telling the user to run install.sh first. setup.ps1 was conditionally
creating a bare venv with python -m venv when SKIP_STUDIO_BASE was not
set, which would produce an empty venv with no torch or unsloth. Now
setup.ps1 matches setup.sh: always error, always point to install.ps1.

* Fix --torch-backend=auto CPU solver dead-end on Linux, macOS, and Windows

On CPU-only machines, `uv pip install unsloth --torch-backend=auto`
falls back to unsloth==2024.8 because the CPU solver cannot satisfy
newer unsloth's dependencies. install.ps1 already solved this with a
two-step approach; this applies the same fix to install.sh and
install_python_stack.py.

install.sh: add get_torch_index_url() that detects GPU via nvidia-smi
and maps CUDA versions to PyTorch index URLs (matching install.ps1's
Get-TorchIndexUrl). Fresh installs now install torch first via explicit
--index-url, then install unsloth with --upgrade-package to preserve
the pre-installed torch. All 5 --torch-backend=auto removed from
primary paths.

install.ps1: add fallback else-branch when TorchIndexUrl is empty,
using --torch-backend=auto as last resort (matching install.sh).

install_python_stack.py: remove unconditional --torch-backend=auto
from _build_uv_cmd. Torch is pre-installed by install.sh/setup.ps1
by the time this runs. Callers that need it can set UV_TORCH_BACKEND.

Both install.sh and install.ps1 now share the same three-branch logic:
migrated env (upgrade-package only), normal (torch-first + index-url),
and fallback (--torch-backend=auto if URL detection fails).

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use --reinstall-package for migrated envs on both Linux and Windows

For migrated environments (moved from legacy venv location),
--reinstall-package is better than --upgrade-package because it forces
a clean reinstall even if the same version is already installed. This
ensures proper .dist-info and .pyc state in the new venv location.

--upgrade-package remains correct for the fresh install path where
torch is already installed and we just want to add unsloth without
re-resolving torch.

* Address review findings: portability, parity, and stale comments

- Replace grep -oP (GNU Perl regex) with POSIX sed in
  get_torch_index_url() so the script works on BSD grep (macOS is
  already guarded by the Darwin early-return, but Alpine/BusyBox
  would silently get the wrong CUDA tag)
- Add LC_ALL=C before nvidia-smi invocation to prevent locale-dependent
  output parsing issues
- Add warning on stderr when nvidia-smi output is unparseable, matching
  install.ps1's [WARN] message
- Add explicit unsloth-zoo positional arg to install.ps1 migrated path,
  matching install.sh (--reinstall-package alone won't install it if it
  was never present in the migrated env)
- Fix stale comment in install_python_stack.py line 392 that still
  claimed --torch-backend=auto is added by _build_uv_cmd
- Add sed to test tools directory (function now uses sed instead of grep)

* Add --index-url to migrated env path to prevent CPU torch resolution

The migrated path runs uv pip install with --reinstall-package for
unsloth/unsloth-zoo. While uv should keep existing torch as satisfied,
the resolver could still re-resolve torch as a transitive dependency.
Without --index-url pointing at the correct CUDA wheel index, the
resolver would fall back to plain PyPI and potentially pull CPU-only
torch. Adding --index-url $TORCH_INDEX_URL ensures CUDA wheels are
available if the resolver needs them.

Applied to both install.sh and install.ps1.

* Revert --index-url on migrated env path

The original install.ps1 on main already handles the migrated path
without --index-url and it works correctly. --reinstall-package only
forces reinstall of the named packages while uv keeps existing torch
as satisfied. No need for the extra flag.

* Fix unsloth studio update --local not installing local checkout

studio.py sets STUDIO_LOCAL_REPO when --local is passed, but
install_python_stack.py never read it. The update path always
installed from PyPI regardless of the --local flag.

Add a local_repo branch that first updates deps from base.txt
(with --upgrade-package to preserve torch), then overlays the
local checkout as an editable install with --no-deps.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-25 05:24:21 -07:00
Daniel Han
3446e0c489
Add ROCm (AMD GPU) support to studio setup (#4585)
* Add support for ROCm in studio setup

* Fix ROCm detection bugs: ROCM_PATH resolution, CUDA guard, compiler selection

- Set GPU_BACKEND="cuda" when nvcc is found (CUDA path was unreachable)
- Guard ROCm detection with `if [ -z "$GPU_BACKEND" ]` so CUDA takes
  priority on mixed-toolchain hosts
- Rename ROCM_PATH to ROCM_HIPCC for the hipcc binary; resolve the
  actual ROCm root via readlink -f and hipconfig -R into ROCM_ROOT
- Export both ROCM_PATH and HIP_PATH as the resolved root directory
- Use HIPCXX via hipconfig -l instead of legacy CMAKE_C_COMPILER=hipcc
- Switch grep -oP to grep -oE for portability across Linux distros
- Use GPU_TARGETS (upstream cmake variable) instead of AMDGPU_TARGETS
- Remove stale hardcoded fallback targets; let cmake auto-detect instead

* Fix gfx regex to match gfx90a (MI210/MI250/MI250X)

The grep and bash regex used {3,4} digits after 'gfx', which silently
excluded gfx90a (2 digits + letter 'a') -- the architecture for AMD
Instinct MI210, MI250, and MI250X data-center GPUs. Change to {2,4}
so all real gfx targets from gfx90a through gfx1200 are matched.

---------

Co-authored-by: edamamez <eda.zhou@amd.com>
2026-03-25 04:50:23 -07:00
cz-03
7eb48512bc
feat(tokenizer): add get_tokenizer_info() diagnostic helper (#4436)
* feat(tokenizer): add get_tokenizer_info() diagnostic helper

Adds get_tokenizer_info(tokenizer) to tokenizer_utils.py returning a concise dict of key tokenizer properties class name, is_fast, vocab size, added token count, model_max_length, padding side, special tokens (bos, eos, pad, unk), chat template presence, and total special token count. All fields use getattr(..., None) fallbacks so the function never raises on unusual or partially initialized tokenizers. Exported via __all__ alongside the existing public helpers. Useful for logging, debugging, and surfacing tokenizer state in the Unsloth Studio UI.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix docstring, remove artifact, restore valuable comments in tokenizer_utils.py

- Fix get_tokenizer_info() docstring example: correct tokenizer_class to
  PreTrainedTokenizerFast, vocab_size to 128000, swap added_tokens_count (256)
  and special_tokens_count (3) to match actual Llama-3.2-1B-Instruct output
- Remove accidentally committed "# ... (rest of file unchanged)" diff artifact
- Restore fix_sentencepiece_gguf() docstring with llama.cpp upstream link
- Restore 10 comments containing upstream URLs, model-specific workarounds,
  and non-obvious context (issue #292, sentencepiece#121, Starling hack,
  Kaggle /tmp limit, Deepseek slow tokenizer, twitter/danielhanchen references)

* Revert "Fix docstring, remove artifact, restore valuable comments in tokenizer_utils.py"

This reverts commit 4e525b734b.

* Revert all deletions, keep only get_tokenizer_info() addition

Restore tokenizer_utils.py to main and add only the new
get_tokenizer_info() function and its __all__ entry.
All comment removals, dead code cleanup, and formatting
changes from the original PR are reverted.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-25 04:29:01 -07:00
Etherll
d69d60ff19
perf(studio): upgrade to Vite 8 + auto-install bun for faster frontend builds (#4522)
* perf(studio): upgrade to Vite 8 + auto-install bun for 3x faster frontend builds

* fix(studio): make bun-to-npm fallback actually reachable

setup.sh used run_quiet() for the bun install attempt, but run_quiet
calls exit on failure. This killed the script before the npm fallback
could run, making the "falling back to npm" branch dead code.

Replace the run_quiet call with a direct bun invocation that captures
output to a temp file (same pattern, but returns instead of exiting).

Also clean up partial node_modules left by a failed bun install before
falling back to npm, in both setup.sh and build.sh. Without this, npm
inherits a corrupted node_modules tree from the failed bun run.

* fix(studio): restore commonjsOptions for dagre CJS interop

The previous commit removed build.commonjsOptions, assuming Vite 8's
Rolldown handles CJS natively. While optimizeDeps.include covers the
dev server (pre-bundling), it does NOT apply to production builds.

The resolve.alias still points @dagrejs/dagre to its .cjs.js entry,
so without commonjsOptions the production bundle fails to resolve
the CJS default export. This causes "TypeError: e is not a function"
on /chat after build (while dev mode works fine).

Restore the original commonjsOptions block to fix production builds.

* fix(studio): use motion/react instead of legacy framer-motion import

* fix(studio): address PR review findings for Vite 8 + bun upgrade

Fixes:
  - Remove bun.lock from repo and add to .gitignore (npm is source of truth)
  - Use & bun install *> $null pattern in setup.ps1 for reliable $LASTEXITCODE
  - Add Remove-Item node_modules before npm fallback in setup.ps1
  - Print bun install failure log in setup.sh before discarding
  - Add Refresh-Environment after npm install -g bun in setup.ps1
  - Tighten Node version check to ^20.19.0 || >=22.12.0 (Vite 8 requirement)
  - Add engines field to package.json
  - Use string comparison for _install_ok in build.sh
  - Remove explicit framer-motion ^11.18.2 from package.json (motion pulls
    framer-motion ^12.38.0 as its own dependency — the old pin caused a
    version conflict)

* Fix Colab Node bypass and bun.lock stale-build trigger

Gate the Colab Node shortcut on NODE_OK=true so Colab
environments with a Node version too old for Vite 8 fall
through to the nvm install path instead of silently proceeding.

Exclude bun.lock from the stale-build probe in both setup.sh
and setup.ps1 so it does not force unnecessary frontend rebuilds
on every run.

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: Shine1i <wasimysdev@gmail.com>
2026-03-25 04:27:41 -07:00
Daniel Han
be2cd7087a
Add macOS and Linux desktop shortcuts to install.sh (#4568)
* Add macOS and Linux desktop shortcuts to install.sh

Adds create_studio_shortcuts() function that creates platform-native
shortcuts after `unsloth studio setup` completes, mirroring the Windows
shortcut behavior from PR #4558.

Linux: .desktop file in ~/.local/share/applications/ and ~/Desktop/
macOS: .app bundle in ~/Applications/ with Info.plist, exec stub, and
       optional .icns icon built from unsloth-gem.png via sips+iconutil

Both platforms share a Bash launcher script at
~/.local/share/unsloth/launch-studio.sh that provides:
- Health check with service fingerprint verification
- Port scanning (8888-8908) via ss/lsof
- PID-file single-instance guard (no flock dependency)
- Terminal spawning (macOS: Terminal.app; Linux: gnome-terminal etc.)
- Browser open after health poll with 60s timeout

WSL is skipped (no native desktop environment).

* Fix 6 issues found by 10 parallel reviewers

1. [10/10] Health check now supports wget as fallback to curl via
   _http_get() helper, matching the installer's own download() pattern.
   Previously wget-only systems would time out on every launch.

2. [9/10] Exe path substitution now escapes sed metacharacters (&, \, |)
   and shell single-quotes before injection, preventing launcher
   corruption for paths like /opt/R&D/bin/unsloth.

3. [4/10] Linux .desktop Exec= field now quotes the launcher path,
   fixing launches from home directories containing spaces.

4. [3/10] macOS AppleScript command now escapes backslashes and
   double-quotes before interpolation into do script "...", fixing
   Terminal.app launch failures.

5. [3/10] Single-instance guard now uses atomic mkdir instead of
   racy check-then-write PID file, preventing duplicate concurrent
   launches on rapid double-click.

6. [1/10] Launcher now scans for a free port via _find_launch_port()
   instead of always hardcoding -p 8888, so Studio starts correctly
   when another service already occupies port 8888.

Also fixed: `open` command on Linux (openvt) no longer incorrectly
triggers the macOS browser-open path -- now gated on uname=Darwin.

* Fix mktemp guard and exe path escaping from PR review comments

Two real issues identified from automated review comments:

1. Guard mktemp -d failure in macOS icns generation. If mktemp -d
   returned empty, dirname would resolve to / and rm -rf would attempt
   to delete the root directory. Now checks that the temp dir was
   actually created before proceeding.

2. Replace sed-based exe path substitution with a conf file approach.
   The previous sed escaping broke paths containing apostrophes
   (e.g. /home/O'Connor/) because the '\'' escape introduced
   backslashes that were then double-escaped by the metacharacter
   pass. Now writes UNSLOTH_EXE to a separate studio.conf file that
   the launcher sources at runtime, eliminating all sed metacharacter
   and shell quoting interaction issues.

   This also addresses the sed -i.bak portability concern (now moot
   since sed is no longer used on the launcher file).

* Fix unbound variable crash and per-user lock in launcher

- Use ${UNSLOTH_EXE:-} so set -u does not crash before the friendly
  error message when studio.conf is missing or empty.
- Append $(id -u) to the fallback lock path so each user gets their
  own lock directory when XDG_RUNTIME_DIR is unset.

* Mark desktop shortcut as trusted for GNOME/Nautilus

On modern GNOME desktops, chmod +x alone is not sufficient to make
a .desktop file launchable by double-click on ~/Desktop. Nautilus
requires the metadata::trusted attribute to be set via gio, otherwise
it shows a warning dialog instead of launching the application.
2026-03-25 03:37:37 -07:00
Daniel Han
6872c6e850
Remove advanced CodeQL workflow in favor of default setup (#4584)
The repo has both the CodeQL "default setup" (configured in repo
settings) and this advanced workflow file enabled. GitHub does not
allow both simultaneously, causing all PR CI runs to fail with:

  "CodeQL analyses from advanced configurations cannot be processed
   when the default setup is enabled"

Since the default setup already covers the same languages (Python,
JavaScript/TypeScript) with the same build-mode (none), remove the
redundant advanced workflow file.
2026-03-25 03:34:21 -07:00
dependabot[bot]
38405cc18c
build(deps): bump oxc-parser (#4571)
Bumps the npm-oxc-validator group in /studio/backend/core/data_recipe/oxc-validator with 1 update: [oxc-parser](https://github.com/oxc-project/oxc/tree/HEAD/napi/parser).


Updates `oxc-parser` from 0.116.0 to 0.121.0
- [Release notes](https://github.com/oxc-project/oxc/releases)
- [Changelog](https://github.com/oxc-project/oxc/blob/main/napi/parser/CHANGELOG.md)
- [Commits](https://github.com/oxc-project/oxc/commits/crates_v0.121.0/napi/parser)

---
updated-dependencies:
- dependency-name: oxc-parser
  dependency-version: 0.121.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: npm-oxc-validator
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-25 02:44:38 -07:00
dependabot[bot]
f294161e26
build(deps): bump the actions group with 2 updates (#4570)
Bumps the actions group with 2 updates: [actions/checkout](https://github.com/actions/checkout) and [github/codeql-action](https://github.com/github/codeql-action).


Updates `actions/checkout` from 4 to 6
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v4...v6)

Updates `github/codeql-action` from 3 to 4
- [Release notes](https://github.com/github/codeql-action/releases)
- [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md)
- [Commits](https://github.com/github/codeql-action/compare/v3...v4)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: actions
- dependency-name: github/codeql-action
  dependency-version: '4'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-25 02:44:22 -07:00
Pete Kloehn
efedbe9740
Feature/add dependabot and codeql security checks (#4479)
* Add CodeQL analysis workflow configuration

* Add Dependabot configuration for package updates

Configure Dependabot to check for updates in various ecosystems weekly.

* Fix dependabot.yml: bun ecosystem, missing dir, grouping for PR #4479

1. studio/frontend uses bun.lock not package-lock.json, so change npm to bun
2. Add missing studio/backend/requirements/ pip entry (consumed by studio/setup.sh)
3. Add groups with patterns ["*"] to all pip/bun/npm entries to batch updates
   and avoid 30+ individual Dependabot PRs on the first run

* Consolidate pip blocks to fix overlapping directory violation

GitHub Dependabot forbids multiple same-ecosystem entries with
overlapping directories on the same branch. The root "/" directory
overlapped the 3 nested pip dirs. Merge all 4 pip blocks into one
using the `directories:` (plural) key.

Also remove redundant open-pull-requests-limit from the bun block
since grouping with patterns: ["*"] already limits PR count.

---------

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
2026-03-25 02:41:33 -07:00
Datta Nimmaturi
04359be333
[Studio] Try installing causal-conv1d from prebuilt wheels if avialable (#4547)
* Try installing causal-conv1d from prebuilt wheels if avialable

* Prefer installing mamba-ssm from wheel to speed up things

* undo python stack install changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert "undo python stack install changes"

This reverts commit d943551092.

* add comments

* Fix wheel installer: model detection, platform tags, torch pin, error handling

- Add nemotron-h (hyphen) and granite-4.0-h / granitemoehybrid to model
  detection for both causal-conv1d and mamba-ssm. These hybrid Mamba models
  were silently skipped since nemotron_h (underscore) never matches real
  HF model IDs like nvidia/Nemotron-H-8B-Base, and granite was missing
  entirely despite being a supported model in model_config.py and loader.py.
- Fix _causal_conv1d_platform_tag to detect linux_aarch64 via
  platform.machine() instead of hardcoding linux_x86_64. Both upstream
  releases publish aarch64 wheels. Drop win_amd64 since neither repo
  publishes Windows wheels (avoids a wasted HTTP probe on every run).
- Pin torch to >=2.6.0,<2.11.0 instead of <=2.10.0 to add a version floor
  and document the wheel coverage range with upstream release links.
- Strip non-numeric suffixes from torch minor version so nightly builds
  like 2.7a0 correctly resolve to wheel tag torch2.7 instead of torch2.7a0.
- Use stderr=_sp.PIPE instead of stderr=_sp.STDOUT in the env probe so
  torch import warnings do not corrupt the JSON output.
- Add timeout=30 to the env probe subprocess to prevent indefinite hangs.
- Catch Exception (not just ImportError) on the existing-install check so
  ABI-broken installs with OSError/RuntimeError are retried rather than
  silently accepted.
- Guard uv invocation with shutil.which("uv") to prevent FileNotFoundError
  crash when uv is not on PATH. Wrap the top-level ensure calls in
  try/except so failures do not kill the training worker.
- Hoist _SSM_MODEL_SUBSTRINGS to module level.
- Remove redundant --torch-backend=auto flag from direct wheel URL install.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add LFM2 to causal-conv1d detection; stop training on install failure

- Add "lfm2" to _model_wants_causal_conv1d so Studio picks up the
  fast kernel path for Liquid Foundation Model 2.
- Replace silent logger.warning on SSM dependency install failure
  with an error event that tells the user to choose another model
  and stops the training job immediately.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Catch subprocess timeout in torch probe; narrow import guard to ImportError

- _probe_causal_conv1d_env: wrap subprocess.run in try/except for
  TimeoutExpired so a slow torch import returns None (falls back to
  PyPI) instead of killing the training job.
- _install_package_wheel_first: narrow except Exception to except
  ImportError on the __import__ check so unexpected errors from a
  broken module still propagate.

* Remove unconditional torch pin from install_python_stack

The torch>=2.6.0,<2.11.0 pin was added to ensure prebuilt
causal-conv1d / mamba-ssm wheels exist, but it runs at install
time for all users regardless of model choice. This can downgrade
or unnecessarily upgrade torch. The worker already handles wheel
compatibility at training time by probing the environment and
falling back to PyPI, so the install-time pin is not needed.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-25 02:22:26 -07:00
Wasim Yousef Said
926e74509d
feat(chat): cleaner tool UI, inline LaTeX, clickable links (#4561)
* feat(chat): ghost-style tool containers

Remove borders and card styling from tool call UI. ToolFallback
uses minimal padding with indented content. ToolGroup defaults
to ghost variant with subtle background for multi-tool grouping.

* feat(chat): compact web search source pills

Switch sources from vertical full-width badges to horizontal
wrapping pills with smaller icons.

* feat(chat): left-accent code and terminal tool UI

Replace bordered card layout with a left border accent for
Python and Terminal tool output. Add timer cleanup on unmount
for the copy button in both components.

* feat(chat): inline latex and clickable links

Enable single-dollar $...$ math rendering via createMathPlugin.
Add styled link component with target=_blank for external links.

* fix(chat): inline generating indicator, static tailwind classes, misc fixes

Move generating indicator from viewport footer into assistant
message using AnimatedShinyText shimmer. Only shows when message
content is empty, hides once tool calls or text appear.

Use static size class map in SourceIcon for Tailwind v4 compat.
Use unique keys for web search sources. Remove px-3 from ghost
tool group variant.

* fix(chat): only show generating indicator while message is running

Hide the shimmer when message is cancelled or errored with no
content, preventing stale loading UI on empty completed messages.

* fix: escape currency dollar signs in LaTeX math rendering and fix TS build error

- Add preprocessLaTeX() in lib/latex.ts to escape currency patterns ($5, $1,000, $5.99, $100K)
  before they reach the math parser, preventing false positives when singleDollarTextMath is enabled.
  Code blocks and already-escaped dollars are left untouched.
- Use preprocessLaTeX via useMemo in markdown-text.tsx so Streamdown receives clean input.
- Fix TS18048 in thread.tsx: message.status?.type (optional chaining) since status can be undefined.

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-25 02:06:03 -07:00
Daniel Han
3998f67680
Bump Data Designer to 0.5.4 (removes litellm dependency) (#4569)
* Bump Data Designer to 0.5.4 (removes litellm dependency)

NVIDIA Data Designer v0.5.4 removes litellm entirely and replaces it
with native OpenAI and Anthropic adapters. This follows the litellm
supply chain incident where versions 1.82.7 and 1.82.8 were compromised
with a credential stealer.

Release notes: https://github.com/NVIDIA-NeMo/DataDesigner/releases/tag/v0.5.4

Changes:
- Bump data-designer, data-designer-config, data-designer-engine to 0.5.4
- Sync data-designer-deps.txt with 0.5.4 engine requirements:
  - Added: chardet, fsspec, mcp
  - Removed: python-json-logger, pymupdf, pymupdf4llm, mammoth
    (these remain in the unstructured-seed plugin which still needs them)
  - duckdb constraint relaxed from <1.5 to <2 (upstream fixed record_batch)
- Bump plugin lower bound to >=0.5.4

* Keep pymupdf, pymupdf4llm, mammoth in data-designer-deps

The unstructured-seed plugin is installed with --no-deps, so its
pyproject.toml dependencies are not auto-resolved. These three
packages are needed by the seed route (studio/backend/routes/
data_recipe/seed.py) and must remain in the explicit deps list.
2026-03-25 02:01:43 -07:00
Avaya Aggarwal
45d0a343b5
feat: Implement Q-GaLore optimizer and custom embedding learning rate… (#4511)
* feat: Implement Q-GaLore optimizer and custom embedding learning rate in the Unsloth trainer.

* feat: Implement QGaLoreAdamW8bit optimizer with 8-bit states, GaLore low-rank gradient projection, and optional INT8 weight quantization, along with supporting projector and tests.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat: Introduce Q-GaLore AdamW optimizer with low-rank quantized gradient projection and integrate into the trainer, along with dedicated tests.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat: Implement Q-GaLore AdamW optimizer with gradient projection and quantization, including trainer integration and corresponding tests.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix 3 bugs in Q-GaLore optimizer and add weight_quant forward hooks

1. Fix use-after-delete crash: move `del p._saved_data` after the
   weight decay block so decoupled weight decay can reference the
   current weights correctly (p.data).

2. Fix substring matching in make_q_galore_param_groups: split
   parameter names on "." and check exact component matches to
   prevent false positives (e.g. "not_q_proj" matching "q_proj").

3. Implement forward pre-hooks for weight_quant: after the optimizer
   quantizes weights to INT8, replace p.data with a 1-element
   placeholder to free float memory. A register_forward_pre_hook
   dequantizes back to float before each forward pass. The trainer
   calls install_weight_quant_hooks() when weight_quant is enabled.

4. Update test_weight_decay_uses_saved_data to match the fixed code
   path (decoupled decay uses p.data, expected value 2.7). Add
   test_weight_quant_hook_restores_float to verify the INT8-to-float
   hook round-trip.

All 24/24 Q-GaLore tests pass. Benchmarked on Llama-3.2-1B-Instruct
FFT: Q-GaLore saves 32% VRAM (10.63 -> 7.24 GB) with better loss
convergence (1.3 vs 2.0 at step 100). No regressions in 31-notebook
sweep across Llama, Qwen, Mistral, Phi, Gemma, vision, and GRPO.

* Default weight_quant to False in QGaloreConfig

Benchmarks show weight_quant=True adds ~1 GB on Llama-3.2-1B due to
INT8 copy/scale overhead exceeding savings from the placeholder trick.
Users can still opt in explicitly. The optimizer logic is unchanged.

* Optimize Q-GaLore projector and optimizer step performance

Projector (q_galore_projector.py):
- Use torch.svd_lowrank with oversampling p=10 (Halko et al. 2009) instead
  of full SVD for large matrices. Falls back to full SVD when min(m,n) <= 2*rank.
  SVD steps are 6-8x faster on Llama-3.2-1B (22s -> 3s for first step).
- Cache the dequantized ortho matrix between project() and project_back() to
  avoid redundant dequantization when quant=True.
- Replace F.cosine_similarity with torch.dot for 1-D unit vectors in the
  adaptive schedule. Remove unused torch.nn.functional import.
- Use collections.deque(maxlen=queue_size) instead of list with manual pop(0).

Optimizer (q_galore_adamw.py):
- Remove redundant .clone() on dequantized weights (line 151) and on float
  data before re-quantization (line 211). _dequantize already returns a fresh
  tensor and _quantize/_quantize_stochastic only reads its input.
- Consolidate per-group torch.cuda.synchronize() into a single call after
  all param groups complete.
- Use torch.empty instead of torch.zeros for the scalar placeholder tensor
  that is never read.

Verified: 24/24 unit tests pass. Llama-3.2-1B 61-step training produces
losses within 0.24% relative diff (correlation >0.9999) of the original.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-25 01:03:10 -07:00
Krishna Chaitanya
11606c5025
fix: remove auto wandb.finish() after train() to allow post-training evaluate() (#4564)
* fix: remove auto wandb.finish() after train() to allow post-training evaluate()

The prepare_for_training_mode wrapper unconditionally called wandb.finish()
after trainer.train() completed. This terminated the active W&B run, causing
trainer.evaluate() to fail with "You must call wandb.init() before wandb.log()".

Users who need multiple training runs in one session can call wandb.finish()
manually between runs to avoid data overwriting.

Fixes #3954

* fix: defer wandb.finish() to next train() call instead of removing it

Instead of calling wandb.finish() at the end of train() (which breaks
evaluate/log) or removing it entirely (which causes data overwriting on
multiple train() calls), defer it to the start of the next train() call.

This way:
- train() + evaluate() works (run stays open after train)
- train() + train() gets separate W&B runs (previous run finished first)
- train() + evaluate() + train() also works correctly

Also resets HF's WandbCallback._initialized flag so it re-calls
wandb.init() for the new run.

Fixes #3954

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-25 01:00:12 -07:00
Wasim Yousef Said
208862218d
feat(studio): training history persistence and past runs viewer (#4501)
* feat(db): add SQLite storage layer for training history

* feat(api): add training history endpoints and response models

* feat(training): integrate DB persistence into training event loop

* feat(ui): add training history views and card grid

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix(studio): address review issues in training history persistence

- Strip hf_token/wandb_token from config before SQLite storage
- Add UUID suffix to job_id for collision resistance
- Use isfinite() for 0.0 metric handling throughout
- Respect _should_stop in error event finalization
- Run schema DDL once per process, not per connection
- Close connection on schema init failure
- Guard cleanup_orphaned_runs at startup
- Cap _metric_buffer at 500 entries
- Make FLUSH_THRESHOLD a class constant
- Map 'running' to 'training' phase in historical view
- Derive LR/GradNorm from history arrays in historical view
- Fix nested button with div[role=button] in history cards
- Guard String(value) against null/undefined in config popover
- Clear selectedHistoryRunId on auto tab switch

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix(studio): address round-2 review findings across training backend and frontend

Backend (training.py):
- Move state mutation after proc.start() so a failed spawn does not wedge
  the backend with is_training=True
- Create DB run row eagerly after proc.start() so runs appear in history
  during model loading, not after first metric event
- Rewrite _flush_metrics_to_db() with snapshot-before-insert pattern to
  preserve metrics arriving during the write and retain buffer on failure
- Guard eval_loss with float() coercion and math.isfinite(), matching the
  existing grad_norm guard
- Increase pump thread join timeout from 3s to 8s to cover SQLite's
  default 5s lock timeout

Frontend (studio-page.tsx):
- Fix history navigation: check isTrainingRunning instead of
  showTrainingView in onSelectRun so completed runs are not misrouted
- Replace activeTab state + auto-switch useEffect with derived tab to
  eliminate react-hooks/set-state-in-effect lint violation

Frontend (historical-training-view.tsx):
- Add explicit "running" branch to message ternary so running runs no
  longer fall through to "Training errored"
- Derive loading from detail/error state and move cleanup to effect
  return to eliminate react-hooks/set-state-in-effect lint violation

Frontend (progress-section.tsx):
- Derive stopRequested from isTrainingRunning && stopRequestedLocal to
  eliminate react-hooks/set-state-in-effect lint violation and remove
  unused useEffect import

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix(studio): resolve 3 remaining bugs from round-2 review

1. Stuck on Current Run tab [12/20]: Only force "current-run" tab when
   isTrainingRunning is true, not when stale completed-run data exists.
   After training ends, users can freely navigate to Configure.

2. Incomplete metric sanitization [7/20]: Apply float() coercion and
   isfinite() guards to loss and learning_rate, matching the existing
   pattern used by grad_norm and eval_loss. Prevents TypeError from
   string values and NaN leaks into history arrays.

3. Stop button state leak across runs [10/20]: Add key={runtime.jobId}
   to ProgressSection so React remounts it when a new run starts,
   resetting stopRequestedLocal state.

* fix(studio): deduplicate loss/lr sanitization in training event handler

Reuse _safe_loss/_safe_lr from the progress update block instead of
re-sanitizing the same raw event values for metric history.

* fix(studio): restore loss > 0 guard to prevent eval steps injecting 0.0 into metric histories

Round-2/3 fixes relaxed the history append guard from `loss > 0` to
`loss is not None`, which let eval-only log events (where loss defaults
to 0.0) append fake zeros into loss_history and lr_history. Restore the
`loss > 0` check to match the worker's own has_train_loss gate. The
float() coercion and isfinite() sanitization from round-3 remain intact.

* fix(studio): resolve training history bugs — nullable loss/lr, tab nav, sparkline

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-25 00:58:55 -07:00
Daniel Han
3108750bb0
Remove duplicate frontend assets from wheel to reduce package size (#4567)
The wheel currently ships frontend/public/, frontend/src/, and
frontend/*.lock alongside frontend/dist/. These are build-time inputs
that Vite already copies into dist/ during the build step:

- public/ is copied verbatim into dist/ by vite build (28.6 MB duplicate)
- src/ is TSX source compiled into dist/assets/*.js (2.1 MB, not used at runtime)
- *.lock files are package manager lockfiles (0.9 MB, not used at runtime)

The backend only serves from frontend/dist/ (see main.py setup_frontend
and run.py frontend_path). Nothing references public/ or src/ at runtime.

This drops the wheel from ~62.7 MB to ~31 MB.
2026-03-24 23:48:49 -07:00
Lee Jackson
557743f027
studio: windows desktop shortcut launcher (#4558)
* feat(windows): add Studio desktop/Start shortcuts with health-check launcher

* chore(windows): bundle sloth.ico and set shortcut icons when valid

* chore(windows):add images/sloth.ico

* fix(windows): guard PSScriptRoot for Studio shortcut icon in iex installs

* fix(install): high-DPI sloth.ico and relocate to studio/frontend/publi

* chore(studio): update sloth.ico for clearer desktop and shell icons

* chore(studio): use unsloth.ico for Studio shortcut icon

* feat(windows): improve Studio shortcut launcher (fast health + browser UX)

* fix(windows): stable unsloth.ico URL and Unicode-safe Studio launcher scripts

* fix(windows): escape $ in exe path and write launcher UTF-8 with BOM

* fix(windows): skip shortcuts when Desktop or APPDATA paths are missing

* fix(install): log shortcut/icon/port failures and warn early on missing paths

* fix(install): guard missing LOCALAPPDATA before shortcut paths

* fix(install): harden New-StudioShortcuts and improve success messaging

* fix(install): include port 8908 in studio health check

* fix(install): fix launch-studio.ps1  quoting

* Fix launcher edge cases and normalize indentation in install.ps1

- Handle silent timeout: show a message when Studio is still starting
  but did not become healthy within the timeout, instead of exiting
  with no feedback
- Add -NoProfile to the visible PowerShell terminal launch so the
  user profile cannot hang or error before Studio runs
- Add a named mutex (Local\UnslothStudioLauncher) to prevent
  double-click from spawning duplicate terminals; second instance
  polls for health and opens the browser when ready
- Normalize indentation inside New-StudioShortcuts outer try block
  from mixed 8/12-space to consistent 12-space

* Simplify Get-CandidatePorts port dedup with Sort-Object -Unique

Replace the foreach/-notcontains loop with a single pipeline:
  $ports = (@($basePort) + $listening) | Sort-Object -Unique

* Harden health probe and handle abandoned mutex in launcher

- Test-StudioHealth now checks resp.service == 'Unsloth UI Backend' to
  avoid fingerprinting collisions with other local services on the same
  port range.
- Wrap the mutex WaitOne(0) call in a try/catch for
  AbandonedMutexException so the launcher recovers gracefully when a
  previous instance was killed while holding the mutex.

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-24 23:41:02 -07:00
Krishna Chaitanya
9b989ee898
fix: prevent UnicodeEncodeError on Windows CP1252 consoles in studio setup (#4563)
* fix: prevent UnicodeEncodeError on Windows CP1252 consoles in studio setup

On Windows, `unsloth studio setup` crashes with a UnicodeEncodeError
when install_python_stack.py tries to print Unicode status glyphs
(, , ⚠️) to a console that uses a legacy code page like CP1252.

Add a _safe_print() helper that catches UnicodeEncodeError and
gracefully degrades emoji to ASCII equivalents ([OK], [FAIL], [!]).
Replace all print() calls that emit Unicode glyphs with _safe_print().

Fixes #4509

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Replace Unicode dashes with ASCII in install_python_stack.py

Box-drawing (U+2500) and em dash (U+2014) chars in section dividers
and comments are themselves not representable on CP1252 -- replace
with plain ASCII dashes for consistency with the fix.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-24 22:04:09 -07:00
TR-3B
8c94b461fb
Add GRPO resume vLLM cleanup guard (#4411)
* Add GRPO resume vLLM cleanup guard

* Guard GRPO resume sleep on vLLM sleep mode

* Harden GRPO resume vLLM cleanup guard

- Wrap llm.sleep(1) in try/except so a failed sleep does not block
  training resume (best-effort cleanup)
- Also check kwargs["model_path"] which transformers.Trainer.train()
  still accepts and normalizes to resume_from_checkpoint internally

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-24 21:37:45 -07:00
Wasim Yousef Said
085f9529b6
Regroup chat settings sidebar into focused sections (#4551)
* feat(chat): regroup settings sidebar into Model, Sampling, Tools, and Preferences sections

Split the monolithic Settings collapsible into focused sections with
icons. Model section shows context length and KV cache dtype for GGUF
models, trust remote code for non GGUF. Tools section groups auto heal,
max tool calls, and tool call timeout. Preferences section holds auto
title toggle.

* feat(chat): persist collapsible section open/closed state in localStorage

Remember which sections the user expanded or collapsed across sidebar
toggles, mobile sheet reopens, and browser sessions.

* fix(chat): harden collapsible state persistence and restore defaultOpen

- Validate localStorage values are booleans before using them, preventing
  corrupted entries like string "false" from being treated as truthy
- Use Object.hasOwn() instead of `in` operator to avoid prototype chain
  matches on keys like "constructor" or "toString"
- Restore defaultOpen={true} on Model and Preferences sections so they
  are expanded on first visit, matching the old Settings section behavior
- Fix misleading Context Length description to reflect it is read-only
- Downgrade console.error to console.warn for non-critical localStorage
  parse failures

* fix(chat): remove redundant disabled styles on Context Length input

The Input component already applies opacity-50 and cursor-not-allowed
via its disabled: variants. Specifying them unconditionally in the
className is redundant.

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-24 19:39:27 -07:00
Daniel Han
acc881452f
fix: pin unsloth>=2026.3.11 in install.sh and install.ps1 (#4556)
Ensures both install scripts always pull a version that has the
litellm removal fix. Without the pin, stale uv/pip caches could
resolve the older 2026.3.10 which still had litellm in
data-designer-deps.txt, causing setup to fail at step 8/11
while PyPI has litellm quarantined.
2026-03-24 07:44:07 -07:00
Daniel Han
76a2f17470
fix(studio): remove litellm dep (quarantined on PyPI) (#4553)
litellm has been quarantined on PyPI due to a supply chain attack
in version 1.82.8 (malicious credential-stealing .pth file).
No versions are currently installable, which blocks
`unsloth studio setup` at step 8/11 (data-designer deps).

Remove litellm from the single-env data-designer requirements
so setup completes. litellm can be re-added once PyPI lifts the
quarantine.

Ref: https://github.com/BerriAI/litellm/issues/24512
2026-03-24 07:10:26 -07:00
Daniel Han
fac6f7887e Versioning 2026-03-24 06:50:36 -07:00
Daniel Han
95d2748278
fix: give @0xKushwaha git history credit for completion_only_loss fix (#4552)
* Revert "fix: handle prompt/completion datasets in slow-path BOS detection (#4548)"

This reverts commit fca83182af.

* fix: support completion_only_loss=True with prompt/completion dataset columns

When completion_only_loss=True, TRL rejects formatting_func but Unsloth's
patched _prepare_dataset/_prepare_non_packed_dataloader assumed either
formatting_func or dataset_text_field was always set, causing a catch-22.

Now handles prompt/completion columns as a third case for BOS token
detection, with a safe None fallback for all other cases.

(cherry picked from commit 978f78c6f1)

* fix: handle prompt/completion datasets in slow-path BOS detection

The slow-path check_text blocks in rl_replacements.py and
tokenizer_utils.py crash when a prompt/completion dataset is used
because they unconditionally access dataset[0][dataset_text_field]
even when the dataset does not have a text field.

This fixes both files to:
- Default dataset_text_field to None instead of raising when undefined
- Detect prompt/completion columns and concatenate them for BOS check
- Guard with isinstance(str) on both prompt and completion to handle
  conversational format (list of dicts) by setting test_text to None
- Add test_text is not None guard on has_bos_token_already to prevent
  AttributeError on NoneType.startswith()

This is the slow-path complement to unslothai/unsloth-zoo#560 which
fixes the fast-path in sft_prepare_dataset.

Closes #4486

(cherry picked from commit b6ce5786d0)

* fix: preserve chat_template BOS check when test_text is None

The has_bos_token_already guard wrapped both test_text.startswith()
and bos_token in chat_template with test_text is not None, which
disabled the chat_template BOS detection for conversational datasets
where test_text is set to None.

Split the guard so test_text is not None only applies to the
startswith() call, while bos_token in chat_template is always checked.

(cherry picked from commit 40bd8b8917)

---------

Co-authored-by: Ayush Kushwaha <148432773+ayushkushwaha240@users.noreply.github.com>
2026-03-24 06:38:57 -07:00
Daniel Han
fca83182af
fix: handle prompt/completion datasets in slow-path BOS detection (#4548)
* fix: handle prompt/completion datasets in slow-path BOS detection

The slow-path check_text blocks in rl_replacements.py and
tokenizer_utils.py crash when a prompt/completion dataset is used
because they unconditionally access dataset[0][dataset_text_field]
even when the dataset does not have a text field.

This fixes both files to:
- Default dataset_text_field to None instead of raising when undefined
- Detect prompt/completion columns and concatenate them for BOS check
- Guard with isinstance(str) on both prompt and completion to handle
  conversational format (list of dicts) by setting test_text to None
- Add test_text is not None guard on has_bos_token_already to prevent
  AttributeError on NoneType.startswith()

This is the slow-path complement to unslothai/unsloth-zoo#560 which
fixes the fast-path in sft_prepare_dataset.

Closes #4486

* fix: preserve chat_template BOS check when test_text is None

The has_bos_token_already guard wrapped both test_text.startswith()
and bos_token in chat_template with test_text is not None, which
disabled the chat_template BOS detection for conversational datasets
where test_text is set to None.

Split the guard so test_text is not None only applies to the
startswith() call, while bos_token in chat_template is always checked.
2026-03-24 05:27:59 -07:00
Michael Han
a41dbb6ab2
Add r/unsloth Reddit.md 2026-03-24 04:13:38 -07:00
Michael Han
381f509695
Adding Qwen3.5 RL.md 2026-03-24 04:06:23 -07:00
Wasim Yousef Said
c8057d911b
fix: system prompt ignored in unsloth inference (#4528)
* fix: system prompt was dropped in unsloth text and vision inference

* refactor: simplify system prompt message construction

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: use multimodal typed content parts for vision system message and add fallback

The system message content must use typed content parts
([{"type": "text", "text": ...}]) instead of a plain string to match
the multimodal processor contract (consistent with the audio path).
Plain strings cause some processors (e.g. LLaVA) to silently drop the
system prompt.

Also wraps processor.apply_chat_template in try/except so models that
reject the system role gracefully fall back to no system message with
a warning log.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: capture and log original exception in vision system prompt fallback

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-24 04:01:33 -07:00
Wasim Yousef Said
3dc212e218
fix: always show chat tool icons (#4525)
* fix: always show chat tool icons, gray out when model doesn't support them

Tool icons (Think, Search, Code) were hidden unless a model was loaded
and supported those features. Now they're always visible so users can
see and pre-select them. If a loaded model doesn't support a feature,
the button gets grayed out and disabled instead of being removed.

* refactor: centralize Qwen thinking params in store

* fix: disable tool buttons when no model is loaded

Change disabled condition from `modelLoaded && !supportsX` to
`!modelLoaded || !supportsX` so buttons are grayed out both when
no model is loaded and when the loaded model lacks the capability.

* Fix Qwen3 param clobbering and restore SuggestionItem capability guards

- Revert setReasoningEnabled() in the store to a pure boolean setter.
  Moving the Qwen3 param logic into it caused reconnect/load/refresh
  paths (which also call setReasoningEnabled) to silently overwrite
  user-customized or server-provided temperature/topP/topK/minP.

- Restore applyQwenThinkingParams() as a standalone function called
  only from explicit user toggle click handlers in thread.tsx and
  shared-composer.tsx, matching the pre-PR behavior.

- Re-add supportsReasoning/supportsTools guards in the SuggestionItem
  click handler so that clicking a suggestion card only activates
  tool toggles the loaded model actually supports.

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-24 03:26:56 -07:00
Daniel Han
77b21333fb
fix(studio): restore scroll lock on reasoning panel collapse (#4545)
PR #4543 removed useScrollLock from ReasoningRoot, causing the thread
viewport to jump when a user collapses a reasoning panel. Restore the
hook to freeze scrollTop during the 200ms collapse animation, matching
the pattern used by tool-fallback.tsx and tool-group.tsx.
2026-03-24 02:27:06 -07:00
Wasim Yousef Said
1129ea44bc
fix(studio): show Windows-specific reset-password command on login error (#4529) 2026-03-23 23:04:00 -07:00
Daniel Han
5916bcb2e3
Fix Studio port conflict detection for loopback addresses (#4532)
* Fix port conflict detection when loopback address is held by another process

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use getaddrinfo for IPv6 host support, restore emojis in terminal output

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Guard against conn.pid being None in _get_pid_on_port

psutil.net_connections() can return entries with pid=None when the
current user lacks privileges to see the owning process (common on
macOS without root, Windows without admin, and some Linux configs).

psutil.Process(None) does not raise -- it silently returns the
current process, which would make the warning incorrectly blame
Unsloth Studio itself for blocking the port.

Skip entries with pid=None so the caller falls back to the generic
"port is already in use" message instead.

* Update studio/backend/run.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-23 22:34:47 -07:00
Lee Jackson
45e4a0473a
studio: stop scroll hijack during generation and fix thinking panel layout shift (#4543)
* fix(chat): stabilize thinking panel and thread scroll during generation

* fix: match ChatGPT scroll and thinking panel behavior

- Remove autoScroll={false} from thread viewport to restore default
  follow-scroll during streaming (pauses when user scrolls up, resumes
  at bottom)
- Rewrite reasoning panel state: auto-opens on stream start, user can
  close during streaming, auto-collapses when reasoning ends, user can
  re-expand after collapse

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-23 22:33:46 -07:00
Lee Jackson
01d7dce3f4
studio: persist system prompt and preset settings across navigation (#4538)
* fix(studio): harden system prompt persistence and storage fallback

* Exclude checkpoint from localStorage persistence for PR #4538

checkpoint is backend-owned state -- refresh() already syncs it from
getInferenceStatus() on every page load. Persisting it to localStorage
causes a stale model ID to survive across backend restarts, which
prevents auto-load from triggering when no model is actually loaded.

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-03-23 22:21:04 -07:00
金黄色葡萄球君君
2b330e2f24
fix: store embedding_learning_rate on self in UnslothTrainingArguments (#4531)
Fixes #4492

The embedding_learning_rate parameter was assigned to a local variable
instead of self.embedding_learning_rate, causing UnslothTrainer.create_optimizer()
to always get None via getattr and silently fall back to a single param group.

Bug: embedding_learning_rate = embedding_learning_rate (no-op)
Fix: self.embedding_learning_rate = embedding_learning_rate
2026-03-23 21:08:29 -07:00
pre-commit-ci[bot]
a5be6904a6
[pre-commit.ci] pre-commit autoupdate (#4542)
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.15.6 → v0.15.7](https://github.com/astral-sh/ruff-pre-commit/compare/v0.15.6...v0.15.7)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-03-23 14:55:27 -07:00
Datta Nimmaturi
cd65584f19
Update issue template 2026-03-23 10:10:15 +05:30
359 changed files with 67867 additions and 8384 deletions

View file

@ -6,7 +6,7 @@ labels: bug
assignees: ''
---
Note: Please do not remove the questions. Answer beside them.
1. Did you update? `pip install --upgrade unsloth unsloth_zoo`
2. `Colab` or `Kaggle` or local / cloud
3. Number GPUs used, use `nvidia-smi`
@ -16,6 +16,7 @@ assignees: ''
```python
Put Minimal code to reproduce error here ###Remove Hugging Face token###
###Please make sure to check formatting properly, edit if needed.###
```
🦥 You can also ask via our Reddit page: https://reddit.com/r/unsloth/

27
.github/dependabot.yml vendored Normal file
View file

@ -0,0 +1,27 @@
---
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
groups:
actions:
patterns: ["*"]
- package-ecosystem: "bun"
directory: "/studio/frontend"
schedule:
interval: "weekly"
groups:
bun-frontend:
patterns: ["*"]
- package-ecosystem: "npm"
directory: "/studio/backend/core/data_recipe/oxc-validator"
schedule:
interval: "weekly"
groups:
npm-oxc-validator:
patterns: ["*"]
...

View file

@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.6
rev: v0.15.10
hooks:
- id: ruff
args:

212
README.md
View file

@ -1,28 +1,43 @@
<h1 align="center" style="margin:0;">
<a href="https://unsloth.ai/docs"><picture>
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/unslothai/unsloth/main/images/STUDIO%20WHITE%20LOGO.png">
<source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/unslothai/unsloth/main/images/STUDIO%20BLACK%20LOGO.png">
<img alt="Unsloth logo" src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/STUDIO%20BLACK%20LOGO.png" height="60" style="max-width:100%;">
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20white%20text.png">
<source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20black%20text.png">
<img alt="Unsloth logo" src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20black%20text.png" height="80" style="max-width:100%;">
</picture></a>
</h1>
<h3 align="center" style="margin: 0; margin-top: 0;">
Run and train AI models with a unified local interface.
Unsloth Studio lets you run and train models locally.
</h3>
<p align="center">
<a href="#-features">Features</a>
<a href="#-quickstart">Quickstart</a>
<a href="#-install">Quickstart</a>
<a href="#-free-notebooks">Notebooks</a>
<a href="https://unsloth.ai/docs">Documentation</a>
<a href="https://discord.com/invite/unsloth">Discord</a>
<a href="https://unsloth.ai/docs">Documentation</a>
</p>
<a href="https://unsloth.ai/docs/new/studio">
<img alt="unsloth studio ui homepage" src="https://raw.githubusercontent.com/unslothai/unsloth/main/studio/frontend/public/studio%20github%20landscape%20colab%20display.png" style="max-width: 100%; margin-bottom: 0;"></a>
<br>
<a href="https://unsloth.ai/docs/new/studio">
<img alt="unsloth studio ui homepage" src="https://github.com/user-attachments/assets/53ae17a9-d975-44ef-9686-efb4ebd0454d" style="max-width: 100%; margin-bottom: 0;"></a>
Unsloth Studio (Beta) lets you run and train text, [audio](https://unsloth.ai/docs/basics/text-to-speech-tts-fine-tuning), [embedding](https://unsloth.ai/docs/new/embedding-finetuning), [vision](https://unsloth.ai/docs/basics/vision-fine-tuning) models on Windows, Linux and macOS.
## ⚡ Get started
#### macOS, Linux, WSL:
```bash
curl -fsSL https://unsloth.ai/install.sh | sh
```
#### Windows:
```powershell
irm https://unsloth.ai/install.ps1 | iex
```
#### Community:
- [Discord](https://discord.gg/unsloth)
- [𝕏 (Twitter)](https://x.com/UnslothAI)
- [Reddit](https://reddit.com/r/unsloth)
## ⭐ Features
Unsloth provides several key features for both inference and training:
Unsloth Studio (Beta) lets you run and train text, [audio](https://unsloth.ai/docs/basics/text-to-speech-tts-fine-tuning), [embedding](https://unsloth.ai/docs/new/embedding-finetuning), [vision](https://unsloth.ai/docs/basics/vision-fine-tuning) models on Windows, Linux and macOS.
### Inference
* **Search + download + run models** including GGUF, LoRA adapters, safetensors
* **Export models**: [Save or export](https://unsloth.ai/docs/new/studio/export) models to GGUF, 16-bit safetensors and other formats.
@ -32,15 +47,15 @@ Unsloth provides several key features for both inference and training:
* We work directly with teams behind [gpt-oss](https://docs.unsloth.ai/new/gpt-oss-how-to-run-and-fine-tune#unsloth-fixes-for-gpt-oss), [Qwen3](https://www.reddit.com/r/LocalLLaMA/comments/1kaodxu/qwen3_unsloth_dynamic_ggufs_128k_context_bug_fixes/), [Llama 4](https://github.com/ggml-org/llama.cpp/pull/12889), [Mistral](models/tutorials/devstral-how-to-run-and-fine-tune.md), [Gemma 1-3](https://news.ycombinator.com/item?id=39671146), and [Phi-4](https://unsloth.ai/blog/phi4), where weve fixed bugs that improve model accuracy.
* Upload images, audio, PDFs, code, DOCX and more file types to chat with.
### Training
* Train **500+ models** up to **2x faster** with up to **70% less VRAM**, with no accuracy loss.
* Train and RL **500+ models** up to **2x faster** with up to **70% less VRAM**, with no accuracy loss.
* Custom Triton and mathematical **kernels**. See some collabs we did with [PyTorch](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide/fp8-reinforcement-learning) and [Hugging Face](https://unsloth.ai/docs/new/faster-moe).
* **Data Recipes**: [Auto-create datasets](https://unsloth.ai/docs/new/studio/data-recipe) from **PDF, CSV, DOCX** etc. Edit data in a visual-node workflow.
* Supports full fine-tuning, pretraining, 4-bit, 16-bit and, FP8 training.
* **[Reinforcement Learning](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide)** (RL): The most efficient [RL](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide) library, using **80% less VRAM** for GRPO, [FP8](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide/fp8-reinforcement-learning) etc.
* Supports full fine-tuning, RL, pretraining, 4-bit, 16-bit and, FP8 training.
* **Observability**: Monitor training live, track loss and GPU usage and customize graphs.
* **Reinforcement Learning**: The most efficient [RL](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide) library, using **80% less VRAM** for GRPO, [FP8](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide/fp8-reinforcement-learning) etc.
* [Multi-GPU](https://unsloth.ai/docs/basics/multi-gpu-training-with-unsloth) training is supported, with major improvements coming soon.
## ⚡ Quickstart
## 📥 Install
Unsloth can be used in two ways: through **[Unsloth Studio](https://unsloth.ai/docs/new/studio/)**, the web UI, or through **Unsloth Core**, the code-based version. Each has different requirements.
### Unsloth Studio (web UI)
@ -49,7 +64,7 @@ Unsloth Studio (Beta) works on **Windows, Linux, WSL** and **macOS**.
* **CPU:** Supported for Chat and Data Recipes currently
* **NVIDIA:** Training works on RTX 30/40/50, Blackwell, DGX Spark, Station and more
* **macOS:** Currently supports chat and Data Recipes. **MLX training** is coming very soon
* **AMD:** Chat works. Train with [Unsloth Core](#unsloth-core-code-based). Studio support is coming soon.
* **AMD:** Chat + Data works. Train with [Unsloth Core](#unsloth-core-code-based). Studio support is out soon.
* **Coming soon:** Training support for Apple MLX, AMD, and Intel.
* **Multi-GPU:** Available now, with a major upgrade on the way
@ -57,19 +72,20 @@ Unsloth Studio (Beta) works on **Windows, Linux, WSL** and **macOS**.
```bash
curl -fsSL https://unsloth.ai/install.sh | sh
```
If you don't have `curl`, use `wget`. Launch after setup via:
```bash
source unsloth_studio/bin/activate
unsloth studio -H 0.0.0.0 -p 8888
```
#### Windows:
```powershell
irm https://unsloth.ai/install.ps1 | iex
```
Launch after setup via:
```powershell
& .\unsloth_studio\Scripts\unsloth.exe studio -H 0.0.0.0 -p 8888
#### Launch
```bash
unsloth studio -H 0.0.0.0 -p 8888
```
#### Update
To update, use the same install commands as above. Or run (does not work on Windows):
```bash
unsloth studio update
```
#### Docker
@ -82,64 +98,8 @@ docker run -d -e JUPYTER_PASSWORD="mypassword" \
unsloth/unsloth
```
#### macOS, Linux, WSL developer installs:
```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
uv venv unsloth_studio --python 3.13
source unsloth_studio/bin/activate
uv pip install unsloth --torch-backend=auto
unsloth studio setup
unsloth studio -H 0.0.0.0 -p 8888
```
#### Windows PowerShell developer installs:
```powershell
winget install -e --id Python.Python.3.13
winget install --id=astral-sh.uv -e
uv venv unsloth_studio --python 3.13
.\unsloth_studio\Scripts\activate
uv pip install unsloth --torch-backend=auto
unsloth studio setup
unsloth studio -H 0.0.0.0 -p 8888
```
#### Nightly - MacOS, Linux, WSL:
```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
git clone --filter=blob:none https://github.com/unslothai/unsloth.git unsloth_studio
cd unsloth_studio
uv venv --python 3.13
source .venv/bin/activate
uv pip install -e . --torch-backend=auto
unsloth studio setup
unsloth studio -H 0.0.0.0 -p 8888
```
Then to launch every time:
```bash
cd unsloth_studio
source .venv/bin/activate
unsloth studio -H 0.0.0.0 -p 8888
```
#### Nightly - Windows:
Run in Windows Powershell:
```bash
winget install -e --id Python.Python.3.13
winget install --id=astral-sh.uv -e
git clone --filter=blob:none https://github.com/unslothai/unsloth.git unsloth_studio
cd unsloth_studio
uv venv --python 3.13
.\.venv\Scripts\activate
uv pip install -e . --torch-backend=auto
unsloth studio setup
unsloth studio -H 0.0.0.0 -p 8888
```
Then to launch every time:
```bash
cd unsloth_studio
.\.venv\Scripts\activate
unsloth studio -H 0.0.0.0 -p 8888
```
#### Developer, Nightly, Uninstall
To see developer, nightly and uninstallation etc. instructions, see [advanced installation](#-advanced-installation).
### Unsloth Core (code-based)
#### Linux, WSL:
@ -164,17 +124,19 @@ You can use the same Docker image as Unsloth Studio.
For RTX 50x, B200, 6000 GPUs: `uv pip install unsloth --torch-backend=auto`. Read our guides for: [Blackwell](https://unsloth.ai/docs/blog/fine-tuning-llms-with-blackwell-rtx-50-series-and-unsloth) and [DGX Spark](https://unsloth.ai/docs/blog/fine-tuning-llms-with-nvidia-dgx-spark-and-unsloth). <br>
To install Unsloth on **AMD** and **Intel** GPUs, follow our [AMD Guide](https://unsloth.ai/docs/get-started/install/amd) and [Intel Guide](https://unsloth.ai/docs/get-started/install/intel).
## Free Notebooks
## 📒 Free Notebooks
Train for free with our notebooks. Read our [guide](https://unsloth.ai/docs/get-started/fine-tuning-llms-guide). Add dataset, run, then deploy your trained model.
Train for free with our notebooks. You can use our new [free Unsloth Studio notebook](https://colab.research.google.com/github/unslothai/unsloth/blob/main/studio/Unsloth_Studio_Colab.ipynb) to run and train models for free in a web UI.
Read our [guide](https://unsloth.ai/docs/get-started/fine-tuning-llms-guide). Add dataset, run, then deploy your trained model.
| Model | Free Notebooks | Performance | Memory use |
|-----------|---------|--------|----------|
| **Gemma 4 (E2B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma4_(E2B)-Vision.ipynb) | 1.5x faster | 50% less |
| **Qwen3.5 (4B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_5_(4B)_Vision.ipynb) | 1.5x faster | 60% less |
| **gpt-oss (20B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/gpt-oss-(20B)-Fine-tuning.ipynb) | 2x faster | 70% less |
| **Qwen3.5 GSPO** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_5_(4B)_Vision_GRPO.ipynb) | 2x faster | 70% less |
| **gpt-oss (20B): GRPO** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/gpt-oss-(20B)-GRPO.ipynb) | 2x faster | 80% less |
| **Qwen3: Advanced GRPO** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_(4B)-GRPO.ipynb) | 2x faster | 50% less |
| **Gemma 3 (4B) Vision** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma3_(4B)-Vision.ipynb) | 1.7x faster | 60% less |
| **Qwen3: Advanced GRPO** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_(4B)-GRPO.ipynb) | 2x faster | 70% less |
| **embeddinggemma (300M)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/EmbeddingGemma_(300M).ipynb) | 2x faster | 20% less |
| **Mistral Ministral 3 (3B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Ministral_3_VL_(3B)_Vision.ipynb) | 1.5x faster | 60% less |
| **Llama 3.1 (8B) Alpaca** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-Alpaca.ipynb) | 2x faster | 70% less |
@ -186,6 +148,8 @@ Train for free with our notebooks. Read our [guide](https://unsloth.ai/docs/get-
- See detailed documentation for Unsloth [here](https://unsloth.ai/docs)
## 🦥 Unsloth News
- **Qwen3.6**: Qwen3.6-35B-A3B can now be trained and run in Unsloth Studio. [Blog](https://unsloth.ai/docs/models/qwen3.6)
- **Gemma 4**: Run and train Googles new models directly in Unsloth. [Blog](https://unsloth.ai/docs/models/gemma-4)
- **Introducing Unsloth Studio**: our new web UI for running and training LLMs. [Blog](https://unsloth.ai/docs/new/studio)
- **Qwen3.5** - 0.8B, 2B, 4B, 9B, 27B, 35-A3B, 112B-A10B are now supported. [Guide + notebooks](https://unsloth.ai/docs/models/qwen3.5/fine-tune)
- Train **MoE LLMs 12x faster** with 35% less VRAM - DeepSeek, GLM, Qwen and gpt-oss. [Blog](https://unsloth.ai/docs/new/faster-moe)
@ -196,13 +160,83 @@ Train for free with our notebooks. Read our [guide](https://unsloth.ai/docs/get-
- **FP8 & Vision RL**: You can now do FP8 & VLM GRPO on consumer GPUs. [FP8 Blog](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide/fp8-reinforcement-learning) • [Vision RL](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide/vision-reinforcement-learning-vlm-rl)
- **gpt-oss** by OpenAI: Read our [RL blog](https://unsloth.ai/docs/models/gpt-oss-how-to-run-and-fine-tune/gpt-oss-reinforcement-learning), [Flex Attention](https://unsloth.ai/docs/models/gpt-oss-how-to-run-and-fine-tune/long-context-gpt-oss-training) blog and [Guide](https://unsloth.ai/docs/models/gpt-oss-how-to-run-and-fine-tune).
## 🔗 Links and Resources
## 📥 Advanced Installation
The below advanced instructions are for Unsloth Studio. For Unsloth Core advanced installation, [view our docs](https://unsloth.ai/docs/get-started/install/pip-install#advanced-pip-installation).
#### Developer installs: macOS, Linux, WSL:
```bash
git clone https://github.com/unslothai/unsloth
cd unsloth
./install.sh --local
unsloth studio -H 0.0.0.0 -p 8888
```
Then to update :
```bash
unsloth studio update
```
#### Developer installs: Windows PowerShell:
```powershell
git clone https://github.com/unslothai/unsloth.git
cd unsloth
Set-ExecutionPolicy -Scope Process -ExecutionPolicy Bypass
.\install.ps1 --local
unsloth studio -H 0.0.0.0 -p 8888
```
Then to update :
```bash
unsloth studio update
```
#### Nightly: MacOS, Linux, WSL:
```bash
git clone https://github.com/unslothai/unsloth
cd unsloth
git checkout nightly
./install.sh --local
unsloth studio -H 0.0.0.0 -p 8888
```
Then to launch every time:
```bash
unsloth studio -H 0.0.0.0 -p 8888
```
#### Nightly: Windows:
Run in Windows Powershell:
```bash
git clone https://github.com/unslothai/unsloth.git
cd unsloth
git checkout nightly
Set-ExecutionPolicy -Scope Process -ExecutionPolicy Bypass
.\install.ps1 --local
unsloth studio -H 0.0.0.0 -p 8888
```
Then to launch every time:
```bash
unsloth studio -H 0.0.0.0 -p 8888
```
#### Uninstall
You can uninstall Unsloth Studio by deleting its install folder usually located under `$HOME/.unsloth/studio` on Mac/Linux/WSL and `%USERPROFILE%\.unsloth\studio` on Windows. Using the `rm -rf` commands will **delete everything**, including your history, cache:
* **MacOS, WSL, Linux:** `rm -rf ~/.unsloth/studio`
* **Windows (PowerShell):** `Remove-Item -Recurse -Force "$HOME\.unsloth\studio"`
For more info, [see our docs](https://unsloth.ai/docs/new/studio/install#uninstall).
#### Deleting model files
You can delete old model files either from the bin icon in model search or by removing the relevant cached model folder from the default Hugging Face cache directory. By default, HF uses:
* **MacOS, Linux, WSL:** `~/.cache/huggingface/hub/`
* **Windows:** `%USERPROFILE%\.cache\huggingface\hub\`
## 💚 Community and Links
| Type | Links |
| ----------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------ |
| <img width="16" src="https://cdn.prod.website-files.com/6257adef93867e50d84d30e2/66e3d80db9971f10a9757c99_Symbol.svg" />  **Discord** | [Join Discord server](https://discord.com/invite/unsloth) |
| <img width="15" src="https://redditinc.com/hs-fs/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png" />  **r/unsloth Reddit** | [Join Reddit community](https://reddit.com/r/unsloth) |
| 📚 **Documentation & Wiki** | [Read Our Docs](https://unsloth.ai/docs) |
| <img width="13" src="https://upload.wikimedia.org/wikipedia/commons/0/09/X_(formerly_Twitter)_logo_late_2025.svg" />  **Twitter (aka X)** | [Follow us on X](https://twitter.com/unslothai) |
| 💾 **Installation** | [Pip & Docker Install](https://unsloth.ai/docs/get-started/install) |
| 🔮 **Our Models** | [Unsloth Catalog](https://unsloth.ai/docs/get-started/unsloth-model-catalog) |
| ✍️ **Blog** | [Read our Blogs](https://unsloth.ai/blog) |

View file

@ -29,7 +29,22 @@ _restore_gitignores() {
}
trap _restore_gitignores EXIT
npm install
# Use bun for install if available (faster), fall back to npm.
_install_ok=false
if command -v bun &>/dev/null; then
if bun install; then
_install_ok=true
else
echo "⚠ bun install failed, falling back to npm"
rm -rf node_modules
fi
fi
if [ "$_install_ok" != "true" ]; then
if ! npm install; then
echo "❌ ERROR: package install failed" >&2
exit 1
fi
fi
npm run build # outputs to studio/frontend/dist/
_restore_gitignores

Binary file not shown.

Before

Width:  |  Height:  |  Size: 56 KiB

After

Width:  |  Height:  |  Size: 59 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 59 KiB

After

Width:  |  Height:  |  Size: 59 KiB

File diff suppressed because it is too large Load diff

1483
install.sh

File diff suppressed because it is too large Load diff

View file

@ -46,12 +46,9 @@ studio = [
"*.ps1",
"*.bat",
"frontend/dist/**/*",
"frontend/public/**/*",
"frontend/src/**/*",
"frontend/*.json",
"frontend/*.ts",
"frontend/*.js",
"frontend/*.lock",
"frontend/*.html",
"frontend/*.yaml",
"frontend/.git*",
@ -61,7 +58,8 @@ studio = [
]
[tool.setuptools.packages.find]
exclude = ["images*", "tests*", "kernels/moe*"]
include = ["unsloth*", "unsloth_cli*", "studio", "studio.backend*"]
exclude = ["images*", "tests*", "*.node_modules", "*.node_modules.*"]
[project.optional-dependencies]
triton = [
@ -84,13 +82,13 @@ huggingfacenotorch = [
"huggingface_hub>=0.34.0",
"hf_transfer",
"diffusers",
"transformers>=4.51.3,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,!=4.54.0,!=4.55.0,!=4.55.1,!=4.57.0,!=4.57.4,!=4.57.5,!=5.0.0,!=5.1.0,<=5.3.0",
"transformers>=4.51.3,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,!=4.54.0,!=4.55.0,!=4.55.1,!=4.57.0,!=4.57.4,!=4.57.5,!=5.0.0,!=5.1.0,<=5.5.0",
"trl>=0.18.2,!=0.19.0,<=0.24.0",
"sentence-transformers",
]
huggingface = [
"unsloth[huggingfacenotorch]",
"unsloth_zoo>=2026.3.4",
"unsloth_zoo>=2026.4.8",
"torchvision",
"unsloth[triton]",
]
@ -580,10 +578,10 @@ colab-ampere-torch220 = [
"flash-attn>=2.6.3 ; ('linux' in sys_platform)",
]
colab-new = [
"unsloth_zoo>=2026.3.4",
"unsloth_zoo>=2026.4.8",
"packaging",
"tyro",
"transformers>=4.51.3,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,!=4.54.0,!=4.55.0,!=4.55.1,!=4.57.0,!=4.57.4,!=4.57.5,!=5.0.0,!=5.1.0,<=5.3.0",
"transformers>=4.51.3,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,!=4.54.0,!=4.55.0,!=4.55.1,!=4.57.0,!=4.57.4,!=4.57.5,!=5.0.0,!=5.1.0,<=5.5.0",
"datasets>=3.4.1,!=4.0.*,!=4.1.0,<4.4.0",
"sentencepiece>=0.2.0",
"tqdm",

169
scripts/install_gemma4_mlx.sh Executable file
View file

@ -0,0 +1,169 @@
#!/bin/bash
set -e
# ============================================================
# Gemma 4 MLX — One-command setup + inference
#
# Usage:
# bash install_gemma4_mlx.sh [--venv-dir DIR]
#
# This script:
# 1. Creates a Python virtual environment
# 2. Installs uv, mlx-vlm, transformers
# ============================================================
# ── Output style (inspired by unsloth/install.sh) ─────────────
RULE=""
_rule_i=0
while [ "$_rule_i" -lt 52 ]; do
RULE="${RULE}"
_rule_i=$((_rule_i + 1))
done
if [ -n "${NO_COLOR:-}" ]; then
C_TITLE= C_DIM= C_OK= C_WARN= C_ERR= C_RST=
elif [ -t 1 ] || [ -n "${FORCE_COLOR:-}" ]; then
_ESC="$(printf '\033')"
C_TITLE="${_ESC}[38;5;117m"
C_DIM="${_ESC}[38;5;245m"
C_OK="${_ESC}[38;5;108m"
C_WARN="${_ESC}[38;5;136m"
C_ERR="${_ESC}[91m"
C_RST="${_ESC}[0m"
else
C_TITLE= C_DIM= C_OK= C_WARN= C_ERR= C_RST=
fi
step() { printf " ${C_DIM}%-18.18s${C_RST}${3:-$C_OK}%s${C_RST}\n" "$1" "$2"; }
substep() { printf " ${C_DIM}%-18s${2:-$C_DIM}%s${C_RST}\n" "" "$1"; }
fail() { step "error" "$1" "$C_ERR"; exit 1; }
# ── Parse flags ───────────────────────────────────────────────
VENV_DIR=""
_next_is_venv=false
for arg in "$@"; do
if [ "$_next_is_venv" = true ]; then
VENV_DIR="$arg"
_next_is_venv=false
continue
fi
case "$arg" in
--venv-dir) _next_is_venv=true ;;
esac
done
# Default venv location
if [ -z "$VENV_DIR" ]; then
VENV_DIR="$HOME/.unsloth/unsloth_gemma4_mlx"
fi
# ── Banner ────────────────────────────────────────────────────
echo ""
printf " ${C_TITLE}%s${C_RST}\n" "💎 Gemma 4 MLX Installer"
printf " ${C_DIM}%s${C_RST}\n" "$RULE"
echo ""
# ── Platform check ────────────────────────────────────────────
if [ "$(uname)" != "Darwin" ]; then
fail "MLX requires macOS with Apple Silicon. Detected: $(uname)"
fi
_ARCH=$(uname -m)
if [ "$_ARCH" != "arm64" ]; then
step "warning" "Apple Silicon recommended (detected: $_ARCH)" "$C_WARN"
fi
step "platform" "macOS ($_ARCH)"
# ── Detect Python ─────────────────────────────────────────────
PYTHON=""
for _candidate in python3.12 python3.11 python3.13 python3; do
if command -v "$_candidate" >/dev/null 2>&1; then
PYTHON="$_candidate"
break
fi
done
if [ -z "$PYTHON" ]; then
fail "Python 3 not found. Install via: brew install python@3.12"
fi
_PY_VERSION=$("$PYTHON" -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}')")
step "python" "$PYTHON ($_PY_VERSION)"
# ── Create virtual environment ────────────────────────────────
if [ -x "$VENV_DIR/bin/python" ]; then
step "venv" "using existing environment"
substep "$VENV_DIR"
else
step "venv" "creating virtual environment"
substep "$VENV_DIR"
mkdir -p "$(dirname "$VENV_DIR")"
"$PYTHON" -m venv "$VENV_DIR"
fi
# ── Install uv ───────────────────────────────────────────────
if ! command -v uv >/dev/null 2>&1; then
step "uv" "installing uv package manager..."
_uv_tmp=$(mktemp)
curl -LsSf "https://astral.sh/uv/install.sh" -o "$_uv_tmp"
sh "$_uv_tmp" </dev/null >/dev/null 2>&1
rm -f "$_uv_tmp"
if [ -f "$HOME/.local/bin/env" ]; then
. "$HOME/.local/bin/env"
fi
export PATH="$HOME/.local/bin:$PATH"
substep "done"
else
step "uv" "found $(uv --version 2>/dev/null || echo 'uv')"
fi
_VENV_PY="$VENV_DIR/bin/python"
# ── Install dependencies ──────────────────────────────────────
step "install" "installing mlx-vlm..."
uv pip install --python "$_VENV_PY" -q mlx-vlm
substep "done"
step "install" "installing transformers>=5.5.0..."
if uv pip install --python "$_VENV_PY" -q "transformers>=5.5.0" 2>/dev/null; then
substep "installed from PyPI"
else
substep "PyPI install failed (Python <3.10?), trying GitHub..."
if uv pip install --python "$_VENV_PY" -q "git+https://github.com/huggingface/transformers.git@v5.5-release" 2>/dev/null; then
substep "installed from huggingface/transformers v5.5-release"
else
step "warning" "could not install transformers>=5.5.0" "$C_WARN"
substep "tried: PyPI, huggingface/transformers v5.5-release"
fi
fi
# ── Verify installation ──────────────────────────────────────
if "$_VENV_PY" -c "import mlx_vlm"; then
substep "mlx-vlm verified"
else
fail "Installation verification failed."
fi
# ── Done ──────────────────────────────────────────────────────
echo ""
printf " ${C_TITLE}%s${C_RST}\n" "Gemma 4 MLX installed!"
printf " ${C_DIM}%s${C_RST}\n" "$RULE"
echo ""
step "available models" "unsloth/gemma-4-E2B-it-UD-MLX-4bit"
substep "unsloth/gemma-4-E4B-it-UD-MLX-4bit"
substep "unsloth/gemma-4-26b-a4b-it-UD-MLX-4bit"
substep "unsloth/gemma-4-31b-it-UD-MLX-4bit"
echo ""
step "venv activate" "source ${VENV_DIR}/bin/activate"
echo ""
step "text chat" "python -m mlx_vlm.chat --model unsloth/gemma-4-E2B-it-UD-MLX-4bit"
echo ""
step "vision chat" "python -m mlx_vlm.chat --model unsloth/gemma-4-31b-it-UD-MLX-4bit"
substep "Use /image path/to/image.jpg to load an image"
echo ""
step "gradio UI" "python -m mlx_vlm.chat_ui --model unsloth/gemma-4-31b-it-UD-MLX-4bit"
echo ""
printf " ${C_DIM}%s${C_RST}\n" "$RULE"
echo ""

View file

@ -0,0 +1,191 @@
#!/bin/bash
set -e
# ============================================================
# Qwen3.6 MLX — One-command setup + inference
#
# Usage:
# bash install_qwen3_6_mlx.sh [--venv-dir DIR]
#
# This script:
# 1. Creates a Python virtual environment
# 2. Installs uv, mlx-vlm, transformers, torch, torchvision
# ============================================================
# ── Output style (inspired by unsloth/install.sh) ─────────────
RULE=""
_rule_i=0
while [ "$_rule_i" -lt 52 ]; do
RULE="${RULE}"
_rule_i=$((_rule_i + 1))
done
if [ -n "${NO_COLOR:-}" ]; then
C_TITLE= C_DIM= C_OK= C_WARN= C_ERR= C_RST=
elif [ -t 1 ] || [ -n "${FORCE_COLOR:-}" ]; then
_ESC="$(printf '\033')"
C_TITLE="${_ESC}[38;5;117m"
C_DIM="${_ESC}[38;5;245m"
C_OK="${_ESC}[38;5;108m"
C_WARN="${_ESC}[38;5;136m"
C_ERR="${_ESC}[91m"
C_RST="${_ESC}[0m"
else
C_TITLE= C_DIM= C_OK= C_WARN= C_ERR= C_RST=
fi
step() { printf " ${C_DIM}%-18.18s${C_RST}${3:-$C_OK}%s${C_RST}\n" "$1" "$2"; }
substep() { printf " ${C_DIM}%-18s${2:-$C_DIM}%s${C_RST}\n" "" "$1"; }
fail() { step "error" "$1" "$C_ERR"; exit 1; }
# ── Parse flags ───────────────────────────────────────────────
VENV_DIR=""
_next_is_venv=false
for arg in "$@"; do
if [ "$_next_is_venv" = true ]; then
VENV_DIR="$arg"
_next_is_venv=false
continue
fi
case "$arg" in
--venv-dir) _next_is_venv=true ;;
esac
done
# Default venv location
if [ -z "$VENV_DIR" ]; then
VENV_DIR="$HOME/.unsloth/unsloth_qwen3_6_mlx"
fi
# ── Banner ────────────────────────────────────────────────────
echo ""
printf " ${C_TITLE}%s${C_RST}\n" "Qwen3.6 MLX Installer"
printf " ${C_DIM}%s${C_RST}\n" "$RULE"
echo ""
# ── Platform check ────────────────────────────────────────────
if [ "$(uname)" != "Darwin" ]; then
fail "MLX requires macOS with Apple Silicon. Detected: $(uname)"
fi
_ARCH=$(uname -m)
if [ "$_ARCH" != "arm64" ]; then
step "warning" "Apple Silicon recommended (detected: $_ARCH)" "$C_WARN"
fi
step "platform" "macOS ($_ARCH)"
# ── Detect Python ─────────────────────────────────────────────
PYTHON=""
for _candidate in python3.12 python3.11 python3.13 python3; do
if command -v "$_candidate" >/dev/null 2>&1; then
PYTHON="$_candidate"
break
fi
done
if [ -z "$PYTHON" ]; then
fail "Python 3 not found. Install via: brew install python@3.12"
fi
_PY_VERSION=$("$PYTHON" -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}')")
step "python" "$PYTHON ($_PY_VERSION)"
# ── Create virtual environment ────────────────────────────────
if [ -x "$VENV_DIR/bin/python" ]; then
step "venv" "using existing environment"
substep "$VENV_DIR"
else
step "venv" "creating virtual environment"
substep "$VENV_DIR"
mkdir -p "$(dirname "$VENV_DIR")"
"$PYTHON" -m venv "$VENV_DIR"
fi
# ── Install uv ───────────────────────────────────────────────
if ! command -v uv >/dev/null 2>&1; then
step "uv" "installing uv package manager..."
_uv_tmp=$(mktemp)
curl -LsSf "https://astral.sh/uv/install.sh" -o "$_uv_tmp"
sh "$_uv_tmp" </dev/null
rm -f "$_uv_tmp"
if [ -f "$HOME/.local/bin/env" ]; then
. "$HOME/.local/bin/env"
fi
export PATH="$HOME/.local/bin:$PATH"
substep "done"
else
step "uv" "found $(uv --version 2>/dev/null || echo 'uv')"
fi
_VENV_PY="$VENV_DIR/bin/python"
# ── Install dependencies ──────────────────────────────────────
step "install" "installing mlx-vlm..."
uv pip install --python "$_VENV_PY" -q mlx-vlm
substep "done"
step "install" "installing transformers>=5.2.0..."
if uv pip install --python "$_VENV_PY" -q "transformers>=5.2.0"; then
substep "installed from PyPI"
else
substep "PyPI install failed, trying GitHub..."
if uv pip install --python "$_VENV_PY" -q "git+https://github.com/huggingface/transformers.git"; then
substep "installed from huggingface/transformers main"
else
fail "Could not install transformers>=5.2.0 (required for Qwen3.5/3.6 model support). Please check your Python version (>=3.10 required) and network connection, then try again."
fi
fi
step "install" "installing torch + torchvision (needed for Qwen3 VL processor)..."
uv pip install --python "$_VENV_PY" -q torch torchvision
substep "done"
# ── Verify installation ──────────────────────────────────────
if "$_VENV_PY" -c "import mlx_vlm; import torch; import torchvision; import transformers"; then
substep "mlx-vlm + torch + transformers verified"
else
fail "Installation verification failed. Please ensure Python >=3.10 and try again."
fi
# ── Apply patches for multi-turn image chat ──────────────────
_PATCH_BASE="https://raw.githubusercontent.com/unslothai/unsloth/refs/heads/fix/ui-fix/unsloth/models/patches/mlx_vlm_qwen3_5"
_SITE_PKGS=$("$_VENV_PY" -c "import site; print(site.getsitepackages()[0])")
step "patch" "fixing multi-turn image chat..."
if curl -sSLf "${_PATCH_BASE}/qwen3_5.py" -o "${_SITE_PKGS}/mlx_vlm/models/qwen3_5/qwen3_5.py"; then
substep "patched qwen3_5.py (MRoPE position reset)"
else
step "warning" "failed to download qwen3_5.py patch — multi-turn image chat may not work" "$C_WARN"
fi
if curl -sSLf "${_PATCH_BASE}/generate.py" -o "${_SITE_PKGS}/mlx_vlm/generate.py"; then
substep "patched generate.py (mask trim on cache reuse)"
else
step "warning" "failed to download generate.py patch — multi-turn image chat may not work" "$C_WARN"
fi
# Clear pycache so patches take effect
find "${_SITE_PKGS}/mlx_vlm" -name "__pycache__" -type d -exec rm -rf {} + 2>/dev/null || true
substep "cleared bytecode cache"
# ── Done ──────────────────────────────────────────────────────
echo ""
printf " ${C_TITLE}%s${C_RST}\n" "Qwen3.6 MLX installed!"
printf " ${C_DIM}%s${C_RST}\n" "$RULE"
echo ""
step "available models" "unsloth/Qwen3.6-35B-A3B-UD-MLX-3bit"
substep "unsloth/Qwen3.6-35B-A3B-UD-MLX-4bit"
substep "unsloth/Qwen3.6-35B-A3B-MLX-8bit"
echo ""
step "venv activate" "source ${VENV_DIR}/bin/activate"
echo ""
step "vision chat" "python -m mlx_vlm.chat --model unsloth/Qwen3.6-35B-A3B-UD-MLX-4bit"
substep "Use /image path/to/image.jpg to load an image"
echo ""
step "gradio UI" "python -m mlx_vlm.chat_ui --model unsloth/Qwen3.6-35B-A3B-UD-MLX-4bit"
echo ""
printf " ${C_DIM}%s${C_RST}\n" "$RULE"
echo ""

View file

@ -1,157 +1,153 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/unslothai/unsloth/blob/main/studio/Unsloth_Studio_Colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"id": "6b87de59",
"metadata": {
"id": "6b87de59"
},
"source": [
"To run this, press \"*Runtime*\" and press \"*Run all*\" on a **free** Tesla T4 Google Colab instance!\n",
"<div class=\"align-center\">\n",
"<a href=\"https://unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\" width=\"115\"></a>\n",
"<a href=\"https://discord.gg/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Discord button.png\" width=\"145\"></a>\n",
"<a href=\"https://unsloth.ai/docs/\"><img src=\"https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true\" width=\"125\"></a> Join Discord if you need help + ⭐ <i>Star us on <a href=\"https://github.com/unslothai/unsloth\">Github</a> </i> ⭐\n",
"</div>\n",
"\n",
"To install Unsloth Studio on your local device, follow [our guide](https://unsloth.ai/docs/new/unsloth-studio/install). Unsloth Studio is licensed [AGPL-3.0](https://github.com/unslothai/unsloth/blob/main/studio/LICENSE.AGPL-3.0).\n",
"\n",
"### Unsloth Studio\n",
"\n",
"Train and run open models with [**Unsloth Studio**](https://unsloth.ai/docs/new/unsloth-studio/start). Currently, installation may take 30+ mins so use a newer GPU.\n",
"\n",
"\n",
"We are actively working on making Unsloth Studio install on Colab T4 GPUs faster.\n",
"\n",
"[Features](https://unsloth.ai/docs/new/unsloth-studio#features) • [Quickstart](https://unsloth.ai/docs/new/unsloth-studio/start) • [Data Recipes](https://unsloth.ai/docs/new/unsloth-studio/data-recipe) • [Studio Chat](https://unsloth.ai/docs/new/unsloth-studio/chat) • [Export](https://unsloth.ai/docs/new/unsloth-studio/export)"
]
},
{
"cell_type": "markdown",
"id": "e4206349",
"metadata": {
"id": "e4206349"
},
"source": [
"<p align=\"left\"><img src=\"https://github.com/unslothai/unsloth/raw/main/studio/frontend/public/studio%20github%20landscape%20colab%20display.png\" width=\"600\"></p>"
]
},
{
"cell_type": "markdown",
"id": "27da2957",
"metadata": {
"id": "27da2957"
},
"source": [
"### Setup: Clone repo and run setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "27e68f91",
"metadata": {
"id": "27e68f91"
},
"outputs": [],
"source": [
"!git clone --depth 1 --branch main https://github.com/unslothai/unsloth.git\n",
"%cd /content/unsloth\n",
"!chmod +x studio/setup.sh && ./studio/setup.sh"
]
},
{
"cell_type": "markdown",
"id": "3e1771a9",
"metadata": {
"id": "3e1771a9"
},
"source": [
"### Start Unsloth Studio"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "277e431e",
"metadata": {
"id": "277e431e"
},
"outputs": [],
"source": [
"import sys, time\n",
"sys.path.insert(0, \"/content/unsloth/studio/backend\")\n",
"from colab import start\n",
"start()"
]
},
{
"cell_type": "code",
"source": [
"from google.colab import output\n",
"output.serve_kernel_port_as_iframe(8888, height = 1200, width = \"100%\")\n",
"for _ in range(10000): time.sleep(300), print(\"=\", end = \"\")"
],
"metadata": {
"id": "wb9UELh--XzX"
},
"id": "wb9UELh--XzX",
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"id": "f2b0c6a1",
"metadata": {
"id": "f2b0c6a1"
},
"source": [
"And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!\n",
"\n",
"Some other resources:\n",
"1. Looking to use Unsloth locally? Read our [Installation Guide](https://unsloth.ai/docs/get-started/install) for details on installing Unsloth on Windows, Docker, AMD, Intel GPUs.\n",
"2. Learn how to do Reinforcement Learning with our [RL Guide and notebooks](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide).\n",
"3. Read our guides and notebooks for [Text-to-speech (TTS)](https://unsloth.ai/docs/basics/text-to-speech-tts-fine-tuning) and [vision](https://unsloth.ai/docs/basics/vision-fine-tuning) model support.\n",
"4. Explore our [LLM Tutorials Directory](https://unsloth.ai/docs/models/tutorials-how-to-fine-tune-and-run-llms) to find dedicated guides for each model.\n",
"5. Need help with Inference? Read our [Inference & Deployment page](https://unsloth.ai/docs/basics/inference-and-deployment) for details on using vLLM, llama.cpp, Ollama etc.\n",
"\n",
"<div class=\"align-center\">\n",
" <a href=\"https://unsloth.ai\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\" width=\"115\"></a>\n",
" <a href=\"https://discord.gg/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Discord.png\" width=\"145\"></a>\n",
" <a href=\"https://unsloth.ai/docs/\"><img src=\"https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true\" width=\"125\"></a>\n",
"\n",
" Join Discord if you need help + ⭐️ <i>Star us on <a href=\"https://github.com/unslothai/unsloth\">Github</a> </i> ⭐️\n",
"\n",
" <b>This notebook is licensed <a href=\"https://github.com/unslothai/unsloth/blob/main/studio/LICENSE.AGPL-3.0\">AGPL-3.0</a></b>\n",
"</div>"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/unslothai/unsloth/blob/main/studio/Unsloth_Studio_Colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cell_type": "markdown",
"id": "6b87de59",
"metadata": {
"id": "6b87de59"
},
"source": [
"To run this, press \"*Runtime*\" and press \"*Run all*\" on a **free** Tesla T4 Google Colab instance!\n",
"<div class=\"align-center\">\n",
"<a href=\"https://unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\" width=\"115\"></a>\n",
"<a href=\"https://discord.gg/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Discord button.png\" width=\"145\"></a>\n",
"<a href=\"https://unsloth.ai/docs/\"><img src=\"https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true\" width=\"125\"></a> Join Discord if you need help + ⭐ <i>Star us on <a href=\"https://github.com/unslothai/unsloth\">Github</a> </i> ⭐\n",
"</div>\n",
"\n",
"To install Unsloth Studio on your local device, follow [our guide](https://unsloth.ai/docs/new/unsloth-studio/install). Unsloth Studio is licensed [AGPL-3.0](https://github.com/unslothai/unsloth/blob/main/studio/LICENSE.AGPL-3.0).\n",
"\n",
"### Unsloth Studio\n",
"\n",
"Train and run open models with [**Unsloth Studio**](https://unsloth.ai/docs/new/unsloth-studio/start). NEW! Installation should now only take 2 mins!\n",
"\n",
"\n",
"We are actively working on making Unsloth Studio install on Colab T4 GPUs faster.\n",
"\n",
"[Features](https://unsloth.ai/docs/new/unsloth-studio#features) • [Quickstart](https://unsloth.ai/docs/new/unsloth-studio/start) • [Data Recipes](https://unsloth.ai/docs/new/unsloth-studio/data-recipe) • [Studio Chat](https://unsloth.ai/docs/new/unsloth-studio/chat) • [Export](https://unsloth.ai/docs/new/unsloth-studio/export)"
]
},
{
"cell_type": "markdown",
"id": "e4206349",
"metadata": {
"id": "e4206349"
},
"source": [
"<p align=\"left\"><img src=\"https://github.com/unslothai/unsloth/raw/main/studio/frontend/public/studio%20github%20landscape%20colab%20display.png\" width=\"600\"></p>"
]
},
{
"cell_type": "markdown",
"id": "27da2957",
"metadata": {
"id": "27da2957"
},
"source": [
"### Setup: Clone repo and run setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "27e68f91",
"metadata": {
"id": "27e68f91"
},
"outputs": [],
"source": "!git clone --depth 1 --branch main https://github.com/unslothai/unsloth.git\n%cd /content/unsloth\n!chmod +x studio/setup.sh && ./studio/setup.sh"
},
{
"cell_type": "markdown",
"id": "3e1771a9",
"metadata": {
"id": "3e1771a9"
},
"source": [
"### Start Unsloth Studio"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "277e431e",
"metadata": {
"id": "277e431e"
},
"outputs": [],
"source": [
"import sys, time\n",
"sys.path.insert(0, \"/content/unsloth/studio/backend\")\n",
"from colab import start\n",
"start()"
]
},
{
"cell_type": "code",
"source": [
"from google.colab import output\n",
"output.serve_kernel_port_as_iframe(8888, height = 1200, width = \"100%\")\n",
"for _ in range(10000): time.sleep(300), print(\"=\", end = \"\")"
],
"metadata": {
"id": "wb9UELh--XzX"
},
"id": "wb9UELh--XzX",
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"id": "f2b0c6a1",
"metadata": {
"id": "f2b0c6a1"
},
"source": [
"And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!\n",
"\n",
"Some other resources:\n",
"1. Looking to use Unsloth locally? Read our [Installation Guide](https://unsloth.ai/docs/get-started/install) for details on installing Unsloth on Windows, Docker, AMD, Intel GPUs.\n",
"2. Learn how to do Reinforcement Learning with our [RL Guide and notebooks](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide).\n",
"3. Read our guides and notebooks for [Text-to-speech (TTS)](https://unsloth.ai/docs/basics/text-to-speech-tts-fine-tuning) and [vision](https://unsloth.ai/docs/basics/vision-fine-tuning) model support.\n",
"4. Explore our [LLM Tutorials Directory](https://unsloth.ai/docs/models/tutorials-how-to-fine-tune-and-run-llms) to find dedicated guides for each model.\n",
"5. Need help with Inference? Read our [Inference & Deployment page](https://unsloth.ai/docs/basics/inference-and-deployment) for details on using vLLM, llama.cpp, Ollama etc.\n",
"\n",
"<div class=\"align-center\">\n",
" <a href=\"https://unsloth.ai\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\" width=\"115\"></a>\n",
" <a href=\"https://discord.gg/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Discord.png\" width=\"145\"></a>\n",
" <a href=\"https://unsloth.ai/docs/\"><img src=\"https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true\" width=\"125\"></a>\n",
"\n",
" Join Discord if you need help + ⭐️ <i>Star us on <a href=\"https://github.com/unslothai/unsloth\">Github</a> </i> ⭐️\n",
"\n",
" <b>This notebook is licensed <a href=\"https://github.com/unslothai/unsloth/blob/main/studio/LICENSE.AGPL-3.0\">AGPL-3.0</a></b>\n",
"</div>"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View file

@ -10,13 +10,13 @@ training:
load_in_4bit: false
output_dir: outputs
num_epochs: 1
learning_rate: 0.0002
learning_rate: 2e-5
batch_size: 1
gradient_accumulation_steps: 4
warmup_steps: 5
max_steps: 0
save_steps: 0
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: false

View file

@ -1,6 +1,14 @@
{
"_comment": "Per-model-family inference parameter defaults. Sources: (1) Ollama params blobs, (2) Existing Unsloth Studio YAML configs. Patterns ordered longest-match-first.",
"families": {
"qwen3.6": {
"temperature": 0.7,
"top_p": 0.8,
"top_k": 20,
"min_p": 0.0,
"repetition_penalty": 1.0,
"presence_penalty": 1.5
},
"qwen3.5": {
"temperature": 0.7,
"top_p": 0.8,
@ -93,6 +101,14 @@
"min_p": 0.0,
"repetition_penalty": 1.0
},
"gemma-4": {
"temperature": 1.0,
"top_p": 0.95,
"top_k": 64,
"min_p": 0.0,
"repetition_penalty": 1.0,
"presence_penalty": 0.0
},
"gemma-3n": {
"temperature": 1.0,
"top_p": 0.95,
@ -361,12 +377,12 @@
}
},
"patterns": [
"qwen3.5",
"qwen3.6", "qwen3.5",
"qwen3-coder", "qwen3-next", "qwen3-vl", "qwen3",
"qwen2.5-coder", "qwen2.5-vl", "qwen2.5-omni", "qwen2.5-math", "qwen2.5",
"qwen2-vl", "qwen2",
"qwq",
"gemma-3n", "gemma-3", "medgemma", "gemma-2",
"gemma-4", "gemma-3n", "gemma-3", "medgemma", "gemma-2",
"llama-4", "llama-3.3", "llama-3.2", "llama-3.1", "llama-3",
"phi-4", "phi-3",
"mistral-nemo", "mistral-small", "mistral-large", "magistral", "ministral",

View file

@ -16,7 +16,7 @@ training:
warmup_steps: 5
max_steps: 0
save_steps: 0
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: false

View file

@ -6,13 +6,13 @@ training:
max_seq_length: 2048
# num_epochs: 4
num_epochs: 0
learning_rate: 5e-5
learning_rate: 2e-4
batch_size: 2
gradient_accumulation_steps: 4
warmup_ratio: 0.1
max_steps: 30
save_steps: 30
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: true

View file

@ -12,7 +12,7 @@ training:
warmup_ratio: 0.03
max_steps: 30
save_steps: 30
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: false

View file

@ -11,7 +11,7 @@ training:
warmup_ratio: 0.03
max_steps: 30
save_steps: 30
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: false

View file

@ -11,7 +11,7 @@ training:
warmup_ratio: 0.03
max_steps: 30
save_steps: 30
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: false

View file

@ -11,7 +11,7 @@ training:
warmup_ratio: 0.03
max_steps: 30
save_steps: 30
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: false

View file

@ -11,7 +11,7 @@ training:
warmup_ratio: 0.03
max_steps: 30
save_steps: 30
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: false

View file

@ -13,7 +13,7 @@ training:
warmup_steps: 5
max_steps: 30
save_steps: 30
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: true

View file

@ -13,7 +13,7 @@ training:
warmup_steps: 5
max_steps: 30
save_steps: 30
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: true

View file

@ -0,0 +1,47 @@
# Model defaults for unsloth/gemma-4-26B-A4B-it
# Also applies to: google/gemma-4-26B-A4B-it, unsloth/gemma-4-26B-A4B-it-GGUF
training:
trust_remote_code: false
max_seq_length: 2048
num_epochs: 0
learning_rate: 2e-4
batch_size: 2
gradient_accumulation_steps: 4
warmup_steps: 5
max_steps: 30
save_steps: 30
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: true
gradient_checkpointing: "unsloth"
optim: "adamw_8bit"
lr_scheduler_type: "linear"
lora:
lora_r: 8
lora_alpha: 8
lora_dropout: 0.0
target_modules:
- "all-linear"
use_rslora: false
use_loftq: false
finetune_vision_layers: true
finetune_language_layers: true
finetune_attention_modules: true
finetune_mlp_modules: true
logging:
enable_wandb: false
wandb_project: "llm-finetuning"
enable_tensorboard: false
tensorboard_dir: "runs"
log_frequency: 10
inference:
trust_remote_code: false
temperature: 1.0
top_p: 0.95
top_k: 64
min_p: 0.0

View file

@ -0,0 +1,47 @@
# Model defaults for unsloth/gemma-4-26B-A4B (base/pretrained)
# Also applies to: google/gemma-4-26B-A4B
training:
trust_remote_code: false
max_seq_length: 2048
num_epochs: 0
learning_rate: 2e-4
batch_size: 2
gradient_accumulation_steps: 4
warmup_steps: 5
max_steps: 30
save_steps: 30
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: true
gradient_checkpointing: "unsloth"
optim: "adamw_8bit"
lr_scheduler_type: "linear"
lora:
lora_r: 8
lora_alpha: 8
lora_dropout: 0.0
target_modules:
- "all-linear"
use_rslora: false
use_loftq: false
finetune_vision_layers: true
finetune_language_layers: true
finetune_attention_modules: true
finetune_mlp_modules: true
logging:
enable_wandb: false
wandb_project: "llm-finetuning"
enable_tensorboard: false
tensorboard_dir: "runs"
log_frequency: 10
inference:
trust_remote_code: false
temperature: 1.0
top_p: 0.95
top_k: 64
min_p: 0.0

View file

@ -0,0 +1,47 @@
# Model defaults for unsloth/gemma-4-31B-it
# Also applies to: google/gemma-4-31B-it, unsloth/gemma-4-31B-it-GGUF
training:
trust_remote_code: false
max_seq_length: 2048
num_epochs: 0
learning_rate: 2e-4
batch_size: 2
gradient_accumulation_steps: 4
warmup_steps: 5
max_steps: 30
save_steps: 30
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: true
gradient_checkpointing: "unsloth"
optim: "adamw_8bit"
lr_scheduler_type: "linear"
lora:
lora_r: 8
lora_alpha: 8
lora_dropout: 0.0
target_modules:
- "all-linear"
use_rslora: false
use_loftq: false
finetune_vision_layers: true
finetune_language_layers: true
finetune_attention_modules: true
finetune_mlp_modules: true
logging:
enable_wandb: false
wandb_project: "llm-finetuning"
enable_tensorboard: false
tensorboard_dir: "runs"
log_frequency: 10
inference:
trust_remote_code: false
temperature: 1.0
top_p: 0.95
top_k: 64
min_p: 0.0

View file

@ -0,0 +1,47 @@
# Model defaults for unsloth/gemma-4-31B (base/pretrained)
# Also applies to: google/gemma-4-31B
training:
trust_remote_code: false
max_seq_length: 2048
num_epochs: 0
learning_rate: 2e-4
batch_size: 2
gradient_accumulation_steps: 4
warmup_steps: 5
max_steps: 30
save_steps: 30
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: true
gradient_checkpointing: "unsloth"
optim: "adamw_8bit"
lr_scheduler_type: "linear"
lora:
lora_r: 8
lora_alpha: 8
lora_dropout: 0.0
target_modules:
- "all-linear"
use_rslora: false
use_loftq: false
finetune_vision_layers: true
finetune_language_layers: true
finetune_attention_modules: true
finetune_mlp_modules: true
logging:
enable_wandb: false
wandb_project: "llm-finetuning"
enable_tensorboard: false
tensorboard_dir: "runs"
log_frequency: 10
inference:
trust_remote_code: false
temperature: 1.0
top_p: 0.95
top_k: 64
min_p: 0.0

View file

@ -0,0 +1,47 @@
# Model defaults for unsloth/gemma-4-E2B-it
# Also applies to: google/gemma-4-E2B-it, unsloth/gemma-4-E2B-it-GGUF
training:
trust_remote_code: false
max_seq_length: 2048
num_epochs: 0
learning_rate: 2e-4
batch_size: 2
gradient_accumulation_steps: 4
warmup_steps: 5
max_steps: 30
save_steps: 30
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: true
gradient_checkpointing: "unsloth"
optim: "adamw_8bit"
lr_scheduler_type: "linear"
lora:
lora_r: 8
lora_alpha: 8
lora_dropout: 0.0
target_modules:
- "all-linear"
use_rslora: false
use_loftq: false
finetune_vision_layers: true
finetune_language_layers: true
finetune_attention_modules: true
finetune_mlp_modules: true
logging:
enable_wandb: false
wandb_project: "llm-finetuning"
enable_tensorboard: false
tensorboard_dir: "runs"
log_frequency: 10
inference:
trust_remote_code: false
temperature: 1.0
top_p: 0.95
top_k: 64
min_p: 0.0

View file

@ -0,0 +1,47 @@
# Model defaults for unsloth/gemma-4-E2B (base/pretrained)
# Also applies to: google/gemma-4-E2B
training:
trust_remote_code: false
max_seq_length: 2048
num_epochs: 0
learning_rate: 2e-4
batch_size: 2
gradient_accumulation_steps: 4
warmup_steps: 5
max_steps: 30
save_steps: 30
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: true
gradient_checkpointing: "unsloth"
optim: "adamw_8bit"
lr_scheduler_type: "linear"
lora:
lora_r: 8
lora_alpha: 8
lora_dropout: 0.0
target_modules:
- "all-linear"
use_rslora: false
use_loftq: false
finetune_vision_layers: true
finetune_language_layers: true
finetune_attention_modules: true
finetune_mlp_modules: true
logging:
enable_wandb: false
wandb_project: "llm-finetuning"
enable_tensorboard: false
tensorboard_dir: "runs"
log_frequency: 10
inference:
trust_remote_code: false
temperature: 1.0
top_p: 0.95
top_k: 64
min_p: 0.0

View file

@ -0,0 +1,47 @@
# Model defaults for unsloth/gemma-4-E4B-it
# Also applies to: google/gemma-4-E4B-it, unsloth/gemma-4-E4B-it-GGUF
training:
trust_remote_code: false
max_seq_length: 2048
num_epochs: 0
learning_rate: 2e-4
batch_size: 2
gradient_accumulation_steps: 4
warmup_steps: 5
max_steps: 30
save_steps: 30
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: true
gradient_checkpointing: "unsloth"
optim: "adamw_8bit"
lr_scheduler_type: "linear"
lora:
lora_r: 8
lora_alpha: 8
lora_dropout: 0.0
target_modules:
- "all-linear"
use_rslora: false
use_loftq: false
finetune_vision_layers: true
finetune_language_layers: true
finetune_attention_modules: true
finetune_mlp_modules: true
logging:
enable_wandb: false
wandb_project: "llm-finetuning"
enable_tensorboard: false
tensorboard_dir: "runs"
log_frequency: 10
inference:
trust_remote_code: false
temperature: 1.0
top_p: 0.95
top_k: 64
min_p: 0.0

View file

@ -0,0 +1,47 @@
# Model defaults for unsloth/gemma-4-E4B (base/pretrained)
# Also applies to: google/gemma-4-E4B
training:
trust_remote_code: false
max_seq_length: 2048
num_epochs: 0
learning_rate: 2e-4
batch_size: 2
gradient_accumulation_steps: 4
warmup_steps: 5
max_steps: 30
save_steps: 30
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: true
gradient_checkpointing: "unsloth"
optim: "adamw_8bit"
lr_scheduler_type: "linear"
lora:
lora_r: 8
lora_alpha: 8
lora_dropout: 0.0
target_modules:
- "all-linear"
use_rslora: false
use_loftq: false
finetune_vision_layers: true
finetune_language_layers: true
finetune_attention_modules: true
finetune_mlp_modules: true
logging:
enable_wandb: false
wandb_project: "llm-finetuning"
enable_tensorboard: false
tensorboard_dir: "runs"
log_frequency: 10
inference:
trust_remote_code: false
temperature: 1.0
top_p: 0.95
top_k: 64
min_p: 0.0

View file

@ -13,7 +13,7 @@ training:
warmup_steps: 0
max_steps: 30
save_steps: 30
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: true

View file

@ -16,7 +16,7 @@ training:
warmup_steps: 5
max_steps: 0
save_steps: 0
weight_decay: 0.01
weight_decay: 0.001
random_seed: 3407
packing: false
train_on_completions: false

View file

@ -10,10 +10,12 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
import jwt
from .storage import (
API_KEY_PREFIX,
get_jwt_secret,
get_user_and_secret,
load_jwt_secret,
save_refresh_token,
validate_api_key,
verify_refresh_token,
)
@ -137,6 +139,18 @@ async def _get_current_subject(
...
"""
token = credentials.credentials
# --- API key path (sk-unsloth-...) ---
if token.startswith(API_KEY_PREFIX):
username = validate_api_key(token)
if username is None:
raise HTTPException(
status_code = status.HTTP_401_UNAUTHORIZED,
detail = "Invalid or expired API key",
)
return username
# --- JWT path ---
subject = _decode_subject_without_verification(token)
if subject is None:
raise HTTPException(

View file

@ -72,7 +72,22 @@ def clear_bootstrap_password() -> None:
def _hash_token(token: str) -> str:
"""SHA-256 hash helper used for refresh token storage."""
"""SHA-256 hash helper used for refresh token storage.
Plain SHA-256 is intentional here: refresh tokens are high-entropy
random strings from ``secrets.token_urlsafe(48)`` (384 bits of
entropy), so a slow KDF (Argon2 / bcrypt / PBKDF2) provides zero
additional security no attacker can brute-force 2^384 regardless
of hash speed while adding tens of ms of CPU to every refresh.
See the OWASP Password Storage Cheat Sheet on fast-vs-slow hashing
of high-entropy inputs.
API keys use the separate ``_pbkdf2_api_key`` helper below, which
runs PBKDF2-HMAC-SHA256 with a persistent server-side salt not
for cryptographic reasons (128-bit random tokens don't need slow
hashing), but because CodeQL's ``py/weak-sensitive-data-hashing``
query mislabels API keys as passwords and demands a KDF.
"""
return hashlib.sha256(token.encode("utf-8")).hexdigest()
@ -103,6 +118,29 @@ def get_connection() -> sqlite3.Connection:
);
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS api_keys (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL,
key_prefix TEXT NOT NULL,
key_hash TEXT NOT NULL UNIQUE,
name TEXT NOT NULL DEFAULT '',
created_at TEXT NOT NULL,
last_used_at TEXT,
expires_at TEXT,
is_active INTEGER NOT NULL DEFAULT 1
);
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS app_secrets (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);
"""
)
columns = {row["name"] for row in conn.execute("PRAGMA table_info(auth_user)")}
if "must_change_password" not in columns:
conn.execute(
@ -112,6 +150,89 @@ def get_connection() -> sqlite3.Connection:
return conn
# ── API-key PBKDF2 salt ────────────────────────────────────────────────
#
# Module-level cache for the persistent API-key PBKDF2 salt. Populated
# lazily on first use via ``_get_or_create_api_key_pbkdf2_salt``. Not
# protected by a lock because (a) the ``INSERT OR IGNORE`` provides
# atomicity at the SQLite layer and (b) concurrent populations converge
# on the same value, so the worst case is a harmless duplicate read on
# startup.
_api_key_pbkdf2_salt_cache: Optional[bytes] = None
def _get_or_create_api_key_pbkdf2_salt() -> bytes:
"""Return the persistent API-key PBKDF2 salt, generating it once if missing.
Stored as a hex-encoded 32-byte random value in the ``app_secrets``
table under key ``"api_key_pbkdf2_salt"``. Regenerated only if the row
is missing (i.e. fresh install, or operator manually deleted the row
and accepts invalidating existing API keys).
"""
global _api_key_pbkdf2_salt_cache
if _api_key_pbkdf2_salt_cache is not None:
return _api_key_pbkdf2_salt_cache
conn = get_connection()
try:
cur = conn.execute(
"SELECT value FROM app_secrets WHERE key = ?",
("api_key_pbkdf2_salt",),
)
row = cur.fetchone()
if row is None:
new_value = secrets.token_hex(32) # 32 bytes -> 64 hex chars
conn.execute(
"INSERT OR IGNORE INTO app_secrets (key, value) VALUES (?, ?)",
("api_key_pbkdf2_salt", new_value),
)
conn.commit()
cur = conn.execute(
"SELECT value FROM app_secrets WHERE key = ?",
("api_key_pbkdf2_salt",),
)
row = cur.fetchone()
salt = bytes.fromhex(row["value"])
finally:
conn.close()
_api_key_pbkdf2_salt_cache = salt
return salt
_API_KEY_PBKDF2_ITERATIONS = 100_000
def _pbkdf2_api_key(raw_key: str) -> str:
"""PBKDF2-HMAC-SHA256 an API key with a persistent server-side salt.
Used for API-key storage ONLY, not refresh tokens. Matches the
PBKDF2 algorithm + iteration count used by the password hasher in
``auth/hashing.py`` so the codebase is consistent on which KDF it
uses for credential storage.
Notes on why a slow KDF here is *only* a CodeQL appeasement and
*not* a cryptographic requirement: API keys are cryptographically
random 128-bit tokens (via ``secrets.token_hex``), so brute force
against 2^128 is infeasible regardless of hash speed. CodeQL's
``py/weak-sensitive-data-hashing`` query mislabels these tokens as
"password" sensitive data and then demands a KDF from its
allowlist (Argon2 / scrypt / bcrypt / PBKDF2). Per the query's
own recommendation page we use PBKDF2. The persistent salt is
still loaded from ``app_secrets`` so an attacker dumping the
``api_keys`` table alone cannot derive hashes for candidate
tokens without also obtaining the salt row.
"""
salt = _get_or_create_api_key_pbkdf2_salt()
dk = hashlib.pbkdf2_hmac(
"sha256",
raw_key.encode("utf-8"),
salt,
_API_KEY_PBKDF2_ITERATIONS,
)
return dk.hex()
def is_initialized() -> bool:
"""Check if auth is ready for login (at least one user exists in DB)."""
conn = get_connection()
@ -357,3 +478,105 @@ def revoke_user_refresh_tokens(username: str) -> None:
conn.commit()
finally:
conn.close()
# ---------------------------------------------------------------------------
# API key management
# ---------------------------------------------------------------------------
API_KEY_PREFIX = "sk-unsloth-"
def create_api_key(
username: str,
name: str,
expires_at: Optional[str] = None,
) -> Tuple[str, dict]:
"""Create a new API key for *username*.
Returns ``(raw_key, row_dict)`` where *raw_key* is shown to the user
exactly once. The database only stores the SHA-256 hash.
"""
raw_key = API_KEY_PREFIX + secrets.token_hex(16)
key_hash = _pbkdf2_api_key(raw_key)
key_prefix = raw_key[len(API_KEY_PREFIX) : len(API_KEY_PREFIX) + 8]
now = datetime.now(timezone.utc).isoformat()
conn = get_connection()
try:
conn.execute(
"""
INSERT INTO api_keys (username, key_prefix, key_hash, name, created_at, expires_at)
VALUES (?, ?, ?, ?, ?, ?)
""",
(username, key_prefix, key_hash, name, now, expires_at),
)
conn.commit()
cur = conn.execute("SELECT * FROM api_keys WHERE key_hash = ?", (key_hash,))
row = cur.fetchone()
return raw_key, dict(row)
finally:
conn.close()
def list_api_keys(username: str) -> list:
"""Return all API keys for *username* (never exposes ``key_hash``)."""
conn = get_connection()
try:
cur = conn.execute(
"""
SELECT id, username, key_prefix, name, created_at, last_used_at, expires_at, is_active
FROM api_keys
WHERE username = ?
ORDER BY created_at DESC
""",
(username,),
)
return [dict(row) for row in cur.fetchall()]
finally:
conn.close()
def revoke_api_key(username: str, key_id: int) -> bool:
"""Soft-delete an API key. Returns True if a matching row was found."""
conn = get_connection()
try:
cursor = conn.execute(
"UPDATE api_keys SET is_active = 0 WHERE id = ? AND username = ?",
(key_id, username),
)
conn.commit()
return cursor.rowcount > 0
finally:
conn.close()
def validate_api_key(raw_key: str) -> Optional[str]:
"""Validate *raw_key* and return the owning username, or ``None``.
Also updates ``last_used_at`` on success.
"""
key_hash = _pbkdf2_api_key(raw_key)
conn = get_connection()
try:
cur = conn.execute(
"SELECT id, username, is_active, expires_at FROM api_keys WHERE key_hash = ?",
(key_hash,),
)
row = cur.fetchone()
if row is None:
return None
if not row["is_active"]:
return None
if row["expires_at"] is not None:
expires = datetime.fromisoformat(row["expires_at"])
if datetime.now(timezone.utc) > expires:
return None
conn.execute(
"UPDATE api_keys SET last_used_at = ? WHERE id = ?",
(datetime.now(timezone.utc).isoformat(), row["id"]),
)
conn.commit()
return row["username"]
finally:
conn.close()

View file

@ -18,31 +18,6 @@ if _backend_dir not in sys.path:
import _platform_compat # noqa: F401
def _bootstrap_studio_venv() -> None:
"""Expose the Studio venv's site-packages to the current interpreter.
On Colab, notebook cells run outside the venv subshell. Instead of
installing the full stack into system Python, we prepend the venv's
site-packages so that packages like structlog, fastapi, etc. are
importable from notebook cells and take priority over system copies.
"""
venv_lib = Path.home() / ".unsloth" / "studio" / ".venv" / "lib"
if not venv_lib.exists():
import warnings
warnings.warn(
f"Studio venv not found at {venv_lib.parent} -- run 'unsloth studio setup' first",
stacklevel = 2,
)
return
for sp in venv_lib.glob("python*/site-packages"):
sp_str = str(sp)
if sp_str not in sys.path:
sys.path.insert(0, sp_str)
_bootstrap_studio_venv()
from loggers import get_logger
logger = get_logger(__name__)
@ -91,7 +66,10 @@ def show_link(port: int = 8888):
<svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="white"><polygon points="5,3 19,12 5,21"/></svg>
Open Unsloth Studio
</a>
<p style="color: #333333; margin: 16px 0 0 0; font-size: 13px; font-family: monospace;">
<p style="color: #333333; margin: 12px 0 0 0; font-size: 14px; font-weight: bold;">
If the link doesn't work, you can scroll down to view the UI generated directly in Colab.
</p>
<p style="color: #333333; margin: 16px 0 0 0; font-size: 13px; font-family: monospace; font-weight: bold;">
{short_url}
</p>
</div>

View file

@ -31,6 +31,7 @@ __all__ = [
# Config
"ModelConfig",
"is_vision_model",
"scan_trained_models",
"scan_trained_loras",
"load_model_defaults",
"get_base_model_from_lora",
@ -72,6 +73,7 @@ def __getattr__(name):
if name in (
"is_vision_model",
"ModelConfig",
"scan_trained_models",
"scan_trained_loras",
"load_model_defaults",
"get_base_model_from_lora",
@ -79,14 +81,15 @@ def __getattr__(name):
from utils.models import (
is_vision_model,
ModelConfig,
scan_trained_loras,
scan_trained_models,
load_model_defaults,
get_base_model_from_lora,
)
globals()["is_vision_model"] = is_vision_model
globals()["ModelConfig"] = ModelConfig
globals()["scan_trained_loras"] = scan_trained_loras
globals()["scan_trained_models"] = scan_trained_models
globals()["scan_trained_loras"] = scan_trained_models
globals()["load_model_defaults"] = load_model_defaults
globals()["get_base_model_from_lora"] = get_base_model_from_lora
return globals()[name]

View file

@ -4,7 +4,7 @@
"version": "0.0.1",
"type": "module",
"dependencies": {
"oxc-parser": "^0.116.0",
"oxc-parser": "^0.123.0",
"oxlint": "^1.51.0"
}
}

View file

@ -167,12 +167,7 @@ def _validate_recipe_runtime_support(
recipe: dict[str, Any],
model_providers: list[Any],
) -> None:
if not _recipe_has_llm_columns(recipe):
raise ValueError(
"Recipe Studio currently requires at least one AI generation step."
)
if not model_providers:
if _recipe_has_llm_columns(recipe) and not model_providers:
raise ValueError("Add a Provider connection block before running this recipe.")
@ -266,6 +261,21 @@ def create_data_designer(
model_providers = build_model_providers(recipe)
_validate_recipe_runtime_support(recipe, model_providers)
# DataDesigner requires at least one model provider in its registry even
# when the pipeline contains no LLM columns. Supply a lightweight stub
# so sampler/expression-only recipes can run without a real provider.
if not model_providers:
from data_designer.config.models import ModelProvider
model_providers = [
ModelProvider(
name = "_unused",
endpoint = "http://localhost",
provider_type = "openai",
api_key = None,
)
]
return DataDesigner(
artifact_path = artifact_path,
model_providers = model_providers,

View file

@ -310,7 +310,7 @@ class ExportBackend:
repo_id: Optional[str] = None,
hf_token: Optional[str] = None,
private: bool = False,
) -> Tuple[bool, str]:
) -> Tuple[bool, str, Optional[str]]:
"""
Export merged model (for PEFT models).
@ -323,14 +323,21 @@ class ExportBackend:
private: Whether to make the repo private
Returns:
Tuple of (success: bool, message: str)
Tuple of (success, message, output_path). output_path is the
resolved absolute on-disk directory of the saved model when
``save_directory`` was set, else None.
"""
if not self.current_model or not self.current_tokenizer:
return False, "No model loaded. Please select a checkpoint first."
return False, "No model loaded. Please select a checkpoint first.", None
if not self.is_peft:
return False, "This is not a PEFT model. Use 'Export Base Model' instead."
return (
False,
"This is not a PEFT model. Use 'Export Base Model' instead.",
None,
)
output_path: Optional[str] = None
try:
# Determine save method
if format_type == "4-bit (FP4)":
@ -354,6 +361,7 @@ class ExportBackend:
# Write export metadata so the Chat page can identify the base model
self._write_export_metadata(save_directory)
logger.info(f"Model saved successfully to {save_directory}")
output_path = str(Path(save_directory).resolve())
# Push to hub if requested
if push_to_hub:
@ -361,6 +369,7 @@ class ExportBackend:
return (
False,
"Repository ID and Hugging Face token required for Hub upload",
None,
)
logger.info(f"Pushing merged model to Hub: {repo_id}")
@ -378,14 +387,14 @@ class ExportBackend:
)
logger.info(f"Model pushed successfully to {repo_id}")
return True, "Model exported successfully"
return True, "Model exported successfully", output_path
except Exception as e:
logger.error(f"Error exporting merged model: {e}")
import traceback
logger.error(traceback.format_exc())
return False, f"Export failed: {str(e)}"
return False, f"Export failed: {str(e)}", None
def export_base_model(
self,
@ -395,22 +404,26 @@ class ExportBackend:
hf_token: Optional[str] = None,
private: bool = False,
base_model_id: Optional[str] = None,
) -> Tuple[bool, str]:
) -> Tuple[bool, str, Optional[str]]:
"""
Export base model (for non-PEFT models).
Returns:
Tuple of (success: bool, message: str)
Tuple of (success, message, output_path). output_path is the
resolved absolute on-disk directory of the saved model when
``save_directory`` was set, else None.
"""
if not self.current_model or not self.current_tokenizer:
return False, "No model loaded. Please select a checkpoint first."
return False, "No model loaded. Please select a checkpoint first.", None
if self.is_peft:
return (
False,
"This is a PEFT model. Use 'Merged Model' export type instead.",
None,
)
output_path: Optional[str] = None
try:
# Save locally if requested
if save_directory:
@ -424,6 +437,7 @@ class ExportBackend:
# Write export metadata so the Chat page can identify the base model
self._write_export_metadata(save_directory)
logger.info(f"Model saved successfully to {save_directory}")
output_path = str(Path(save_directory).resolve())
# Push to hub if requested
if push_to_hub:
@ -431,6 +445,7 @@ class ExportBackend:
return (
False,
"Repository ID and Hugging Face token required for Hub upload",
None,
)
logger.info(f"Pushing base model to Hub: {repo_id}")
@ -472,16 +487,16 @@ class ExportBackend:
)
logger.info(f"Model pushed successfully to {repo_id}")
else:
return False, "Local save directory required for Hub upload"
return False, "Local save directory required for Hub upload", None
return True, "Model exported successfully"
return True, "Model exported successfully", output_path
except Exception as e:
logger.error(f"Error exporting base model: {e}")
import traceback
logger.error(traceback.format_exc())
return False, f"Export failed: {str(e)}"
return False, f"Export failed: {str(e)}", None
def export_gguf(
self,
@ -490,7 +505,7 @@ class ExportBackend:
push_to_hub: bool = False,
repo_id: Optional[str] = None,
hf_token: Optional[str] = None,
) -> Tuple[bool, str]:
) -> Tuple[bool, str, Optional[str]]:
"""
Export model in GGUF format.
@ -502,11 +517,14 @@ class ExportBackend:
hf_token: Hugging Face token
Returns:
Tuple of (success: bool, message: str)
Tuple of (success, message, output_path). output_path is the
resolved absolute on-disk directory containing the .gguf
files when ``save_directory`` was set, else None.
"""
if not self.current_model or not self.current_tokenizer:
return False, "No model loaded. Please select a checkpoint first."
return False, "No model loaded. Please select a checkpoint first.", None
output_path: Optional[str] = None
try:
# Convert quantization method to lowercase for unsloth
quant_method = quantization_method.lower()
@ -601,6 +619,7 @@ class ExportBackend:
abs_save_dir,
"\n ".join(os.path.basename(f) for f in final_ggufs) or "(none)",
)
output_path = str(Path(abs_save_dir).resolve())
# Push to hub if requested
if push_to_hub:
@ -608,6 +627,7 @@ class ExportBackend:
return (
False,
"Repository ID and Hugging Face token required for Hub upload",
None,
)
logger.info(f"Pushing GGUF model to Hub: {repo_id}")
@ -620,14 +640,18 @@ class ExportBackend:
)
logger.info(f"GGUF model pushed successfully to {repo_id}")
return True, f"GGUF model exported successfully ({quantization_method})"
return (
True,
f"GGUF model exported successfully ({quantization_method})",
output_path,
)
except Exception as e:
logger.error(f"Error exporting GGUF model: {e}")
import traceback
logger.error(traceback.format_exc())
return False, f"GGUF export failed: {str(e)}"
return False, f"GGUF export failed: {str(e)}", None
def export_lora_adapter(
self,
@ -636,19 +660,22 @@ class ExportBackend:
repo_id: Optional[str] = None,
hf_token: Optional[str] = None,
private: bool = False,
) -> Tuple[bool, str]:
) -> Tuple[bool, str, Optional[str]]:
"""
Export LoRA adapter only (not merged).
Returns:
Tuple of (success: bool, message: str)
Tuple of (success, message, output_path). output_path is the
resolved absolute on-disk directory of the saved adapter
when ``save_directory`` was set, else None.
"""
if not self.current_model or not self.current_tokenizer:
return False, "No model loaded. Please select a checkpoint first."
return False, "No model loaded. Please select a checkpoint first.", None
if not self.is_peft:
return False, "This is not a PEFT model. No adapter to export."
return False, "This is not a PEFT model. No adapter to export.", None
output_path: Optional[str] = None
try:
# Save locally if requested
if save_directory:
@ -659,6 +686,7 @@ class ExportBackend:
self.current_model.save_pretrained(save_directory)
self.current_tokenizer.save_pretrained(save_directory)
logger.info(f"Adapter saved successfully to {save_directory}")
output_path = str(Path(save_directory).resolve())
# Push to hub if requested
if push_to_hub:
@ -666,6 +694,7 @@ class ExportBackend:
return (
False,
"Repository ID and Hugging Face token required for Hub upload",
None,
)
logger.info(f"Pushing LoRA adapter to Hub: {repo_id}")
@ -676,14 +705,14 @@ class ExportBackend:
)
logger.info(f"Adapter pushed successfully to {repo_id}")
return True, "LoRA adapter exported successfully"
return True, "LoRA adapter exported successfully", output_path
except Exception as e:
logger.error(f"Error exporting LoRA adapter: {e}")
import traceback
logger.error(traceback.format_exc())
return False, f"Adapter export failed: {str(e)}"
return False, f"Adapter export failed: {str(e)}", None
# Global export backend instance

View file

@ -16,19 +16,25 @@ Pattern follows core/inference/orchestrator.py.
import atexit
import structlog
from collections import deque
from loggers import get_logger
import multiprocessing as mp
import queue
import threading
import time
from pathlib import Path
from typing import Any, List, Optional, Tuple
from typing import Any, Deque, Dict, List, Optional, Tuple
from utils.paths import outputs_root
logger = get_logger(__name__)
_CTX = mp.get_context("spawn")
# Maximum number of captured log lines kept in memory per export
# orchestrator. Acts as scrollback for the live export log panel in the
# UI. 4000 lines is ~1 MB worst-case at 256 chars/line.
_LOG_BUFFER_MAXLEN = 4000
class ExportOrchestrator:
"""
@ -44,6 +50,9 @@ class ExportOrchestrator:
self._proc: Optional[mp.Process] = None
self._cmd_queue: Any = None
self._resp_queue: Any = None
# Serializes export operations (load_checkpoint, export_*,
# cleanup) so concurrent HTTP requests can never interleave
# commands on the subprocess queue. Previously unused.
self._lock = threading.Lock()
# Local state mirrors (updated from subprocess responses)
@ -51,9 +60,103 @@ class ExportOrchestrator:
self.is_vision: bool = False
self.is_peft: bool = False
# ── Live log capture ─────────────────────────────────────
# Thread-safe ring buffer of log lines forwarded from the
# worker subprocess. Powers the GET /api/export/logs/stream
# SSE endpoint that the export dialog consumes.
self._log_buffer: Deque[Dict[str, Any]] = deque(maxlen = _LOG_BUFFER_MAXLEN)
self._log_lock = threading.Lock()
# Monotonically increasing sequence number. Never reset across
# operations, so SSE clients can use it as a stable cursor even
# if clear_logs() is called mid-session.
self._log_seq: int = 0
# Snapshot of _log_seq captured at the start of the current run
# (updated by clear_logs()). The SSE endpoint defaults its
# cursor to this value so a client that connects AFTER the
# worker has already emitted its first lines still sees the
# full run. Every line appended during the current run has seq
# strictly greater than _run_start_seq, and every line from
# prior runs has seq less than or equal to it.
self._run_start_seq: int = 0
# True while an export operation (load/export/cleanup) is
# running. The SSE endpoint ends the stream 1 second after
# this flips back to False to drain any trailing log lines.
self._export_active: bool = False
atexit.register(self._cleanup)
logger.info("ExportOrchestrator initialized (subprocess mode)")
# ------------------------------------------------------------------
# Live log capture helpers
# ------------------------------------------------------------------
def _append_log(self, entry: Dict[str, Any]) -> None:
"""Append a log line from the worker subprocess to the buffer.
Entries look like {"type": "log", "stream": "stdout"|"stderr",
"line": "...", "ts": ...}. Each is stamped with a monotonic
seq number before it lands in the buffer so SSE clients can
cursor through new lines.
"""
line = entry.get("line")
if not line:
return
with self._log_lock:
self._log_seq += 1
self._log_buffer.append(
{
"seq": self._log_seq,
"stream": entry.get("stream", "stdout"),
"line": line,
"ts": entry.get("ts", time.time()),
}
)
def clear_logs(self) -> None:
"""Drop any buffered log lines from a previous operation.
Called at the start of each export op so the UI shows only the
output of the current run. The seq counter is NOT reset, so an
SSE client that captured the cursor before clear_logs() will
still see new lines (with strictly greater seq numbers).
Also snapshots the current seq into ``_run_start_seq`` so the
SSE endpoint can anchor its default cursor at the start of
this run. Anything appended after this call has seq strictly
greater than the snapshot and is reachable via
``get_logs_since(get_run_start_seq())``.
"""
with self._log_lock:
self._log_buffer.clear()
self._run_start_seq = self._log_seq
def get_logs_since(self, cursor: int) -> Tuple[List[Dict[str, Any]], int]:
"""Return log entries with seq > cursor, plus the new cursor."""
with self._log_lock:
new_entries = [entry for entry in self._log_buffer if entry["seq"] > cursor]
if new_entries:
return new_entries, new_entries[-1]["seq"]
return [], cursor
def get_current_log_seq(self) -> int:
"""Return the current seq counter without reading any entries."""
with self._log_lock:
return self._log_seq
def get_run_start_seq(self) -> int:
"""Return the seq value captured at the start of the current run.
The SSE endpoint uses this as the default cursor so a client
that connects AFTER the worker has already started emitting
output still sees every line from the current run.
"""
with self._log_lock:
return self._run_start_seq
def is_export_active(self) -> bool:
"""True while an export / load / cleanup command is running."""
return self._export_active
# ------------------------------------------------------------------
# Subprocess lifecycle
# ------------------------------------------------------------------
@ -179,8 +282,26 @@ class ExportOrchestrator:
error_msg = resp.get("error", "Unknown error")
raise RuntimeError(f"Subprocess error: {error_msg}")
if rtype == "log":
# Forwarded stdout/stderr line from the worker process.
self._append_log(resp)
continue
if rtype == "status":
logger.info("Export subprocess status: %s", resp.get("message", ""))
message = resp.get("message", "")
logger.info("Export subprocess status: %s", message)
# Surface status messages in the live log panel too so
# users see high level progress (e.g. "Importing
# Unsloth...", "Loading checkpoint: ...") alongside
# subprocess output.
if message:
self._append_log(
{
"stream": "status",
"line": message,
"ts": resp.get("ts", time.time()),
}
)
continue
# Other response types during wait — skip
@ -217,6 +338,7 @@ class ExportOrchestrator:
max_seq_length: int = 2048,
load_in_4bit: bool = True,
trust_remote_code: bool = False,
hf_token: Optional[str] = None,
) -> Tuple[bool, str]:
"""Load a checkpoint for export.
@ -227,39 +349,50 @@ class ExportOrchestrator:
"max_seq_length": max_seq_length,
"load_in_4bit": load_in_4bit,
"trust_remote_code": trust_remote_code,
"hf_token": hf_token,
}
# Always kill existing subprocess and spawn fresh.
if self._ensure_subprocess_alive():
self._shutdown_subprocess()
elif self._proc is not None:
self._shutdown_subprocess(timeout = 2)
with self._lock:
# Start a fresh log buffer for this operation so the UI
# sees only the current run's output.
self.clear_logs()
self._export_active = True
try:
# Always kill existing subprocess and spawn fresh.
if self._ensure_subprocess_alive():
self._shutdown_subprocess()
elif self._proc is not None:
self._shutdown_subprocess(timeout = 2)
logger.info("Spawning fresh export subprocess for '%s'", checkpoint_path)
self._spawn_subprocess(sub_config)
logger.info(
"Spawning fresh export subprocess for '%s'", checkpoint_path
)
self._spawn_subprocess(sub_config)
try:
resp = self._wait_response("loaded", timeout = 300)
except RuntimeError as exc:
self._shutdown_subprocess(timeout = 5)
self.current_checkpoint = None
self.is_vision = False
self.is_peft = False
return False, str(exc)
try:
resp = self._wait_response("loaded")
except RuntimeError as exc:
self._shutdown_subprocess(timeout = 5)
self.current_checkpoint = None
self.is_vision = False
self.is_peft = False
return False, str(exc)
if resp.get("success"):
self.current_checkpoint = resp.get("checkpoint")
self.is_vision = resp.get("is_vision", False)
self.is_peft = resp.get("is_peft", False)
logger.info("Checkpoint '%s' loaded in subprocess", checkpoint_path)
return True, resp.get("message", "Loaded successfully")
else:
error = resp.get("message", "Failed to load checkpoint")
logger.error("Failed to load checkpoint: %s", error)
self.current_checkpoint = None
self.is_vision = False
self.is_peft = False
return False, error
if resp.get("success"):
self.current_checkpoint = resp.get("checkpoint")
self.is_vision = resp.get("is_vision", False)
self.is_peft = resp.get("is_peft", False)
logger.info("Checkpoint '%s' loaded in subprocess", checkpoint_path)
return True, resp.get("message", "Loaded successfully")
else:
error = resp.get("message", "Failed to load checkpoint")
logger.error("Failed to load checkpoint: %s", error)
self.current_checkpoint = None
self.is_vision = False
self.is_peft = False
return False, error
finally:
self._export_active = False
def export_merged_model(
self,
@ -269,7 +402,7 @@ class ExportOrchestrator:
repo_id: Optional[str] = None,
hf_token: Optional[str] = None,
private: bool = False,
) -> Tuple[bool, str]:
) -> Tuple[bool, str, Optional[str]]:
"""Export merged PEFT model."""
return self._run_export(
"merged",
@ -291,7 +424,7 @@ class ExportOrchestrator:
hf_token: Optional[str] = None,
private: bool = False,
base_model_id: Optional[str] = None,
) -> Tuple[bool, str]:
) -> Tuple[bool, str, Optional[str]]:
"""Export base model (non-PEFT)."""
return self._run_export(
"base",
@ -312,7 +445,7 @@ class ExportOrchestrator:
push_to_hub: bool = False,
repo_id: Optional[str] = None,
hf_token: Optional[str] = None,
) -> Tuple[bool, str]:
) -> Tuple[bool, str, Optional[str]]:
"""Export model in GGUF format."""
return self._run_export(
"gguf",
@ -332,7 +465,7 @@ class ExportOrchestrator:
repo_id: Optional[str] = None,
hf_token: Optional[str] = None,
private: bool = False,
) -> Tuple[bool, str]:
) -> Tuple[bool, str, Optional[str]]:
"""Export LoRA adapter only."""
return self._run_export(
"lora",
@ -345,46 +478,74 @@ class ExportOrchestrator:
},
)
def _run_export(self, export_type: str, params: dict) -> Tuple[bool, str]:
"""Send an export command to the subprocess and wait for result."""
if not self._ensure_subprocess_alive():
return False, "No export subprocess running. Load a checkpoint first."
def _run_export(
self, export_type: str, params: dict
) -> Tuple[bool, str, Optional[str]]:
"""Send an export command to the subprocess and wait for result.
cmd = {"type": "export", "export_type": export_type, **params}
Returns ``(success, message, output_path)``. ``output_path`` is the
resolved on-disk directory the worker actually wrote to (None when
the export only pushed to Hub or failed before any file was
written). Surfaced via the export route's ``details.output_path``
so the dialog's success screen can show the user where the model
landed.
"""
with self._lock:
if not self._ensure_subprocess_alive():
return (
False,
"No export subprocess running. Load a checkpoint first.",
None,
)
try:
self._send_cmd(cmd)
resp = self._wait_response(
f"export_{export_type}_done",
timeout = 3600, # GGUF for 30B+ models can take 30+ min
)
return resp.get("success", False), resp.get("message", "")
except RuntimeError as exc:
return False, str(exc)
self.clear_logs()
self._export_active = True
try:
cmd = {"type": "export", "export_type": export_type, **params}
try:
self._send_cmd(cmd)
resp = self._wait_response(
f"export_{export_type}_done",
timeout = 3600, # GGUF for 30B+ models can take 30+ min
)
return (
resp.get("success", False),
resp.get("message", ""),
resp.get("output_path"),
)
except RuntimeError as exc:
return False, str(exc), None
finally:
self._export_active = False
def cleanup_memory(self) -> bool:
"""Cleanup export-related models from memory."""
if not self._ensure_subprocess_alive():
# No subprocess — just clear local state
self.current_checkpoint = None
self.is_vision = False
self.is_peft = False
return True
with self._lock:
if not self._ensure_subprocess_alive():
# No subprocess — just clear local state
self.current_checkpoint = None
self.is_vision = False
self.is_peft = False
return True
try:
self._send_cmd({"type": "cleanup"})
resp = self._wait_response("cleanup_done", timeout = 30)
success = resp.get("success", False)
except RuntimeError:
success = False
self._export_active = True
try:
try:
self._send_cmd({"type": "cleanup"})
resp = self._wait_response("cleanup_done", timeout = 30)
success = resp.get("success", False)
except RuntimeError:
success = False
# Shut down subprocess after cleanup — no model loaded
self._shutdown_subprocess()
# Shut down subprocess after cleanup — no model loaded
self._shutdown_subprocess()
self.current_checkpoint = None
self.is_vision = False
self.is_peft = False
return success
self.current_checkpoint = None
self.is_vision = False
self.is_peft = False
return success
finally:
self._export_active = False
def scan_checkpoints(
self, outputs_dir: str = str(outputs_root())

View file

@ -17,10 +17,12 @@ Pattern follows core/inference/worker.py and core/training/worker.py.
from __future__ import annotations
import errno
import structlog
from loggers import get_logger
import os
import sys
import threading
import time
import traceback
from pathlib import Path
@ -29,38 +31,164 @@ from typing import Any
logger = get_logger(__name__)
def _activate_transformers_version(model_name: str) -> None:
"""Activate the correct transformers version BEFORE any ML imports.
# Gate that controls whether captured stdout/stderr lines are forwarded
# to the parent's resp_queue (and from there to the export-dialog SSE
# stream). Closed by default so the noisy bootstrap phase -- transformers
# venv activation, Unsloth/torch imports, base-model resolution, "Top
# GGUF/hub models" lists, vision detection, weight loading bars -- is
# suppressed in the UI. _handle_export() opens the gate at the start of
# the actual export work and leaves it open; the orchestrator always
# spawns a fresh subprocess for the next checkpoint load (see
# orchestrator._spawn_subprocess) which resets this state.
#
# Lines dropped while the gate is closed are still echoed to the saved
# original stdout/stderr fds so the server console / log file keeps the
# full output for debugging.
_log_forward_gate = threading.Event()
If the model needs transformers 5.x, prepend the pre-installed .venv_t5/
directory to sys.path. Otherwise do nothing (default 4.57.x in .venv/).
def _setup_log_capture(resp_queue: Any) -> None:
"""Redirect fds 1 and 2 through pipes so every line printed by this
worker process and any child process it spawns is forwarded to the
parent process via resp_queue as {"type": "log", ...} messages.
Must be called BEFORE LogConfig.setup_logging and BEFORE any ML
imports, otherwise library handlers may capture the original stderr
reference and bypass the pipe.
Lines are also echoed back to the original stdout/stderr so the
server console keeps receiving the full subprocess output, even
while ``_log_forward_gate`` is closed.
"""
try:
saved_out_fd = os.dup(1)
saved_err_fd = os.dup(2)
except OSError:
# dup failed (exotic platforms) - give up quietly, export still
# works, just no live log streaming.
return
try:
r_out, w_out = os.pipe()
r_err, w_err = os.pipe()
except OSError:
os.close(saved_out_fd)
os.close(saved_err_fd)
return
try:
os.dup2(w_out, 1)
os.dup2(w_err, 2)
except OSError:
for fd in (saved_out_fd, saved_err_fd, r_out, w_out, r_err, w_err):
try:
os.close(fd)
except OSError:
pass
return
# Close the write ends we just dup2'd (fds 1 and 2 are the real
# write ends now).
os.close(w_out)
os.close(w_err)
# Replace Python's sys.stdout/sys.stderr with line-buffered writers
# bound to the (now-redirected) fds 1 and 2.
try:
sys.stdout = os.fdopen(1, "w", buffering = 1, encoding = "utf-8", errors = "replace")
sys.stderr = os.fdopen(2, "w", buffering = 1, encoding = "utf-8", errors = "replace")
except Exception:
pass
def _reader(read_fd: int, stream_name: str, echo_fd: int) -> None:
buf = bytearray()
while True:
try:
chunk = os.read(read_fd, 4096)
except OSError as exc:
if exc.errno == errno.EBADF:
break
continue
if not chunk:
break
# Echo to the original fd so the server console still sees
# the full output.
try:
os.write(echo_fd, chunk)
except OSError:
pass
buf.extend(chunk)
# Split on \n OR \r so tqdm-style progress bars update.
while True:
nl = -1
for i, b in enumerate(buf):
if b == 0x0A or b == 0x0D:
nl = i
break
if nl < 0:
break
line = bytes(buf[:nl]).decode("utf-8", errors = "replace")
del buf[: nl + 1]
if not line:
continue
if not _log_forward_gate.is_set():
# Gate closed (bootstrap phase) -- already echoed to
# the saved console fd above; drop the line so the
# export dialog doesn't see import / vendoring noise.
continue
try:
resp_queue.put_nowait(
{
"type": "log",
"stream": stream_name,
"line": line,
"ts": time.time(),
}
)
except Exception:
# Queue put failed (full, closed, etc.) - drop the
# line rather than crash the reader thread.
pass
if buf and _log_forward_gate.is_set():
try:
resp_queue.put_nowait(
{
"type": "log",
"stream": stream_name,
"line": bytes(buf).decode("utf-8", errors = "replace"),
"ts": time.time(),
}
)
except Exception:
pass
t_out = threading.Thread(
target = _reader,
args = (r_out, "stdout", saved_out_fd),
daemon = True,
name = "export-log-stdout",
)
t_err = threading.Thread(
target = _reader,
args = (r_err, "stderr", saved_err_fd),
daemon = True,
name = "export-log-stderr",
)
t_out.start()
t_err.start()
def _activate_transformers_version(model_name: str) -> None:
"""Activate the correct transformers version BEFORE any ML imports."""
# Ensure backend is on path for utils imports
backend_path = str(Path(__file__).resolve().parent.parent.parent)
if backend_path not in sys.path:
sys.path.insert(0, backend_path)
from utils.transformers_version import (
needs_transformers_5,
_resolve_base_model,
_ensure_venv_t5_exists,
_VENV_T5_DIR,
)
from utils.transformers_version import activate_transformers_for_subprocess
resolved = _resolve_base_model(model_name)
if needs_transformers_5(resolved):
if not _ensure_venv_t5_exists():
raise RuntimeError(
f"Cannot activate transformers 5.x: .venv_t5 missing at {_VENV_T5_DIR}"
)
if _VENV_T5_DIR not in sys.path:
sys.path.insert(0, _VENV_T5_DIR)
logger.info("Activated transformers 5.x from %s", _VENV_T5_DIR)
# Propagate to child subprocesses (e.g. GGUF converter)
_pp = os.environ.get("PYTHONPATH", "")
os.environ["PYTHONPATH"] = _VENV_T5_DIR + (os.pathsep + _pp if _pp else "")
else:
logger.info("Using default transformers (4.57.x) for %s", model_name)
activate_transformers_for_subprocess(model_name)
def _send_response(resp_queue: Any, response: dict) -> None:
@ -78,6 +206,19 @@ def _handle_load(backend, cmd: dict, resp_queue: Any) -> None:
load_in_4bit = cmd.get("load_in_4bit", True)
trust_remote_code = cmd.get("trust_remote_code", False)
# Auto-enable trust_remote_code for NemotronH/Nano models.
if not trust_remote_code:
_NEMOTRON_TRUST_SUBSTRINGS = ("nemotron_h", "nemotron-h", "nemotron-3-nano")
_cp_lower = checkpoint_path.lower()
if any(sub in _cp_lower for sub in _NEMOTRON_TRUST_SUBSTRINGS) and (
_cp_lower.startswith("unsloth/") or _cp_lower.startswith("nvidia/")
):
trust_remote_code = True
logger.info(
"Auto-enabled trust_remote_code for Nemotron model: %s",
checkpoint_path,
)
try:
_send_response(
resp_queue,
@ -126,9 +267,17 @@ def _handle_export(backend, cmd: dict, resp_queue: Any) -> None:
export_type = cmd["export_type"] # "merged", "base", "gguf", "lora"
response_type = f"export_{export_type}_done"
# Open the log forwarding gate so the user sees the actual export
# progress (Unsloth merge bars, file copies, GGUF conversion, etc.)
# in the live log panel. The gate stays open for the rest of this
# subprocess's life; the orchestrator spawns a fresh subprocess for
# the next checkpoint load, which resets the gate to closed.
_log_forward_gate.set()
output_path: Any = None
try:
if export_type == "merged":
success, message = backend.export_merged_model(
success, message, output_path = backend.export_merged_model(
save_directory = cmd.get("save_directory", ""),
format_type = cmd.get("format_type", "16-bit (FP16)"),
push_to_hub = cmd.get("push_to_hub", False),
@ -137,7 +286,7 @@ def _handle_export(backend, cmd: dict, resp_queue: Any) -> None:
private = cmd.get("private", False),
)
elif export_type == "base":
success, message = backend.export_base_model(
success, message, output_path = backend.export_base_model(
save_directory = cmd.get("save_directory", ""),
push_to_hub = cmd.get("push_to_hub", False),
repo_id = cmd.get("repo_id"),
@ -146,7 +295,7 @@ def _handle_export(backend, cmd: dict, resp_queue: Any) -> None:
base_model_id = cmd.get("base_model_id"),
)
elif export_type == "gguf":
success, message = backend.export_gguf(
success, message, output_path = backend.export_gguf(
save_directory = cmd.get("save_directory", ""),
quantization_method = cmd.get("quantization_method", "Q4_K_M"),
push_to_hub = cmd.get("push_to_hub", False),
@ -154,7 +303,7 @@ def _handle_export(backend, cmd: dict, resp_queue: Any) -> None:
hf_token = cmd.get("hf_token"),
)
elif export_type == "lora":
success, message = backend.export_lora_adapter(
success, message, output_path = backend.export_lora_adapter(
save_directory = cmd.get("save_directory", ""),
push_to_hub = cmd.get("push_to_hub", False),
repo_id = cmd.get("repo_id"),
@ -170,6 +319,7 @@ def _handle_export(backend, cmd: dict, resp_queue: Any) -> None:
"type": response_type,
"success": success,
"message": message,
"output_path": output_path,
"ts": time.time(),
},
)
@ -181,6 +331,7 @@ def _handle_export(backend, cmd: dict, resp_queue: Any) -> None:
"type": response_type,
"success": False,
"message": str(exc),
"output_path": None,
"stack": traceback.format_exc(limit = 20),
"ts": time.time(),
},
@ -226,10 +377,26 @@ def run_export_process(
"""
import queue as _queue
# Install fd-level stdout/stderr capture FIRST so every subsequent
# print and every child process inherits the redirected fds. This
# is what powers the live export log stream in the UI.
_setup_log_capture(resp_queue)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTHONWARNINGS"] = (
"ignore" # Suppress warnings at C-level before imports
)
# Force unbuffered output from any child Python process (e.g. the
# GGUF converter) so their prints surface in the log stream as they
# happen rather than at the end.
os.environ["PYTHONUNBUFFERED"] = "1"
# tqdm defaults to a 10-second mininterval when stdout is not a tty
# (which it isn't here -- we redirected fd 1/2 to a pipe). That makes
# multi-step progress bars look frozen in the export log panel. Force
# frequent flushes so the user sees movement during merge / GGUF
# conversion. Has no effect on single-step bars (e.g. "Copying 1
# files") which only emit start/end events regardless.
os.environ.setdefault("TQDM_MININTERVAL", "0.5")
import warnings
from loggers.config import LogConfig

View file

@ -0,0 +1,447 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""
Minimal HTML-to-Markdown converter using only the standard library.
Replaces the external ``html2text`` (GPL-3.0) dependency with a ~250-line
``html.parser.HTMLParser`` subclass. Covers headings, links, bold/italic,
lists, tables, blockquotes, code blocks, and entity decoding.
"""
from __future__ import annotations
import html
import re
from html.parser import HTMLParser
__all__ = ["html_to_markdown"]
_SKIP_TAGS = frozenset(
{
"script",
"style",
"head",
"noscript",
"svg",
"math",
"nav",
"footer",
}
)
_BLOCK_TAGS = frozenset(
{
"p",
"div",
"section",
"article",
"main",
"aside",
"figure",
"figcaption",
"details",
"summary",
"dl",
"dt",
"dd",
}
)
_HEADING_TAGS = frozenset({"h1", "h2", "h3", "h4", "h5", "h6"})
_INLINE_EMPHASIS = {"strong": "**", "b": "**", "em": "*", "i": "*"}
class _MarkdownRenderer(HTMLParser):
"""HTMLParser subclass that emits Markdown tokens into a list."""
def __init__(self):
super().__init__(convert_charrefs = False)
self._out: list[str] = []
self._skip_depth: int = 0
# Link state
self._link_href: str | None = None
self._link_text_parts: list[str] = []
self._in_link: bool = False
# List state
self._list_stack: list[str] = [] # "ul" or "ol"
self._ol_counter: list[int] = []
# Table state
self._in_table: bool = False
self._current_row: list[str] = []
self._cell_parts: list[str] = []
self._in_cell: bool = False
self._header_row_done: bool = False
self._row_has_th: bool = False
self._is_first_row: bool = False
# Pre/code state
self._in_pre: bool = False
self._pre_parts: list[str] = []
self._in_inline_code: bool = False
# Blockquote state -- stack of output buffers so nested
# blockquotes each collect their own content and get prefixed
# with the correct number of ">" markers on close.
self._bq_stack: list[list[str]] = []
# ------------------------------------------------------------------
def _emit(self, text: str) -> None:
if self._in_link:
self._link_text_parts.append(text)
elif self._in_cell:
self._cell_parts.append(text)
elif self._in_pre:
self._pre_parts.append(text)
elif self._bq_stack:
self._bq_stack[-1].append(text)
else:
self._out.append(text)
# ------------------------------------------------------------------
def _prefix_blockquote(self, content: str) -> str:
"""Prefix every line of *content* with ``> ``."""
# Strip trailing whitespace first, then collapse blank lines
content = re.sub(r"[ \t]+$", "", content, flags = re.MULTILINE)
content = re.sub(r"\n{3,}", "\n\n", content).strip()
if not content:
return ""
lines = content.split("\n")
prefixed: list[str] = []
for line in lines:
if line.strip():
prefixed.append("> " + line)
else:
prefixed.append(">")
return "\n".join(prefixed)
# ------------------------------------------------------------------
# Table helpers -- flush open cells and rows so that HTML with
# omitted optional end tags (</td>, </tr>) does not lose data.
# ------------------------------------------------------------------
def _finish_cell(self) -> None:
if not self._in_cell:
return
self._in_cell = False
cell_text = "".join(self._cell_parts).strip().replace("\n", " ")
cell_text = cell_text.replace("|", "\\|")
self._current_row.append(cell_text)
self._cell_parts = []
def _finish_row(self) -> None:
if not self._current_row:
return
line = "| " + " | ".join(self._current_row) + " |"
self._emit(line + "\n")
if not self._header_row_done and (self._row_has_th or self._is_first_row):
sep = "| " + " | ".join("---" for _ in self._current_row) + " |"
self._emit(sep + "\n")
self._header_row_done = True
self._is_first_row = False
self._current_row = []
self._row_has_th = False
# ------------------------------------------------------------------
# Link text helper -- normalize whitespace so block-level content
# inside an <a> does not produce multiline Markdown link labels.
# ------------------------------------------------------------------
def _finish_link(self) -> None:
text = re.sub(r"\s+", " ", "".join(self._link_text_parts)).strip()
href = self._link_href or ""
self._in_link = False
if href and text:
self._emit(f"[{text}]({href})")
elif text:
self._emit(text)
# ------------------------------------------------------------------
# Tag handlers
# ------------------------------------------------------------------
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
tag = tag.lower()
if tag in _SKIP_TAGS:
self._skip_depth += 1
return
if self._skip_depth:
return
attr_dict = dict(attrs)
if tag in _HEADING_TAGS:
level = int(tag[1])
self._emit("\n\n" + "#" * level + " ")
elif tag == "a":
self._link_href = attr_dict.get("href")
self._link_text_parts = []
self._in_link = True
elif tag in _INLINE_EMPHASIS:
self._emit(_INLINE_EMPHASIS[tag])
elif tag == "br":
self._emit("\n")
elif tag in _BLOCK_TAGS:
self._emit("\n\n")
elif tag == "hr":
self._emit("\n\n---\n\n")
elif tag == "blockquote":
self._emit("\n\n")
self._bq_stack.append([])
elif tag == "ul":
self._list_stack.append("ul")
self._emit("\n")
elif tag == "ol":
self._list_stack.append("ol")
start_attr = attr_dict.get("start")
try:
start = int(start_attr) if start_attr is not None else 1
except (ValueError, TypeError):
start = 1
self._ol_counter.append(start - 1)
self._emit("\n")
elif tag == "li":
indent = " " * max(0, len(self._list_stack) - 1)
if self._list_stack and self._list_stack[-1] == "ol":
if self._ol_counter:
self._ol_counter[-1] += 1
self._emit(f"\n{indent}{self._ol_counter[-1]}. ")
else:
self._emit(f"\n{indent}1. ")
else:
self._emit(f"\n{indent}* ")
elif tag == "pre":
self._pre_parts = []
self._in_pre = True
elif tag == "code" and not self._in_pre:
self._in_inline_code = True
self._emit("`")
elif tag == "table":
self._in_table = True
self._header_row_done = False
self._is_first_row = True
self._emit("\n\n")
elif tag == "tr":
# Flush any open cell/row from a previous row that may
# have omitted its optional </td> or </tr> end tags.
self._finish_cell()
self._finish_row()
elif tag in ("th", "td"):
# Flush any open cell (handles omitted </td>/<th>)
self._finish_cell()
self._cell_parts = []
self._in_cell = True
if tag == "th":
self._row_has_th = True
elif tag == "img":
# Skip images -- keeps fetched page text focused on readable
# content and avoids data-URI amplification.
return
def handle_endtag(self, tag: str) -> None:
tag = tag.lower()
if tag in _SKIP_TAGS:
self._skip_depth = max(0, self._skip_depth - 1)
return
if self._skip_depth:
return
if tag in _HEADING_TAGS:
self._emit("\n\n")
elif tag == "a":
self._finish_link()
elif tag in _INLINE_EMPHASIS:
self._emit(_INLINE_EMPHASIS[tag])
elif tag in _BLOCK_TAGS:
self._emit("\n\n")
elif tag == "blockquote":
if self._bq_stack:
content = "".join(self._bq_stack.pop())
prefixed = self._prefix_blockquote(content)
if prefixed:
self._emit("\n\n" + prefixed + "\n\n")
elif tag == "ul":
if self._list_stack and self._list_stack[-1] == "ul":
self._list_stack.pop()
self._emit("\n")
elif tag == "ol":
if self._list_stack and self._list_stack[-1] == "ol":
self._list_stack.pop()
if self._ol_counter:
self._ol_counter.pop()
self._emit("\n")
elif tag == "pre":
raw = "".join(self._pre_parts)
self._in_pre = False
block = "```\n" + raw + "\n```"
self._emit("\n\n" + block + "\n\n")
elif tag == "code" and not self._in_pre:
self._in_inline_code = False
self._emit("`")
elif tag in ("th", "td"):
self._finish_cell()
elif tag == "tr":
self._finish_cell()
self._finish_row()
elif tag == "table":
# Flush any remaining row (handles omitted </tr>)
self._finish_cell()
self._finish_row()
self._in_table = False
self._emit("\n")
# ------------------------------------------------------------------
# Text / entity handlers
# ------------------------------------------------------------------
def handle_data(self, data: str) -> None:
if self._skip_depth:
return
if self._in_pre:
self._pre_parts.append(data)
return
# Preserve literal whitespace inside inline <code> spans
if self._in_inline_code:
self._emit(data)
return
# Collapse all whitespace (including newlines) per HTML rules
text = re.sub(r"\s+", " ", data)
# Suppress whitespace-only text nodes between table structural
# elements (indentation from source HTML) to prevent leading
# spaces from breaking Markdown table row alignment.
if self._in_table and not self._in_cell and not text.strip():
return
self._emit(text)
def handle_entityref(self, name: str) -> None:
if self._skip_depth:
return
self._emit(html.unescape(f"&{name};"))
def handle_charref(self, name: str) -> None:
if self._skip_depth:
return
self._emit(html.unescape(f"&#{name};"))
# ------------------------------------------------------------------
# Flush pending buffers (handles truncated HTML from capped fetches)
# ------------------------------------------------------------------
def flush_pending(self) -> None:
"""Flush any open side-buffers into ``_out``.
Called after ``close()`` to recover content from truncated HTML
where closing tags were never seen (common when ``_fetch_page_text``
caps the download by byte count).
"""
# Flush innermost buffers first so their content propagates outward.
if self._in_link:
self._finish_link()
if self._in_inline_code:
self._in_inline_code = False
self._emit("`")
self._finish_cell()
self._finish_row()
if self._in_pre:
raw = "".join(self._pre_parts)
self._in_pre = False
block = "```\n" + raw + "\n```"
self._emit("\n\n" + block + "\n\n")
# Flatten any open blockquote buffers (innermost first)
while self._bq_stack:
content = "".join(self._bq_stack.pop())
prefixed = self._prefix_blockquote(content)
if not prefixed:
continue
if self._bq_stack:
self._bq_stack[-1].append("\n\n" + prefixed + "\n\n")
else:
self._out.append("\n\n" + prefixed + "\n\n")
# ------------------------------------------------------------------
# Post-processing
# ------------------------------------------------------------------
def _cleanup(text: str) -> str:
"""Normalize whitespace and blank lines in the final output.
Preserves content inside fenced code blocks verbatim so that
intentional blank lines in ``<pre>`` content are not collapsed.
"""
lines = text.split("\n")
out: list[str] = []
in_fence = False
blank_run = 0
for line in lines:
stripped = line.rstrip(" \t")
if stripped.startswith("```"):
in_fence = not in_fence
blank_run = 0
out.append(stripped)
continue
if in_fence:
# Preserve code block content exactly as-is
out.append(line)
continue
if not stripped:
blank_run += 1
if blank_run <= 1:
out.append("")
continue
blank_run = 0
out.append(stripped)
return "\n".join(out).strip()
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def html_to_markdown(source_html: str) -> str:
"""Convert an HTML string to Markdown.
Handles headings, links, bold/italic, lists (ordered and unordered),
tables, blockquotes, code blocks, and HTML entities. ``<script>``,
``<style>``, and ``<head>`` sections are stripped entirely.
"""
# Normalize line endings before parsing
source_html = source_html.replace("\r\n", "\n").replace("\r", "\n")
renderer = _MarkdownRenderer()
renderer.feed(source_html)
renderer.close()
renderer.flush_pending()
raw = "".join(renderer._out)
return _cleanup(raw)

View file

@ -0,0 +1,521 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved.
"""
Anthropic Messages API OpenAI format translation utilities.
Pure functions and a stateful stream emitter no FastAPI, no I/O.
"""
from __future__ import annotations
import json
from typing import Any, Optional, Union
def anthropic_messages_to_openai(
messages: list[dict],
system: Optional[Union[str, list]] = None,
) -> list[dict]:
"""Convert Anthropic messages + system to OpenAI-format message dicts."""
result: list[dict] = []
# System prompt
if system:
if isinstance(system, str):
result.append({"role": "system", "content": system})
elif isinstance(system, list):
parts = []
for block in system:
if isinstance(block, dict) and block.get("type") == "text":
parts.append(block["text"])
elif isinstance(block, str):
parts.append(block)
if parts:
result.append({"role": "system", "content": "\n".join(parts)})
for msg in messages:
role = msg["role"] if isinstance(msg, dict) else msg.role
content = msg["content"] if isinstance(msg, dict) else msg.content
if isinstance(content, str):
result.append({"role": role, "content": content})
continue
# Content is a list of blocks
text_parts: list[str] = []
tool_calls: list[dict] = []
tool_results: list[dict] = []
for block in content:
b = block if isinstance(block, dict) else block.model_dump()
btype = b.get("type", "")
if btype == "text":
text_parts.append(b["text"])
elif btype == "tool_use":
tool_calls.append(
{
"id": b["id"],
"type": "function",
"function": {
"name": b["name"],
"arguments": json.dumps(b["input"]),
},
}
)
elif btype == "tool_result":
tc = b.get("content", "")
if isinstance(tc, list):
tc = " ".join(
p["text"]
for p in tc
if isinstance(p, dict) and p.get("type") == "text"
)
tool_results.append(
{
"role": "tool",
"tool_call_id": b["tool_use_id"],
"content": str(tc),
}
)
if role == "assistant":
msg_dict: dict[str, Any] = {"role": "assistant"}
if text_parts:
msg_dict["content"] = "\n".join(text_parts)
if tool_calls:
msg_dict["tool_calls"] = tool_calls
result.append(msg_dict)
elif role == "user":
if text_parts:
result.append({"role": "user", "content": "\n".join(text_parts)})
for tr in tool_results:
result.append(tr)
return result
def anthropic_tools_to_openai(tools: list) -> list[dict]:
"""Convert Anthropic tool definitions to OpenAI function-tool format."""
result = []
for t in tools:
td = t if isinstance(t, dict) else t.model_dump()
result.append(
{
"type": "function",
"function": {
"name": td["name"],
"description": td.get("description", ""),
"parameters": td.get("input_schema", {}),
},
}
)
return result
def anthropic_tool_choice_to_openai(tc: Any) -> Any:
"""Translate Anthropic `tool_choice` into OpenAI `tool_choice`.
Anthropic formats (all dict shapes with a ``type`` discriminator):
- ``{"type": "auto"}`` ``"auto"``
- ``{"type": "any"}`` ``"required"``
- ``{"type": "none"}`` ``"none"``
- ``{"type": "tool", "name": "get_weather"}``
``{"type": "function", "function": {"name": "get_weather"}}``
Returns ``None`` for ``None`` or any unrecognized shape (caller may
then fall back to its own default, typically ``"auto"``).
"""
if tc is None:
return None
if not isinstance(tc, dict):
return None
t = tc.get("type")
if t == "auto":
return "auto"
if t == "any":
return "required"
if t == "none":
return "none"
if t == "tool":
name = tc.get("name")
if not name:
return None
return {"type": "function", "function": {"name": name}}
return None
def build_anthropic_sse_event(event_type: str, data: dict) -> str:
"""Format a single Anthropic SSE event."""
return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
class AnthropicStreamEmitter:
"""Converts generator events from generate_chat_completion_with_tools()
into Anthropic Messages SSE strings."""
def __init__(self) -> None:
self.block_index: int = 0
self._text_block_open: bool = False
self._prev_text: str = ""
self._usage: dict = {}
def start(self, message_id: str, model: str) -> list[str]:
"""Emit message_start and open the first text content block."""
events = []
events.append(
build_anthropic_sse_event(
"message_start",
{
"type": "message_start",
"message": {
"id": message_id,
"type": "message",
"role": "assistant",
"content": [],
"model": model,
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 0, "output_tokens": 0},
},
},
)
)
events.extend(self._open_text_block())
return events
def feed(self, event: dict) -> list[str]:
"""Process one generator event, return SSE strings."""
etype = event.get("type", "")
if etype == "content":
return self._handle_content(event)
elif etype == "tool_start":
return self._handle_tool_start(event)
elif etype == "tool_end":
return self._handle_tool_end(event)
elif etype == "metadata":
self._usage = event.get("usage", {})
return []
# status events — no Anthropic equivalent
return []
def finish(self, stop_reason: str = "end_turn") -> list[str]:
"""Close any open block and emit message_delta + message_stop."""
events = []
if self._text_block_open:
events.append(self._close_block())
events.append(
build_anthropic_sse_event(
"message_delta",
{
"type": "message_delta",
"delta": {"stop_reason": stop_reason, "stop_sequence": None},
"usage": {
"output_tokens": self._usage.get("completion_tokens", 0),
},
},
)
)
events.append(
build_anthropic_sse_event(
"message_stop",
{
"type": "message_stop",
},
)
)
return events
def _handle_content(self, event: dict) -> list[str]:
cumulative = event.get("text", "")
new_text = cumulative[len(self._prev_text) :]
self._prev_text = cumulative
if not new_text:
return []
if not self._text_block_open:
events = self._open_text_block()
else:
events = []
events.append(
build_anthropic_sse_event(
"content_block_delta",
{
"type": "content_block_delta",
"index": self.block_index,
"delta": {"type": "text_delta", "text": new_text},
},
)
)
return events
def _handle_tool_start(self, event: dict) -> list[str]:
events = []
# Close current text block if open
if self._text_block_open:
events.append(self._close_block())
# Open a tool_use block
self.block_index += 1
events.append(
build_anthropic_sse_event(
"content_block_start",
{
"type": "content_block_start",
"index": self.block_index,
"content_block": {
"type": "tool_use",
"id": event.get("tool_call_id", ""),
"name": event.get("tool_name", ""),
"input": {},
},
},
)
)
# Emit the arguments as input_json_delta
args = event.get("arguments", {})
if args:
events.append(
build_anthropic_sse_event(
"content_block_delta",
{
"type": "content_block_delta",
"index": self.block_index,
"delta": {
"type": "input_json_delta",
"partial_json": json.dumps(args),
},
},
)
)
return events
def _handle_tool_end(self, event: dict) -> list[str]:
events = []
# Close the tool_use block
events.append(self._close_block())
# Emit custom tool_result event (non-standard, ignored by SDKs)
events.append(
build_anthropic_sse_event(
"tool_result",
{
"type": "tool_result",
"tool_use_id": event.get("tool_call_id", ""),
"content": event.get("result", ""),
},
)
)
# Open a new text block for the model's next response
self.block_index += 1
events.extend(self._open_text_block())
# Reset text tracking for the next synthesis turn
self._prev_text = ""
return events
def _open_text_block(self) -> list[str]:
self._text_block_open = True
return [
build_anthropic_sse_event(
"content_block_start",
{
"type": "content_block_start",
"index": self.block_index,
"content_block": {"type": "text", "text": ""},
},
)
]
def _close_block(self) -> str:
self._text_block_open = False
return build_anthropic_sse_event(
"content_block_stop",
{
"type": "content_block_stop",
"index": self.block_index,
},
)
class AnthropicPassthroughEmitter:
"""Converts llama-server's OpenAI-format streaming chunks into Anthropic SSE.
Used for the client-side tool-use pass-through path: the client (e.g. Claude
Code) sends its own tool definitions in the ``tools`` field and expects to
execute them itself. We forward them to llama-server and translate the
streaming response back to Anthropic format without executing anything.
"""
def __init__(self) -> None:
self.block_index: int = -1
self._current_block_type: Optional[str] = None # "text" | "tool_use" | None
self._tool_call_states: dict = {} # delta index -> {block_index, id, name}
self._usage: dict = {}
self._stop_reason: str = "end_turn"
def start(self, message_id: str, model: str) -> list[str]:
return [
build_anthropic_sse_event(
"message_start",
{
"type": "message_start",
"message": {
"id": message_id,
"type": "message",
"role": "assistant",
"content": [],
"model": model,
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 0, "output_tokens": 0},
},
},
)
]
def feed_chunk(self, chunk: dict) -> list[str]:
"""Process one OpenAI streaming chat.completion.chunk."""
events: list[str] = []
# usage-only chunks carry token totals
usage = chunk.get("usage")
if usage:
self._usage = usage
choices = chunk.get("choices") or []
if not choices:
return events
choice = choices[0]
delta = choice.get("delta") or {}
finish_reason = choice.get("finish_reason")
# ── Text content ──
content = delta.get("content")
if content:
if self._current_block_type != "text":
if self._current_block_type is not None:
events.append(self._close_current_block())
events.extend(self._open_text_block())
events.append(
build_anthropic_sse_event(
"content_block_delta",
{
"type": "content_block_delta",
"index": self.block_index,
"delta": {"type": "text_delta", "text": content},
},
)
)
# ── Tool calls (streaming deltas) ──
tool_calls = delta.get("tool_calls") or []
for tc in tool_calls:
tc_idx = tc.get("index", 0)
fn = tc.get("function") or {}
if tc_idx not in self._tool_call_states:
# New tool call — close prior block, open tool_use block
if self._current_block_type is not None:
events.append(self._close_current_block())
tc_id = tc.get("id", "")
tc_name = fn.get("name", "")
self.block_index += 1
self._current_block_type = "tool_use"
self._tool_call_states[tc_idx] = {
"block_index": self.block_index,
"id": tc_id,
"name": tc_name,
}
events.append(
build_anthropic_sse_event(
"content_block_start",
{
"type": "content_block_start",
"index": self.block_index,
"content_block": {
"type": "tool_use",
"id": tc_id,
"name": tc_name,
"input": {},
},
},
)
)
args_delta = fn.get("arguments", "")
if args_delta:
events.append(
build_anthropic_sse_event(
"content_block_delta",
{
"type": "content_block_delta",
"index": self._tool_call_states[tc_idx]["block_index"],
"delta": {
"type": "input_json_delta",
"partial_json": args_delta,
},
},
)
)
# ── Finish reason ──
if finish_reason:
if finish_reason == "tool_calls":
self._stop_reason = "tool_use"
elif finish_reason == "length":
self._stop_reason = "max_tokens"
else:
self._stop_reason = "end_turn"
return events
def finish(self) -> list[str]:
events: list[str] = []
if self._current_block_type is not None:
events.append(self._close_current_block())
events.append(
build_anthropic_sse_event(
"message_delta",
{
"type": "message_delta",
"delta": {
"stop_reason": self._stop_reason,
"stop_sequence": None,
},
"usage": {
"output_tokens": self._usage.get("completion_tokens", 0),
},
},
)
)
events.append(
build_anthropic_sse_event(
"message_stop",
{"type": "message_stop"},
)
)
return events
def _open_text_block(self) -> list[str]:
self.block_index += 1
self._current_block_type = "text"
return [
build_anthropic_sse_event(
"content_block_start",
{
"type": "content_block_start",
"index": self.block_index,
"content_block": {"type": "text", "text": ""},
},
)
]
def _close_current_block(self) -> str:
idx = self.block_index
self._current_block_type = None
return build_anthropic_sse_event(
"content_block_stop",
{
"type": "content_block_stop",
"index": idx,
},
)

View file

@ -6,6 +6,15 @@
import utils.hardware.hardware as hw
DEFAULT_MODELS_GGUF = [
"unsloth/gemma-4-E2B-it-GGUF",
"unsloth/gemma-4-E4B-it-GGUF",
"unsloth/gemma-4-31B-it-GGUF",
"unsloth/gemma-4-26B-A4B-it-GGUF",
"unsloth/Qwen3.6-35B-A3B-GGUF",
"unsloth/Qwen3.5-4B-GGUF",
"unsloth/Qwen3.5-9B-GGUF",
"unsloth/Qwen3.5-35B-A3B-GGUF",
"unsloth/Qwen3.5-0.8B-GGUF",
"unsloth/Llama-3.2-1B-Instruct-GGUF",
"unsloth/Llama-3.2-3B-Instruct-GGUF",
"unsloth/Llama-3.1-8B-Instruct-GGUF",
@ -15,6 +24,19 @@ DEFAULT_MODELS_GGUF = [
]
DEFAULT_MODELS_STANDARD = [
"unsloth/gemma-4-E2B-it-GGUF",
"unsloth/gemma-4-E4B-it-GGUF",
"unsloth/gemma-4-31B-it-GGUF",
"unsloth/gemma-4-26B-A4B-it-GGUF",
"unsloth/Qwen3.6-35B-A3B-GGUF",
"unsloth/Qwen3.5-4B-GGUF",
"unsloth/Qwen3.5-9B-GGUF",
"unsloth/Qwen3.5-35B-A3B-GGUF",
"unsloth/Qwen3.5-0.8B-GGUF",
"unsloth/gemma-4-E2B-it",
"unsloth/gemma-4-E4B-it",
"unsloth/gemma-4-31B-it",
"unsloth/gemma-4-26B-A4B-it",
"unsloth/Qwen3-4B-Instruct-2507",
"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
"unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit",

View file

@ -18,7 +18,14 @@ from typing import Optional, Union, Generator, Tuple
from utils.models import ModelConfig, get_base_model_from_lora
from utils.paths import is_model_cached
from utils.utils import format_error_message
from utils.hardware import get_device, clear_gpu_cache, log_gpu_memory
from utils.hardware import (
get_device,
clear_gpu_cache,
log_gpu_memory,
get_device_map,
raise_if_offloaded,
get_visible_gpu_count,
)
from core.inference.audio_codecs import AudioCodecManager
from io import StringIO
import structlog
@ -241,10 +248,15 @@ class InferenceBackend:
load_in_4bit: bool = True,
hf_token: Optional[str] = None,
trust_remote_code: bool = False,
gpu_ids: Optional[list[int]] = None,
) -> bool:
"""
Load any model: base, LoRA adapter, text, or vision.
"""
# GGUF uses max_seq_length=0 as "model default"; Unsloth crashes on it.
if max_seq_length <= 0:
max_seq_length = 2048
try:
model_name = config.identifier
@ -260,6 +272,10 @@ class InferenceBackend:
return False
self.loading_models.add(model_name)
device_map = get_device_map(gpu_ids)
logger.info(
f"Using device_map='{device_map}' ({get_visible_gpu_count()} GPU(s) visible)"
)
self.models[model_name] = {
"is_vision": config.is_vision,
@ -290,6 +306,7 @@ class InferenceBackend:
config.path,
auto_model = CsmForConditionalGeneration,
load_in_4bit = False,
device_map = device_map,
token = hf_token if hf_token and hf_token.strip() else None,
trust_remote_code = trust_remote_code,
)
@ -325,6 +342,7 @@ class InferenceBackend:
config.path,
dtype = torch.float32,
load_in_4bit = False,
device_map = device_map,
token = hf_token if hf_token and hf_token.strip() else None,
trust_remote_code = trust_remote_code,
)
@ -345,6 +363,7 @@ class InferenceBackend:
llm_path,
dtype = torch.float32,
load_in_4bit = False,
device_map = device_map,
token = hf_token if hf_token and hf_token.strip() else None,
trust_remote_code = trust_remote_code,
)
@ -361,6 +380,7 @@ class InferenceBackend:
config.path,
max_seq_length = max_seq_length,
load_in_4bit = False,
device_map = device_map,
token = hf_token if hf_token and hf_token.strip() else None,
trust_remote_code = trust_remote_code,
)
@ -378,6 +398,7 @@ class InferenceBackend:
whisper_language = "English",
whisper_task = "transcribe",
load_in_4bit = False,
device_map = device_map,
token = hf_token if hf_token and hf_token.strip() else None,
trust_remote_code = trust_remote_code,
)
@ -405,6 +426,7 @@ class InferenceBackend:
model_name = config.path,
max_seq_length = max_seq_length,
load_in_4bit = False,
device_map = device_map,
token = hf_token if hf_token and hf_token.strip() else None,
trust_remote_code = trust_remote_code,
)
@ -420,6 +442,11 @@ class InferenceBackend:
audio_type, self.device, model_repo_path = model_repo_path
)
# Reject CPU/disk offload for audio models too
raise_if_offloaded(
self.models[model_name]["model"], device_map, "Inference"
)
self.active_model_name = model_name
self.loading_models.discard(model_name)
logger.info(f"Successfully loaded audio model: {model_name}")
@ -441,6 +468,7 @@ class InferenceBackend:
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
device_map = device_map,
token = hf_token if hf_token and hf_token.strip() else None,
trust_remote_code = trust_remote_code,
)
@ -497,6 +525,7 @@ class InferenceBackend:
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
device_map = device_map,
token = hf_token if hf_token and hf_token.strip() else None,
trust_remote_code = trust_remote_code,
)
@ -507,6 +536,10 @@ class InferenceBackend:
self.models[model_name]["model"] = model
self.models[model_name]["tokenizer"] = tokenizer
raise_if_offloaded(
self.models[model_name]["model"], device_map, "Inference"
)
# Load chat template info
self._load_chat_template_info(model_name)
@ -615,6 +648,7 @@ class InferenceBackend:
dtype = None,
load_in_4bit: bool = True,
hf_token: Optional[str] = None,
gpu_ids: Optional[list[int]] = None,
) -> Tuple[bool, Optional[str], Optional[str]]:
"""
Final Corrected Version:
@ -639,7 +673,12 @@ class InferenceBackend:
base_model_name, None, is_lora = False
)
if not self.load_model(
base_config, max_seq_length, dtype, load_in_4bit, hf_token
base_config,
max_seq_length,
dtype,
load_in_4bit,
hf_token,
gpu_ids = gpu_ids,
):
return False, None, None
@ -927,6 +966,12 @@ class InferenceBackend:
logger.warning(f"Could not apply get_chat_template: {e}")
# Step 2: Format with tokenizer.apply_chat_template()
if system_prompt:
template_messages = [
{"role": "system", "content": system_prompt}
] + messages
else:
template_messages = messages
try:
if not (hasattr(tokenizer, "chat_template") and tokenizer.chat_template):
raise ValueError(
@ -937,7 +982,7 @@ class InferenceBackend:
f"one via tokenizer.chat_template before inference."
)
formatted_prompt = tokenizer.apply_chat_template(
messages, tokenize = False, add_generation_prompt = True
template_messages, tokenize = False, add_generation_prompt = True
)
logger.debug(f"Formatted prompt: {formatted_prompt[:200]}...")
except Exception as e:
@ -992,30 +1037,51 @@ class InferenceBackend:
# Prepare vision messages
if image:
vision_messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": user_message},
],
}
]
user_msg = {
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": user_message},
],
}
if system_prompt:
vision_messages = [
{
"role": "system",
"content": [{"type": "text", "text": system_prompt}],
},
user_msg,
]
else:
vision_messages = [user_msg]
input_text = processor.apply_chat_template(
vision_messages, add_generation_prompt = True, tokenize = False
)
try:
input_text = processor.apply_chat_template(
vision_messages, add_generation_prompt = True, tokenize = False
)
except Exception as e:
if system_prompt:
logger.warning(
f"Vision processor for '{self.active_model_name}' may not support "
f"system messages; retrying without. Original error: {e}"
)
vision_messages = [user_msg]
input_text = processor.apply_chat_template(
vision_messages, add_generation_prompt = True, tokenize = False
)
else:
raise
inputs = processor(
image,
input_text,
add_special_tokens = False,
return_tensors = "pt",
).to(self.device)
).to(model.device)
else:
# Text-only for vision model
formatted_prompt = self.format_chat_prompt(messages, system_prompt)
inputs = raw_tokenizer(formatted_prompt, return_tensors = "pt").to(
self.device
model.device
)
# Stream with TextIteratorStreamer + background thread
@ -1155,7 +1221,7 @@ class InferenceBackend:
return_dict = True,
return_tensors = "pt",
truncation = False,
).to(self.device)
).to(model.device)
try:
from transformers import TextIteratorStreamer

File diff suppressed because it is too large Load diff

View file

@ -17,6 +17,7 @@ Pattern follows core/training/training.py.
import atexit
import base64
import os
import structlog
from loggers import get_logger
import multiprocessing as mp
@ -27,11 +28,17 @@ import uuid
from io import BytesIO
from pathlib import Path
from typing import Any, Generator, Optional, Tuple, Union
from utils.hardware import prepare_gpu_selection
logger = get_logger(__name__)
_CTX = mp.get_context("spawn")
class DownloadStallError(RuntimeError):
"""Raised when the worker reports no download progress for too long."""
# Dispatcher timeout constants (seconds)
_DISPATCH_READ_TIMEOUT = 30.0
_DISPATCH_POLL_INTERVAL = 0.5
@ -102,12 +109,13 @@ class InferenceOrchestrator:
self._top_models_ready.wait(timeout = 5)
top_gguf = self._top_gguf_cache or []
top_hub = self._top_hub_cache or []
# GGUFs first, then hub models, then static fallbacks.
# Curated static defaults first (editorial picks like new models),
# then HF download-ranked models to backfill.
# Send extras so the frontend still has 4 per category
# after removing already-downloaded models.
result: list[str] = []
seen: set[str] = set()
for m in top_gguf + top_hub + self._static_models:
for m in self._static_models + top_gguf + top_hub:
if m not in seen:
result.append(m)
seen.add(m)
@ -262,12 +270,17 @@ class InferenceOrchestrator:
except (EOFError, OSError, ValueError):
return None
def _wait_response(self, expected_type: str, timeout: float = 120.0) -> dict:
def _wait_response(self, expected_type: str, timeout: float = 300.0) -> dict:
"""Block until a response of the expected type arrives.
Also handles 'status' and 'error' events during the wait.
Returns the matching response dict.
Raises RuntimeError on timeout or subprocess crash.
The *timeout* is an **inactivity** timeout: it resets whenever the
subprocess sends a status message, so long-running operations (large
downloads, slow model loads) won't be killed as long as the subprocess
keeps reporting progress.
"""
deadline = time.monotonic() + timeout
@ -292,8 +305,15 @@ class InferenceOrchestrator:
if rtype == "status":
logger.info("Subprocess status: %s", resp.get("message", ""))
# Reset deadline — subprocess is still alive and working
deadline = time.monotonic() + timeout
continue
if rtype == "stall":
msg = resp.get("message", "Download stalled")
logger.warning("Subprocess reported stall: %s", msg)
raise DownloadStallError(msg)
# Other response types during wait — skip
logger.debug(
"Skipping response type '%s' while waiting for '%s'",
@ -302,7 +322,8 @@ class InferenceOrchestrator:
)
raise RuntimeError(
f"Timeout waiting for '{expected_type}' response after {timeout}s"
f"Timeout waiting for '{expected_type}' response "
f"(no activity for {timeout}s)"
)
def _drain_queue(self) -> list:
@ -571,6 +592,7 @@ class InferenceOrchestrator:
load_in_4bit: bool = True,
hf_token: Optional[str] = None,
trust_remote_code: bool = False,
gpu_ids: Optional[list[int]] = None,
) -> bool:
"""Load a model for inference.
@ -594,7 +616,16 @@ class InferenceOrchestrator:
"hf_token": hf_token or "",
"gguf_variant": getattr(config, "gguf_variant", None),
"trust_remote_code": trust_remote_code,
"gpu_ids": gpu_ids,
}
resolved_gpu_ids, gpu_selection = prepare_gpu_selection(
gpu_ids,
model_name = model_name,
hf_token = hf_token,
load_in_4bit = load_in_4bit,
)
sub_config["resolved_gpu_ids"] = resolved_gpu_ids
sub_config["gpu_selection"] = gpu_selection
# Always kill existing subprocess and spawn fresh.
# Reusing a subprocess after unsloth patches torch internals
@ -608,36 +639,66 @@ class InferenceOrchestrator:
# Dead subprocess — clean up
self._shutdown_subprocess(timeout = 2)
logger.info(
"Spawning fresh inference subprocess for '%s' (transformers %s.x)",
model_name,
needed_major,
disable_xet = sub_config.get("disable_xet", False) or (
os.environ.get("HF_HUB_DISABLE_XET") == "1"
)
self._spawn_subprocess(sub_config)
resp = self._wait_response("loaded", timeout = 180)
# Update local state from response
if resp.get("success"):
self._current_transformers_major = needed_major
model_info = resp.get("model_info", {})
self.active_model_name = model_info.get("identifier", model_name)
self.models[self.active_model_name] = {
"is_vision": model_info.get("is_vision", False),
"is_lora": model_info.get("is_lora", False),
"display_name": model_info.get("display_name", model_name),
"is_audio": model_info.get("is_audio", False),
"audio_type": model_info.get("audio_type"),
"has_audio_input": model_info.get("has_audio_input", False),
}
self.loading_models.discard(model_name)
logger.info("Model '%s' loaded successfully in subprocess", model_name)
return True
else:
error = resp.get("error", "Failed to load model")
self.loading_models.discard(model_name)
self.active_model_name = None
self.models.clear()
raise Exception(error)
for attempt in range(2):
logger.info(
"Spawning fresh inference subprocess for '%s' "
"(transformers %s.x, attempt %d/2%s)",
model_name,
needed_major,
attempt + 1,
", xet disabled" if disable_xet else "",
)
sub_config["disable_xet"] = disable_xet
self._spawn_subprocess(sub_config)
try:
resp = self._wait_response("loaded")
except DownloadStallError:
# First stall and Xet was enabled -> retry with Xet disabled
if attempt == 0 and not disable_xet:
logger.warning(
"Download stalled for '%s' -- retrying with "
"HF_HUB_DISABLE_XET=1",
model_name,
)
self._shutdown_subprocess(timeout = 5)
disable_xet = True
continue
# Second stall (or already had xet disabled) -> give up
self._shutdown_subprocess(timeout = 5)
raise RuntimeError(
f"Download stalled for '{model_name}' even with "
f"HF_HUB_DISABLE_XET=1 -- check your network connection"
)
# Got a response — check success
if resp.get("success"):
self._current_transformers_major = needed_major
model_info = resp.get("model_info", {})
self.active_model_name = model_info.get("identifier", model_name)
self.models[self.active_model_name] = {
"is_vision": model_info.get("is_vision", False),
"is_lora": model_info.get("is_lora", False),
"display_name": model_info.get("display_name", model_name),
"is_audio": model_info.get("is_audio", False),
"audio_type": model_info.get("audio_type"),
"has_audio_input": model_info.get("has_audio_input", False),
}
self.loading_models.discard(model_name)
logger.info(
"Model '%s' loaded successfully in subprocess", model_name
)
return True
else:
error = resp.get("error", "Failed to load model")
self.loading_models.discard(model_name)
self.active_model_name = None
self.models.clear()
raise Exception(error)
except Exception:
self.loading_models.discard(model_name)
@ -661,7 +722,7 @@ class InferenceOrchestrator:
"model_name": model_name,
}
)
resp = self._wait_response("unloaded", timeout = 30)
resp = self._wait_response("unloaded")
# Update local state
self.models.pop(model_name, None)

View file

@ -8,25 +8,260 @@ Supports web search (DuckDuckGo), Python code execution, and terminal commands.
"""
import ast
import http.client
import os
os.environ["UNSLOTH_IS_PRESENT"] = "1"
import random
import re
import shlex
import ssl
import subprocess
import sys
import tempfile
import threading
import urllib.request
from loggers import get_logger
logger = get_logger(__name__)
_EXEC_TIMEOUT = 300 # 5 minutes
# Pre-import modules used in _sandbox_preexec at module level so that
# the preexec_fn closure does not trigger the import machinery in the
# forked child (which can deadlock in multi-threaded servers).
_libc = None
if sys.platform == "linux":
try:
import ctypes
import ctypes.util
_libc_name = ctypes.util.find_library("c")
if _libc_name:
_libc = ctypes.CDLL(_libc_name, use_errno = True)
except (OSError, AttributeError):
pass
_resource = None
if sys.platform != "win32":
try:
import resource as _resource
except ImportError:
pass
# Strict raster-image allowlist for sandbox file serving.
# No .svg (XSS risk via embedded scripts), no .html, no .pdf.
_IMAGE_EXTS = frozenset({".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"})
_MAX_OUTPUT_CHARS = 8000 # truncate long output
_BASH_BLOCKED_WORDS = {"rm", "sudo", "dd", "chmod", "mkfs", "shutdown", "reboot"}
_BLOCKED_COMMANDS_COMMON = frozenset(
{
"rm",
"sudo",
"su",
"dd",
"chmod",
"chown",
"mkfs",
"shutdown",
"reboot",
"passwd",
"mount",
"umount",
"fdisk",
"kill",
"killall",
"pkill",
}
)
_BLOCKED_COMMANDS_WIN = frozenset(
{
"rmdir",
"takeown",
"icacls",
"runas",
"powershell",
"pwsh",
}
)
_BLOCKED_COMMANDS = (
_BLOCKED_COMMANDS_COMMON | _BLOCKED_COMMANDS_WIN
if sys.platform == "win32"
else _BLOCKED_COMMANDS_COMMON
)
def _find_blocked_commands(command: str) -> set[str]:
"""Detect blocked commands using shlex tokenization and regex scanning.
Catches: full paths (/usr/bin/sudo), quoted strings ("sudo"),
split-quotes (su""do), backslash escapes (\\rm), and command-position
words after ;, |, &&, $().
"""
blocked = set()
# 1. shlex tokenization (handles quotes, escapes, concatenation)
try:
tokens = (
shlex.split(command)
if sys.platform != "win32"
else shlex.split(command, posix = False)
)
except ValueError:
tokens = command.split()
for token in tokens:
base = os.path.basename(token).lower()
# Strip common Windows executable extensions so that
# runas.exe, shutdown.bat, etc. match the blocklist.
stem, ext = os.path.splitext(base)
if ext in {".exe", ".com", ".bat", ".cmd"}:
base = stem
if base in _BLOCKED_COMMANDS:
blocked.add(base)
# 2. Regex: catch blocked words at shell command boundaries
# (semicolons, pipes, &&, ||, backticks, $(), <(), subshells, newlines)
# Uses a single combined pattern for all blocked words.
# Handles optional Unix path prefix (/usr/bin/) and Windows drive
# letter prefix (C:\Windows\...\).
lowered = command.lower()
if _BLOCKED_COMMANDS:
words_alt = "|".join(re.escape(w) for w in sorted(_BLOCKED_COMMANDS))
pattern = (
rf"(?:^|[;&|`\n(]\s*|[$]\(\s*|<\(\s*)"
rf"(?:[\w./\\-]*/|[a-zA-Z]:[/\\][\w./\\-]*)?"
rf"({words_alt})(?:\.(?:exe|com|bat|cmd))?\b"
)
blocked.update(re.findall(pattern, lowered))
# 3. Check for nested shell invocations (bash -c 'sudo whoami',
# bash -lc '...', bash --login -c '...', cmd /c '...').
# When a -c or /c flag is found, look backwards for a shell name
# (skipping intermediate flags like --login, -l, -x) and recursively
# scan the nested command string.
_SHELLS = {"bash", "sh", "zsh", "dash", "ksh", "csh", "tcsh", "fish"}
_SHELLS_WIN = {"cmd", "cmd.exe"}
for i, token in enumerate(tokens):
tok_lower = token.lower()
# Match -c exactly, or combined flags ending in c (e.g. -lc, -xc)
is_unix_c = tok_lower == "-c" or (
tok_lower.startswith("-")
and tok_lower.endswith("c")
and not tok_lower.startswith("--")
)
is_win_c = tok_lower == "/c"
if not (is_unix_c or is_win_c) or i < 1 or i + 1 >= len(tokens):
continue
# Look backwards past any flags to find the shell binary.
# On Unix, flags start with - (skip those). On Windows, flags
# start with / but so do absolute paths, so only skip short
# single-char /X flags (not /bin/bash style paths).
for j in range(i - 1, -1, -1):
prev = tokens[j]
if prev.startswith("-"):
continue # skip Unix flags like --login, -l
if is_win_c and prev.startswith("/") and len(prev) <= 3:
continue # skip Windows flags like /s, /q (not /bin/bash)
prev_base = os.path.basename(prev).lower()
if is_unix_c and prev_base in _SHELLS:
blocked |= _find_blocked_commands(tokens[i + 1])
elif is_win_c and prev_base in _SHELLS_WIN:
blocked |= _find_blocked_commands(tokens[i + 1])
break # stop at first non-flag token
return blocked
def _build_safe_env(workdir: str) -> dict[str, str]:
"""Build a minimal, credential-free environment for sandboxed subprocesses.
Strips HF_TOKEN, WANDB_API_KEY, AWS_*, GH_TOKEN, LD_PRELOAD, DYLD_*, etc.
Preserves the active Python interpreter and virtualenv directories in PATH
so that pip, uv, and packages installed in the Studio runtime remain
accessible.
"""
# Start with the directory containing the running Python interpreter
# so that subprocess calls to 'python', 'pip', etc. resolve to the
# same environment the Studio server is running in.
exe_dir = os.path.dirname(sys.executable)
path_entries = [exe_dir] if exe_dir else []
# If a virtualenv is active, include its bin/Scripts directory.
venv = os.environ.get("VIRTUAL_ENV")
if venv:
venv_bin = os.path.join(venv, "Scripts" if sys.platform == "win32" else "bin")
if venv_bin not in path_entries:
path_entries.append(venv_bin)
if sys.platform == "win32":
sysroot = os.environ.get("SystemRoot", r"C:\Windows")
path_entries.extend([os.path.join(sysroot, "System32"), sysroot])
else:
path_entries.extend(["/usr/local/bin", "/usr/bin", "/bin"])
# Deduplicate while preserving order
deduped = list(dict.fromkeys(p for p in path_entries if p))
env = {
"PATH": os.pathsep.join(deduped),
"HOME": workdir,
"TMPDIR": workdir,
"LANG": os.environ.get("LANG", "C.UTF-8"),
"TERM": "dumb",
"PYTHONIOENCODING": "utf-8",
}
if venv:
env["VIRTUAL_ENV"] = venv
# Windows needs SystemRoot for Python/subprocess to work
if sys.platform == "win32":
env["SystemRoot"] = os.environ.get("SystemRoot", r"C:\Windows")
return env
def _sandbox_preexec():
"""Pre-exec hook: drop privilege escalation ability and set resource limits.
On Linux, applies PR_SET_NO_NEW_PRIVS so sudo/su/pkexec fail at the
kernel level. On Linux and macOS, sets RLIMIT_FSIZE.
No-op on Windows (use creationflags instead).
Note: RLIMIT_NPROC is intentionally NOT set because Linux enforces it
per real UID, not per process tree, so it would starve the Studio
server and other sessions sharing the same user account.
All modules and handles are resolved at import time (module level) so
this function does not trigger Python imports in the forked child,
avoiding potential deadlocks in multi-threaded servers.
"""
if _libc is not None:
try:
# PR_SET_NO_NEW_PRIVS = 38, arg2 = 1 (enable)
_libc.prctl(38, 1, 0, 0, 0)
except (OSError, AttributeError):
pass # Not available (container, old kernel, etc.)
if _resource is not None:
try:
# Limit file size to 100MB (prevents disk filling)
_resource.setrlimit(
_resource.RLIMIT_FSIZE, (100 * 1024 * 1024, 100 * 1024 * 1024)
)
except (ValueError, OSError):
pass
def _get_shell_cmd(command: str) -> list[str]:
"""Return the platform-appropriate shell invocation for a command string."""
if sys.platform == "win32":
return ["cmd", "/c", command]
return ["bash", "-c", command]
# Per-session working directories so each chat thread gets its own sandbox.
# Falls back to a shared ~/studio_sandbox/ for API callers without a session_id.
# Falls back to a shared ~/studio_sandbox/_default for API callers without a
# session_id.
_workdirs: dict[str, str] = {}
@ -47,7 +282,7 @@ def _get_workdir(session_id: str | None = None) -> str:
if not os.path.realpath(workdir).startswith(os.path.realpath(sandbox_root)):
workdir = os.path.join(sandbox_root, "_invalid")
else:
workdir = sandbox_root
workdir = os.path.join(sandbox_root, "_default")
os.makedirs(workdir, exist_ok = True)
_workdirs[key] = workdir
return _workdirs[key]
@ -57,16 +292,23 @@ WEB_SEARCH_TOOL = {
"type": "function",
"function": {
"name": "web_search",
"description": "Search the web for current information, recent events, or facts you are uncertain about.",
"description": (
"Search the web and fetch page content. Returns snippets for all results. "
"Use the url parameter to fetch full page text from a specific URL."
),
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query",
}
},
"url": {
"type": "string",
"description": "A URL to fetch full page content from (instead of searching). Use this to read a page found in search results.",
},
},
"required": ["query"],
"required": [],
},
},
}
@ -131,7 +373,11 @@ def execute_tool(
)
effective_timeout = _EXEC_TIMEOUT if timeout is _TIMEOUT_UNSET else timeout
if name == "web_search":
return _web_search(arguments.get("query", ""), timeout = effective_timeout)
return _web_search(
arguments.get("query", ""),
url = arguments.get("url"),
timeout = effective_timeout,
)
if name == "python":
return _python_exec(
arguments.get("code", ""), cancel_event, effective_timeout, session_id
@ -143,9 +389,226 @@ def execute_tool(
return f"Unknown tool: {name}"
def _web_search(query: str, max_results: int = 5, timeout: int = _EXEC_TIMEOUT) -> str:
"""Search the web using DuckDuckGo and return formatted results."""
if not query.strip():
_MAX_PAGE_CHARS = 16000 # limit fetched page text (after HTML-to-MD conversion)
# Raw download cap. Must be larger than _MAX_PAGE_CHARS because SSR pages
# embed large <head> sections (CSS, JS, SVGs) that are stripped during
# HTML-to-Markdown conversion. 512 KB is enough to reach article content
# on GitBook / Next.js / Docusaurus pages whose <head> alone can be 200 KB.
_MAX_FETCH_BYTES = 512 * 1024
_USER_AGENTS = (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:133.0) Gecko/20100101 Firefox/133.0",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:133.0) Gecko/20100101 Firefox/133.0",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/18.2 Safari/605.1.15",
)
_tls_ctx = ssl.create_default_context()
class _NoRedirect(urllib.request.HTTPRedirectHandler):
def redirect_request(self, req, fp, code, msg, headers, newurl):
return None
class _PinnedHTTPSConnection(http.client.HTTPSConnection):
"""HTTPS connection that connects to a pinned IP but uses a different
hostname for SNI and certificate verification.
The SSRF IP-pinning rewrites URLs to raw IPs. A normal HTTPSConnection
would then send no SNI and verify the cert against the IP, both of which
fail. This subclass splits the two concerns: TCP connects to the pinned
IP (``host`` parameter) while TLS uses ``sni_hostname`` for the
ClientHello and cert check.
"""
def __init__(self, host: str, *, sni_hostname: str, **kwargs):
super().__init__(host, **kwargs)
self._sni_hostname = sni_hostname
def connect(self):
# TCP connect to the pinned IP stored in self.host (+ tunnel if
# a proxy is configured via set_tunnel, though we do not use one).
http.client.HTTPConnection.connect(self)
# TLS handshake with the real hostname for SNI + cert verification.
self.sock = self._context.wrap_socket(
self.sock,
server_hostname = self._sni_hostname,
)
class _SNIHTTPSHandler(urllib.request.HTTPSHandler):
"""HTTPS handler that sends the correct SNI hostname during TLS handshake.
The SSRF IP-pinning rewrites URLs to raw IPs, which breaks SNI and cert
verification. This handler returns a ``_PinnedHTTPSConnection`` that
connects to the pinned IP but verifies TLS against the original hostname.
"""
def __init__(self, hostname: str):
super().__init__(context = _tls_ctx)
self._sni_hostname = hostname
def https_open(self, req):
return self.do_open(self._sni_connection, req)
def _sni_connection(self, host, **kwargs):
kwargs["context"] = _tls_ctx
return _PinnedHTTPSConnection(host, sni_hostname = self._sni_hostname, **kwargs)
def _validate_and_resolve_host(hostname: str, port: int) -> tuple[bool, str, str]:
"""Resolve *hostname*, reject non-public IPs, return a pinned IP string.
Returns ``(ok, reason_or_empty, resolved_ip)``. The caller should
connect to *resolved_ip* (with a ``Host`` header) to prevent DNS
rebinding between validation and the actual fetch.
"""
import ipaddress
import socket
try:
infos = socket.getaddrinfo(hostname, port, type = socket.SOCK_STREAM)
except OSError as e:
return False, f"Failed to resolve host: {e}", ""
if not infos:
return False, f"Failed to resolve host: no addresses for {hostname!r}", ""
for *_, sockaddr in infos:
ip = ipaddress.ip_address(sockaddr[0])
if (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_multicast
or ip.is_reserved
or ip.is_unspecified
):
return False, f"Blocked: refusing to fetch non-public address {ip}.", ""
# Return the first resolved address for pinning
first_ip = infos[0][4][0]
return True, "", first_ip
def _fetch_page_text(
url: str, max_chars: int = _MAX_PAGE_CHARS, timeout: int = 30
) -> str:
"""Fetch a URL and return plain text content (HTML tags stripped).
Blocks private/loopback/link-local targets (SSRF protection) and caps
the download size to avoid unbounded memory usage.
"""
from urllib.parse import urlparse
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return f"Blocked: only http/https URLs are allowed (got {parsed.scheme!r})."
if not parsed.hostname:
return "Blocked: URL is missing a hostname."
port = parsed.port or (443 if parsed.scheme == "https" else 80)
ok, reason, pinned_ip = _validate_and_resolve_host(parsed.hostname, port)
if not ok:
return reason
try:
from urllib.error import HTTPError as _HTTPError
from urllib.parse import urljoin, urlunparse
max_bytes = _MAX_FETCH_BYTES
current_url = url
current_host = parsed.hostname
ua = random.choice(_USER_AGENTS)
for _hop in range(5):
# Pin to the validated IP to prevent DNS rebinding.
# Rewrite the URL to use the IP and set the Host header.
cp = urlparse(current_url)
# Bracket IPv6 addresses so the netloc is valid in a URL.
ip_str = f"[{pinned_ip}]" if ":" in pinned_ip else pinned_ip
ip_netloc = f"{ip_str}:{cp.port}" if cp.port else ip_str
pinned_url = urlunparse(cp._replace(netloc = ip_netloc))
opener = urllib.request.build_opener(
_NoRedirect,
_SNIHTTPSHandler(current_host),
)
req = urllib.request.Request(
pinned_url,
headers = {
"User-Agent": ua,
"Host": current_host,
},
)
try:
resp = opener.open(req, timeout = timeout)
except _HTTPError as e:
if e.code not in (301, 302, 303, 307, 308):
return (
f"Failed to fetch URL: HTTP {e.code} {getattr(e, 'reason', '')}"
)
location = e.headers.get("Location")
if not location:
return "Failed to fetch URL: redirect missing Location header."
current_url = urljoin(current_url, location)
rp = urlparse(current_url)
if rp.scheme not in ("http", "https") or not rp.hostname:
return "Blocked: redirect target is not a valid http/https URL."
rp_port = rp.port or (443 if rp.scheme == "https" else 80)
ok2, reason2, pinned_ip = _validate_and_resolve_host(
rp.hostname,
rp_port,
)
if not ok2:
return reason2
current_host = rp.hostname
continue
# Success -- read capped body
raw_bytes = resp.read(max_bytes)
break
else:
return "Failed to fetch URL: too many redirects."
charset = resp.headers.get_content_charset() or "utf-8"
raw_html = raw_bytes.decode(charset, errors = "replace")
except _HTTPError as e:
return f"Failed to fetch URL: HTTP {e.code} {getattr(e, 'reason', '')}"
except Exception as e:
return f"Failed to fetch URL: {e}"
# Convert HTML to Markdown using the builtin converter (no external deps)
from ._html_to_md import html_to_markdown
text = html_to_markdown(raw_html)
if not text:
return "(page returned no readable text)"
if len(text) > max_chars:
text = text[:max_chars] + f"\n\n... (truncated, {len(text)} chars total)"
return text
def _web_search(
query: str,
max_results: int = 5,
timeout: int = _EXEC_TIMEOUT,
url: str | None = None,
) -> str:
"""Search the web using DuckDuckGo and return formatted results.
If ``url`` is provided, fetches that page directly instead of searching.
"""
# Direct URL fetch mode
if url and url.strip():
fetch_timeout = 60 if timeout is None else min(timeout, 60)
return _fetch_page_text(url.strip(), timeout = fetch_timeout)
if not query or not query.strip():
return "No query provided."
try:
from ddgs import DDGS
@ -160,7 +623,13 @@ def _web_search(query: str, max_results: int = 5, timeout: int = _EXEC_TIMEOUT)
f"URL: {r.get('href', '')}\n"
f"Snippet: {r.get('body', '')}"
)
return "\n\n---\n\n".join(parts)
text = "\n\n---\n\n".join(parts)
text += (
"\n\n---\n\nIMPORTANT: These are only short snippets. "
"To get the full page content, call web_search with "
'the url parameter (e.g. {"url": "<URL>"}).'
)
return text
except Exception as e:
return f"Search failed: {e}"
@ -186,6 +655,7 @@ def _check_signal_escape_patterns(code: str):
signal_tampering = []
exception_catching = []
shell_escapes = []
warnings = []
def _ast_name_matches(node, names):
@ -203,10 +673,84 @@ def _check_signal_escape_patterns(code: str):
return full_name in names
return False
# Dangerous os/subprocess functions that can execute shell commands
_SHELL_EXEC_FUNCS = frozenset(
{
"os.system",
"os.popen",
"os.popen2",
"os.popen3",
"os.popen4",
"os.execl",
"os.execle",
"os.execlp",
"os.execlpe",
"os.execv",
"os.execve",
"os.execvp",
"os.execvpe",
"os.spawnl",
"os.spawnle",
"os.spawnlp",
"os.spawnlpe",
"os.spawnv",
"os.spawnve",
"os.spawnvp",
"os.spawnvpe",
"os.posix_spawn",
"os.posix_spawnp",
"subprocess.run",
"subprocess.call",
"subprocess.check_call",
"subprocess.check_output",
"subprocess.Popen",
"subprocess.getoutput",
"subprocess.getstatusoutput",
}
)
def _extract_string_from_node(node):
"""Extract a plain string value from an AST node, if it is a constant."""
if isinstance(node, ast.Constant) and isinstance(node.value, str):
return node.value
return None
def _extract_strings_from_list(node):
"""Extract string elements from an AST List or Tuple node."""
if isinstance(node, (ast.List, ast.Tuple)):
parts = []
for elt in node.elts:
s = _extract_string_from_node(elt)
if s is not None:
parts.append(s)
return parts
return []
# Keyword argument names that carry command content (as opposed to
# control flags like check=True, text=True, capture_output=True).
_CMD_KWARGS = frozenset({"args", "command", "executable", "path", "file"})
def _check_args_for_blocked(args_nodes):
"""Check if any call arguments contain blocked commands."""
found = set()
for arg in args_nodes:
s = _extract_string_from_node(arg)
if s is not None:
found |= _find_blocked_commands(s)
strs = _extract_strings_from_list(arg)
for s in strs:
found |= _find_blocked_commands(s)
return found
class SignalEscapeVisitor(ast.NodeVisitor):
def __init__(self):
self.imports_signal = False
self.signal_aliases = {"signal"}
self.os_aliases = {"os"}
self.subprocess_aliases = {"subprocess"}
# Maps bare function names to their fully-qualified form
# for from-import tracking (e.g. "system" -> "os.system")
self.shell_exec_aliases: dict[str, str] = {}
self.loop_depth = 0
def visit_Import(self, node):
@ -215,6 +759,10 @@ def _check_signal_escape_patterns(code: str):
self.imports_signal = True
if alias.asname:
self.signal_aliases.add(alias.asname)
elif alias.name == "os":
self.os_aliases.add(alias.asname or "os")
elif alias.name == "subprocess":
self.subprocess_aliases.add(alias.asname or "subprocess")
self.generic_visit(node)
def visit_ImportFrom(self, node):
@ -232,6 +780,16 @@ def _check_signal_escape_patterns(code: str):
"alarm",
):
self.signal_aliases.add(alias.asname or alias.name)
elif node.module in ("os", "subprocess"):
if node.module == "os":
self.os_aliases.add("os")
else:
self.subprocess_aliases.add("subprocess")
# Track from-imports of dangerous functions
for alias in node.names:
fq = f"{node.module}.{alias.name}"
if fq in _SHELL_EXEC_FUNCS:
self.shell_exec_aliases[alias.asname or alias.name] = fq
self.generic_visit(node)
def visit_While(self, node):
@ -296,6 +854,111 @@ def _check_signal_escape_patterns(code: str):
"description": "Modifies signal mask (may block SIGALRM)",
}
)
# --- Shell escape detection ---
# Resolve the fully qualified function name for os.*/subprocess.*
shell_func = None
if isinstance(func, ast.Attribute):
if isinstance(func.value, ast.Name):
if func.value.id in self.os_aliases:
shell_func = f"os.{func.attr}"
elif func.value.id in self.subprocess_aliases:
shell_func = f"subprocess.{func.attr}"
elif isinstance(func, ast.Name):
# Check from-import aliases: from os import system; system(...)
shell_func = self.shell_exec_aliases.get(func.id)
if shell_func and shell_func in _SHELL_EXEC_FUNCS:
# Expand **kwargs dicts to inspect their keys
expanded_kwargs: dict[str, ast.AST] = {}
has_opaque_kwargs = False
for kw in node.keywords:
if kw.arg is not None:
expanded_kwargs[kw.arg] = kw.value
elif isinstance(kw.value, ast.Dict):
for k, v in zip(kw.value.keys, kw.value.values):
key = _extract_string_from_node(k) if k else None
if key is not None:
expanded_kwargs[key] = v
else:
has_opaque_kwargs = True
cmd_kw_values = [
v for k, v in expanded_kwargs.items() if k in _CMD_KWARGS
]
all_call_args = list(node.args) + cmd_kw_values
blocked_in_args = _check_args_for_blocked(all_call_args)
if has_opaque_kwargs:
# Can't inspect dynamic **kwargs -- flag as unsafe
shell_escapes.append(
{
"type": "shell_escape_dynamic",
"line": node.lineno,
"description": (
f"{shell_func}() called with dynamic **kwargs"
),
}
)
elif blocked_in_args:
shell_escapes.append(
{
"type": "shell_escape",
"line": node.lineno,
"description": (
f"{shell_func}() invokes blocked command(s): "
f"{', '.join(sorted(blocked_in_args))}"
),
}
)
else:
# Only flag dynamic args for functions that interpret
# strings as shell commands, or when shell= might be
# enabled. Treat any non-literal-False shell= value
# as potentially True (conservative).
_STRING_SHELL_FUNCS = frozenset(
{
"os.system",
"os.popen",
"os.popen2",
"os.popen3",
"os.popen4",
"subprocess.getoutput",
"subprocess.getstatusoutput",
}
)
shell_node = expanded_kwargs.get("shell")
shell_safe = shell_node is None or (
isinstance(shell_node, ast.Constant)
and shell_node.value is False
)
if shell_func in _STRING_SHELL_FUNCS or not shell_safe:
def _is_safe_literal(n):
if _extract_string_from_node(n) is not None:
return True
if isinstance(n, (ast.List, ast.Tuple)):
return all(
_extract_string_from_node(e) is not None
for e in n.elts
)
return False
has_non_literal = any(
not _is_safe_literal(a) for a in all_call_args
)
if has_non_literal:
shell_escapes.append(
{
"type": "shell_escape_dynamic",
"line": node.lineno,
"description": (
f"{shell_func}() called with non-literal "
f"shell command (potential shell escape)"
),
}
)
self.generic_visit(node)
def visit_ExceptHandler(self, node):
@ -311,7 +974,12 @@ def _check_signal_escape_patterns(code: str):
}
)
elif isinstance(node.type, ast.Name):
if node.type.id in ("TimeoutError", "BaseException", "Exception"):
# Only flag BaseException and TimeoutError, NOT Exception.
# except Exception does not catch SystemExit or
# KeyboardInterrupt, so it cannot suppress timeout
# enforcement. Flagging Exception causes false positives
# on normal error-handling patterns.
if node.type.id in ("TimeoutError", "BaseException"):
exception_catching.append(
{
"type": f"catches_{node.type.id}_in_loop",
@ -322,7 +990,7 @@ def _check_signal_escape_patterns(code: str):
elif isinstance(node.type, ast.Tuple):
for elt in node.type.elts:
if isinstance(elt, ast.Name):
if elt.id in ("TimeoutError", "BaseException", "Exception"):
if elt.id in ("TimeoutError", "BaseException"):
exception_catching.append(
{
"type": f"catches_{elt.id}_in_loop",
@ -338,10 +1006,15 @@ def _check_signal_escape_patterns(code: str):
if visitor.imports_signal and not signal_tampering:
warnings.append("Code imports 'signal' module - review manually for safety")
is_safe = len(signal_tampering) == 0 and len(exception_catching) == 0
is_safe = (
len(signal_tampering) == 0
and len(exception_catching) == 0
and len(shell_escapes) == 0
)
return is_safe, {
"signal_tampering": signal_tampering,
"exception_catching": exception_catching,
"shell_escapes": shell_escapes,
"warnings": warnings,
}
@ -353,13 +1026,27 @@ def _check_code_safety(code: str) -> str | None:
"""
safe, info = _check_signal_escape_patterns(code)
if not safe:
# SyntaxError from ast.parse -- let these through so the subprocess
# produces a normal Python traceback instead of a misleading
# "unsafe code detected" message.
if info.get("error"):
return None
reasons = [
item.get("description", "") for item in info.get("signal_tampering", [])
]
return (
f"Error: unsafe code detected ({'; '.join(reasons)}). "
f"Please remove signal manipulation from your code."
)
shell_reasons = [
item.get("description", "") for item in info.get("shell_escapes", [])
]
exception_reasons = [
item.get("description", "") for item in info.get("exception_catching", [])
]
all_reasons = [r for r in reasons + shell_reasons + exception_reasons if r]
if all_reasons:
return (
f"Error: unsafe code detected ({'; '.join(all_reasons)}). "
f"Please remove unsafe patterns from your code."
)
return None
@ -396,6 +1083,17 @@ def _python_exec(
tmp_path = None
workdir = _get_workdir(session_id)
# Snapshot image mtimes so we detect both new and overwritten files.
_before: dict[str, int] = {}
if os.path.isdir(workdir):
for _name in os.listdir(workdir):
if os.path.splitext(_name)[1].lower() in _IMAGE_EXTS:
_p = os.path.join(workdir, _name)
if os.path.isfile(_p):
try:
_before[_name] = os.stat(_p).st_mtime_ns
except OSError:
pass
try:
fd, tmp_path = tempfile.mkstemp(
suffix = ".py", prefix = "studio_exec_", dir = workdir
@ -403,13 +1101,20 @@ def _python_exec(
with os.fdopen(fd, "w") as f:
f.write(code)
proc = subprocess.Popen(
[sys.executable, tmp_path],
safe_env = _build_safe_env(workdir)
popen_kwargs = dict(
stdout = subprocess.PIPE,
stderr = subprocess.STDOUT,
text = True,
cwd = workdir,
env = safe_env,
)
if sys.platform != "win32":
popen_kwargs["preexec_fn"] = _sandbox_preexec
else:
popen_kwargs["creationflags"] = subprocess.CREATE_NO_WINDOW
proc = subprocess.Popen([sys.executable, tmp_path], **popen_kwargs)
# Spawn cancel watcher if we have a cancel event
if cancel_event is not None:
@ -431,7 +1136,29 @@ def _python_exec(
result = output or ""
if proc.returncode != 0:
result = f"Exit code {proc.returncode}:\n{result}"
return _truncate(result) if result.strip() else "(no output)"
result = _truncate(result) if result.strip() else "(no output)"
# Detect new or overwritten image files and append sentinel for frontend
if session_id and os.path.isdir(workdir):
new_images = []
for _name in os.listdir(workdir):
if os.path.splitext(_name)[1].lower() not in _IMAGE_EXTS:
continue
_p = os.path.join(workdir, _name)
if not os.path.isfile(_p):
continue
try:
_mtime = os.stat(_p).st_mtime_ns
except OSError:
continue
if _name not in _before or _mtime != _before[_name]:
new_images.append(_name)
if new_images:
import json as _json
result += f"\n__IMAGES__:{_json.dumps(sorted(new_images))}"
return result
except Exception as e:
return f"Execution error: {e}"
@ -453,21 +1180,27 @@ def _bash_exec(
if not command or not command.strip():
return "No command provided."
# Block dangerous commands
tokens = set(command.lower().split())
blocked = tokens & _BASH_BLOCKED_WORDS
# Block dangerous commands (shlex + regex based)
blocked = _find_blocked_commands(command)
if blocked:
return f"Blocked command(s) for safety: {', '.join(sorted(blocked))}"
try:
workdir = _get_workdir(session_id)
proc = subprocess.Popen(
["bash", "-c", command],
safe_env = _build_safe_env(workdir)
popen_kwargs = dict(
stdout = subprocess.PIPE,
stderr = subprocess.STDOUT,
text = True,
cwd = workdir,
env = safe_env,
)
if sys.platform != "win32":
popen_kwargs["preexec_fn"] = _sandbox_preexec
else:
popen_kwargs["creationflags"] = subprocess.CREATE_NO_WINDOW
proc = subprocess.Popen(_get_shell_cmd(command), **popen_kwargs)
if cancel_event is not None:
watcher = threading.Thread(

View file

@ -22,6 +22,7 @@ from loggers import get_logger
import os
import queue as _queue
import sys
import threading
import time
import traceback
from io import BytesIO
@ -29,40 +30,19 @@ from pathlib import Path
from typing import Any
logger = get_logger(__name__)
from utils.hardware import apply_gpu_ids
def _activate_transformers_version(model_name: str) -> None:
"""Activate the correct transformers version BEFORE any ML imports.
If the model needs transformers 5.x, prepend the pre-installed .venv_t5/
directory to sys.path. Otherwise do nothing (default 4.57.x in .venv/).
"""
"""Activate the correct transformers version BEFORE any ML imports."""
# Ensure backend is on path for utils imports
backend_path = str(Path(__file__).resolve().parent.parent.parent)
if backend_path not in sys.path:
sys.path.insert(0, backend_path)
from utils.transformers_version import (
needs_transformers_5,
_resolve_base_model,
_ensure_venv_t5_exists,
_VENV_T5_DIR,
)
from utils.transformers_version import activate_transformers_for_subprocess
resolved = _resolve_base_model(model_name)
if needs_transformers_5(resolved):
if not _ensure_venv_t5_exists():
raise RuntimeError(
f"Cannot activate transformers 5.x: .venv_t5 missing at {_VENV_T5_DIR}"
)
if _VENV_T5_DIR not in sys.path:
sys.path.insert(0, _VENV_T5_DIR)
logger.info("Activated transformers 5.x from %s", _VENV_T5_DIR)
# Propagate to child subprocesses (e.g. GGUF converter)
_pp = os.environ.get("PYTHONPATH", "")
os.environ["PYTHONPATH"] = _VENV_T5_DIR + (os.pathsep + _pp if _pp else "")
else:
logger.info("Using default transformers (4.57.x) for %s", model_name)
activate_transformers_for_subprocess(model_name)
def _decode_image(image_base64: str):
@ -113,6 +93,157 @@ def _build_model_config(config: dict):
return mc
def _get_hf_download_state(
model_names: list[str] | None = None,
) -> tuple[int, bool] | None:
"""Return (total_bytes, has_incomplete) for the HF Hub cache, or None on error.
When *model_names* is provided, only those models' ``blobs/``
directories are checked instead of scanning every cached model --
much faster on systems with many models. Accepts multiple names so
that LoRA loads can watch both the adapter repo and the base model
repo simultaneously.
*has_incomplete* is True when any ``*.incomplete`` files exist in the
watched blobs directories, indicating that ``huggingface_hub`` is
actively downloading.
Returns None if the state cannot be determined (import error,
permission error, etc.) so callers can skip stall logic.
"""
try:
from huggingface_hub.constants import HF_HUB_CACHE
cache = Path(HF_HUB_CACHE)
if not cache.exists():
return (0, False)
total = 0
has_incomplete = False
blobs_dirs: list[Path] = []
if model_names:
from utils.paths import resolve_cached_repo_id_case
for name in model_names:
if not name:
continue
# Skip local filesystem paths -- HF model IDs use forward
# slashes (org/model) but never start with / . ~ or contain
# backslashes. This distinguishes them from absolute paths,
# relative paths, and Windows paths.
if name.startswith(("/", ".", "~")) or "\\" in name:
continue
name = resolve_cached_repo_id_case(name)
# HF cache dir format: models--org--name (slashes -> --)
cache_dir_name = "models--" + name.replace("/", "--")
blobs_dir = cache / cache_dir_name / "blobs"
if blobs_dir.exists():
blobs_dirs.append(blobs_dir)
else:
blobs_dirs = list(cache.glob("models--*/blobs"))
for bdir in blobs_dirs:
for f in bdir.iterdir():
try:
if f.is_file():
total += f.stat().st_size
if f.name.endswith(".incomplete"):
has_incomplete = True
except OSError:
pass
return (total, has_incomplete)
except Exception as e:
logger.debug("Failed to determine HF download state: %s", e)
return None
def _start_heartbeat(
resp_queue: Any,
interval: float = 30.0,
stall_timeout: float = 180.0,
xet_disabled: bool = False,
model_names: list[str] | None = None,
) -> threading.Event:
"""Start a daemon thread that sends periodic status heartbeats.
Monitors the HF Hub cache directory for download activity. A stall
is only reported when ``*.incomplete`` files are present (indicating
``huggingface_hub`` is actively downloading) **and** the total cache
size has not changed for *stall_timeout* seconds.
Once the download finishes (no more ``.incomplete`` files), the stall
timer resets, so post-download initialization (quantization, GPU
weight loading) is never misclassified as a stalled download.
Returns a stop event -- set it to terminate the heartbeat thread.
"""
stop = threading.Event()
transport = "https" if xet_disabled else "xet"
def _beat():
state = _get_hf_download_state(model_names)
last_size = state[0] if state is not None else 0
last_change = time.monotonic()
while not stop.wait(interval):
state = _get_hf_download_state(model_names)
now = time.monotonic()
# Skip stall logic if we cannot measure the cache
if state is None:
_send_response(
resp_queue,
{
"type": "status",
"message": f"Loading model ({transport} transport)...",
"ts": time.time(),
},
)
continue
current_size, has_incomplete = state
if current_size != last_size:
last_size = current_size
last_change = now
# Only fire stall when .incomplete files are present,
# confirming a download is actively in progress.
# Once downloads finish (no .incomplete), reset the timer
# so model init time is not counted as a stall.
if not has_incomplete:
last_change = now
elif now - last_change >= stall_timeout:
_send_response(
resp_queue,
{
"type": "stall",
"message": (
f"Download appears stalled ({transport} transport) "
f"-- no progress for {int(now - last_change)}s"
),
"ts": time.time(),
},
)
# Only fire once -- the orchestrator will kill us
return
_send_response(
resp_queue,
{
"type": "status",
"message": f"Loading model ({transport} transport)...",
"ts": time.time(),
},
)
t = threading.Thread(target = _beat, daemon = True)
t.start()
return stop
def _handle_load(backend, config: dict, resp_queue: Any) -> None:
"""Handle a load command: load a model into the backend."""
try:
@ -156,13 +287,52 @@ def _handle_load(backend, config: dict, resp_queue: Any) -> None:
except Exception as e:
logger.warning("Could not read adapter_config.json: %s", e)
success = backend.load_model(
config = mc,
max_seq_length = config.get("max_seq_length", 2048),
load_in_4bit = load_in_4bit,
hf_token = hf_token,
trust_remote_code = config.get("trust_remote_code", False),
# Auto-enable trust_remote_code for NemotronH/Nano models only.
# NemotronH has config parsing bugs requiring trust_remote_code=True.
# Other transformers 5.x models are native and do NOT need it.
# NOTE: Must NOT match Llama-Nemotron (standard Llama architecture).
_NEMOTRON_TRUST_SUBSTRINGS = ("nemotron_h", "nemotron-h", "nemotron-3-nano")
trust_remote_code = config.get("trust_remote_code", False)
if not trust_remote_code:
model_name = config["model_name"]
_mn_lower = model_name.lower()
if any(sub in _mn_lower for sub in _NEMOTRON_TRUST_SUBSTRINGS) and (
_mn_lower.startswith("unsloth/") or _mn_lower.startswith("nvidia/")
):
trust_remote_code = True
logger.info(
"Auto-enabled trust_remote_code for Nemotron model: %s",
model_name,
)
# Send heartbeats every 30s so the orchestrator knows we're still alive
# (download / weight loading can take a long time on slow connections)
xet_disabled = os.environ.get("HF_HUB_DISABLE_XET") == "1"
# Watch both the model repo and base model repo (for LoRA loads
# where the base model download is the actual bottleneck)
watch_repos = [mc.identifier]
base = getattr(mc, "base_model", None)
if base and str(base) != mc.identifier:
watch_repos.append(str(base))
heartbeat_stop = _start_heartbeat(
resp_queue,
interval = 30.0,
xet_disabled = xet_disabled,
model_names = watch_repos,
)
try:
success = backend.load_model(
config = mc,
max_seq_length = config.get("max_seq_length", 2048),
load_in_4bit = load_in_4bit,
hf_token = hf_token,
trust_remote_code = trust_remote_code,
gpu_ids = config.get("resolved_gpu_ids"),
)
finally:
heartbeat_stop.set()
if success:
# Build model_info for the parent to mirror
@ -474,6 +644,10 @@ def run_inference_process(
"ignore" # Suppress warnings at C-level before imports
)
if config.get("disable_xet"):
os.environ["HF_HUB_DISABLE_XET"] = "1"
logger.info("Xet transport disabled (HF_HUB_DISABLE_XET=1)")
import warnings
from loggers.config import LogConfig
@ -485,6 +659,8 @@ def run_inference_process(
env = os.getenv("ENVIRONMENT_TYPE", "production"),
)
apply_gpu_ids(config.get("resolved_gpu_ids"))
model_name = config["model_name"]
# ── 1. Activate correct transformers version BEFORE any ML imports ──

View file

@ -33,7 +33,14 @@ if sys.platform in ("win32", "darwin"):
sys.path.insert(0, _compile_cache)
import torch
from utils.hardware import clear_gpu_cache, safe_num_proc, dataset_map_num_proc
from utils.hardware import (
clear_gpu_cache,
safe_num_proc,
dataset_map_num_proc,
get_device_map,
raise_if_offloaded,
get_visible_gpu_count,
)
torch._dynamo.config.recompile_limit = 64
from unsloth import FastLanguageModel, FastVisionModel, is_bfloat16_supported
@ -81,8 +88,8 @@ class TrainingProgress:
epoch: float = 0
step: int = 0
total_steps: int = 0
loss: float = 0.0
learning_rate: float = 0.0
loss: Optional[float] = None
learning_rate: Optional[float] = None
is_training: bool = False
is_completed: bool = False
error: Optional[str] = None
@ -183,7 +190,11 @@ class UnslothTrainer:
self._cuda_audio_used = False
# --- Detect VLM ---
vision = is_vision_model(model_name) if not self.is_audio else False
vision = (
is_vision_model(model_name, hf_token = hf_token)
if not self.is_audio
else False
)
self.is_vlm = not self.is_audio_vlm and vision and is_dataset_image
logger.info(
@ -244,7 +255,7 @@ class UnslothTrainer:
def on_log(self, args, state, control, logs = None, **kwargs):
if not logs:
return
loss_value = logs.get("loss", logs.get("train_loss", 0.0))
loss_value = logs.get("loss", logs.get("train_loss", None))
current_step = state.global_step
grad_norm = logs.get("grad_norm", None)
@ -268,7 +279,7 @@ class UnslothTrainer:
step = current_step,
epoch = round(state.epoch, 2) if state.epoch else 0,
loss = loss_value,
learning_rate = logs.get("learning_rate", 0.0),
learning_rate = logs.get("learning_rate", None),
elapsed_seconds = elapsed_seconds,
eta_seconds = eta_seconds,
grad_norm = grad_norm,
@ -487,6 +498,7 @@ class UnslothTrainer:
is_dataset_audio: bool = False,
trust_remote_code: bool = False,
full_finetuning: bool = False,
gpu_ids: Optional[list[int]] = None,
) -> bool:
"""Load model for training (supports both text and vision models)"""
self.load_in_4bit = load_in_4bit # Store for training_meta.json
@ -550,7 +562,11 @@ class UnslothTrainer:
self._cuda_audio_used = False
# VLM: vision model with image dataset (mutually exclusive with audio paths)
vision = is_vision_model(model_name) if not self.is_audio else False
vision = (
is_vision_model(model_name, hf_token = hf_token)
if not self.is_audio
else False
)
self.is_vlm = not self.is_audio_vlm and vision and is_dataset_image
self.model_name = model_name
self.max_seq_length = max_seq_length
@ -624,6 +640,11 @@ class UnslothTrainer:
self._update_progress(error = friendly, is_training = False)
return False
device_map = get_device_map(gpu_ids)
logger.info(
f"Using device_map='{device_map}' ({get_visible_gpu_count()} GPU(s) visible)"
)
# Branch based on model type
if self._audio_type == "csm":
# CSM: FastModel + auto_model=CsmForConditionalGeneration + load_in_4bit=False
@ -636,6 +657,7 @@ class UnslothTrainer:
dtype = None,
auto_model = CsmForConditionalGeneration,
load_in_4bit = False,
device_map = device_map,
full_finetuning = full_finetuning,
token = hf_token,
trust_remote_code = trust_remote_code,
@ -651,6 +673,7 @@ class UnslothTrainer:
model_name = model_name,
dtype = None,
load_in_4bit = False,
device_map = device_map,
full_finetuning = full_finetuning,
auto_model = WhisperForConditionalGeneration,
whisper_language = "English",
@ -672,6 +695,7 @@ class UnslothTrainer:
max_seq_length = max_seq_length,
dtype = None,
load_in_4bit = load_in_4bit,
device_map = device_map,
full_finetuning = full_finetuning,
token = hf_token,
trust_remote_code = trust_remote_code,
@ -711,6 +735,7 @@ class UnslothTrainer:
max_seq_length = max_seq_length,
dtype = torch.float32, # Spark-TTS requires float32
load_in_4bit = False,
device_map = device_map,
full_finetuning = full_finetuning,
token = hf_token,
trust_remote_code = trust_remote_code,
@ -725,6 +750,7 @@ class UnslothTrainer:
model_name,
max_seq_length = max_seq_length,
load_in_4bit = False,
device_map = device_map,
full_finetuning = full_finetuning,
token = hf_token,
trust_remote_code = trust_remote_code,
@ -741,6 +767,7 @@ class UnslothTrainer:
max_seq_length = max_seq_length,
dtype = None,
load_in_4bit = load_in_4bit,
device_map = device_map,
full_finetuning = full_finetuning,
token = hf_token,
trust_remote_code = trust_remote_code,
@ -754,6 +781,7 @@ class UnslothTrainer:
max_seq_length = max_seq_length,
dtype = None, # Auto-detect
load_in_4bit = load_in_4bit,
device_map = device_map,
full_finetuning = full_finetuning,
token = hf_token,
trust_remote_code = trust_remote_code,
@ -786,12 +814,15 @@ class UnslothTrainer:
max_seq_length = max_seq_length,
dtype = None, # Auto-detect
load_in_4bit = load_in_4bit,
device_map = device_map,
full_finetuning = full_finetuning,
token = hf_token,
trust_remote_code = trust_remote_code,
)
logger.info("Loaded text model")
raise_if_offloaded(self.model, device_map, "Studio training")
if self.should_stop:
return False
@ -824,6 +855,7 @@ class UnslothTrainer:
is_dataset_audio = is_dataset_audio,
trust_remote_code = trust_remote_code,
full_finetuning = full_finetuning,
gpu_ids = gpu_ids,
)
error_msg = str(e)
error_lower = error_msg.lower()
@ -2634,14 +2666,14 @@ class UnslothTrainer:
eval_steps: float = 0.00,
output_dir: str | None = None,
num_epochs: int = 3,
learning_rate: float = 5e-5,
learning_rate: float = 2e-4,
batch_size: int = 2,
gradient_accumulation_steps: int = 4,
warmup_steps: int = None,
warmup_ratio: float = None,
max_steps: int = 0,
save_steps: int = 0,
weight_decay: float = 0.01,
weight_decay: float = 0.001,
random_seed: int = 3407,
packing: bool = False,
train_on_completions: bool = False,
@ -3010,7 +3042,7 @@ class UnslothTrainer:
"fp16": not is_bfloat16_supported(),
"bf16": is_bfloat16_supported(),
"logging_steps": 1,
"weight_decay": training_args.get("weight_decay", 0.01),
"weight_decay": training_args.get("weight_decay", 0.001),
"seed": training_args.get("random_seed", 3407),
"output_dir": output_dir,
"report_to": _build_report_targets(training_args),

View file

@ -14,18 +14,21 @@ worker's mp.Queue, and exposes the same API surface to routes/training.py.
Pattern follows core/data_recipe/jobs/manager.py.
"""
import json as _json
import math
import multiprocessing as mp
import queue
import threading
import time
import structlog
from datetime import datetime, timezone
from loggers import get_logger
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Tuple, Any
import matplotlib.pyplot as plt
from utils.hardware import prepare_gpu_selection
logger = get_logger(__name__)
@ -44,8 +47,8 @@ class TrainingProgress:
epoch: float = 0
step: int = 0
total_steps: int = 0
loss: float = 0.0
learning_rate: float = 0.0
loss: Optional[float] = None
learning_rate: Optional[float] = None
is_training: bool = False
is_completed: bool = False
error: Optional[str] = None
@ -63,6 +66,8 @@ class TrainingBackend:
Launches a fresh subprocess per training job, communicates via mp.Queue.
"""
FLUSH_THRESHOLD: int = 10
def __init__(self):
# Subprocess state
self._proc: Optional[mp.Process] = None
@ -91,13 +96,21 @@ class TrainingBackend:
self.current_job_id: Optional[str] = None
self._output_dir: Optional[str] = None
# DB persistence
self._metric_buffer: list[dict] = []
self._run_finalized: bool = False
self._db_run_created: bool = False
self._db_total_steps_set: bool = False
self._db_config: Optional[dict] = None
self._db_started_at: Optional[str] = None
logger.info("TrainingBackend initialized (subprocess mode)")
# ------------------------------------------------------------------
# Public API (called by routes/training.py)
# ------------------------------------------------------------------
def start_training(self, **kwargs) -> bool:
def start_training(self, job_id: str, **kwargs) -> bool:
"""Spawn a subprocess to run the full training pipeline.
All kwargs are serialized into a config dict and sent to the worker.
@ -108,30 +121,16 @@ class TrainingBackend:
logger.warning("Training subprocess already running")
return False
# Join prior pump thread to prevent it from consuming events
# from the new job's queue (it reads self._event_queue dynamically).
# Join prior pump thread — refuse to start if it won't die
if self._pump_thread is not None and self._pump_thread.is_alive():
self._pump_thread.join(timeout = 5.0)
if self._pump_thread.is_alive():
logger.warning("Previous pump thread did not exit within 5s")
logger.warning(
"Previous pump thread did not exit within 5s — refusing to start"
)
return False
self._pump_thread = None
# Reset state
self._should_stop = False
self._cancel_requested = False
self._progress = TrainingProgress(
is_training = True, status_message = "Initializing training..."
)
self.loss_history.clear()
self.lr_history.clear()
self.step_history.clear()
self.grad_norm_history.clear()
self.grad_norm_step_history.clear()
self.eval_loss_history.clear()
self.eval_step_history.clear()
self.eval_enabled = False
self._output_dir = None
# Build config dict for the subprocess
config = {
"model_name": kwargs["model_name"],
@ -161,7 +160,7 @@ class TrainingBackend:
"warmup_ratio": kwargs.get("warmup_ratio"),
"max_steps": kwargs.get("max_steps", 0),
"save_steps": kwargs.get("save_steps", 0),
"weight_decay": kwargs.get("weight_decay", 0.01),
"weight_decay": kwargs.get("weight_decay", 0.001),
"random_seed": kwargs.get("random_seed", 3407),
"packing": kwargs.get("packing", False),
"optim": kwargs.get("optim", "adamw_8bit"),
@ -187,29 +186,85 @@ class TrainingBackend:
"enable_tensorboard": kwargs.get("enable_tensorboard", False),
"tensorboard_dir": kwargs.get("tensorboard_dir", "runs"),
"trust_remote_code": kwargs.get("trust_remote_code", False),
"gpu_ids": kwargs.get("gpu_ids"),
}
# Derive load_in_4bit from training_type
if config["training_type"] != "LoRA/QLoRA":
config["load_in_4bit"] = False
# Spawn subprocess
# Spawn subprocess — use locals so state is untouched on failure
resolved_gpu_ids, gpu_selection = prepare_gpu_selection(
kwargs.get("gpu_ids"),
model_name = config["model_name"],
hf_token = config["hf_token"] or None,
training_type = config["training_type"],
load_in_4bit = config["load_in_4bit"],
batch_size = config.get("batch_size", 4),
max_seq_length = config.get("max_seq_length", 2048),
lora_rank = config.get("lora_r", 16),
target_modules = config.get("target_modules"),
gradient_checkpointing = config.get("gradient_checkpointing", "unsloth"),
optimizer = config.get("optim", "adamw_8bit"),
)
config["resolved_gpu_ids"] = resolved_gpu_ids
config["gpu_selection"] = gpu_selection
from .worker import run_training_process
self._event_queue = _CTX.Queue()
self._stop_queue = _CTX.Queue()
event_queue = _CTX.Queue()
stop_queue = _CTX.Queue()
self._proc = _CTX.Process(
proc = _CTX.Process(
target = run_training_process,
kwargs = {
"event_queue": self._event_queue,
"stop_queue": self._stop_queue,
"event_queue": event_queue,
"stop_queue": stop_queue,
"config": config,
},
daemon = True,
)
self._proc.start()
logger.info("Training subprocess started (pid=%s)", self._proc.pid)
try:
proc.start()
except Exception:
logger.error("Failed to start training subprocess", exc_info = True)
return False
logger.info("Training subprocess started (pid=%s)", proc.pid)
# Reset state — safe because old pump thread is confirmed dead
# and proc.start() succeeded
self.current_job_id = job_id
self._should_stop = False
self._cancel_requested = False
self._progress = TrainingProgress(
is_training = True, status_message = "Initializing training..."
)
self.loss_history.clear()
self.lr_history.clear()
self.step_history.clear()
self.grad_norm_history.clear()
self.grad_norm_step_history.clear()
self.eval_loss_history.clear()
self.eval_step_history.clear()
self.eval_enabled = False
self._output_dir = None
self._metric_buffer.clear()
self._run_finalized = False
self._db_run_created = False
self._db_total_steps_set = False
self._db_config = {
k: v for k, v in config.items() if k not in {"hf_token", "wandb_token"}
}
self._db_started_at = datetime.now(timezone.utc).isoformat()
# Assign subprocess handles after state reset
self._event_queue = event_queue
self._stop_queue = stop_queue
self._proc = proc
# Eagerly create DB run row so the run appears in history during model loading
self._ensure_db_run_created()
# Start event pump thread
self._pump_thread = threading.Thread(target = self._pump_loop, daemon = True)
@ -252,6 +307,11 @@ class TrainingBackend:
proc.kill()
proc.join(timeout = 2.0)
# Wait for pump thread to finish DB finalization before returning
# (8s covers SQLite's default 5s lock timeout plus execution overhead)
if self._pump_thread is not None and self._pump_thread.is_alive():
self._pump_thread.join(timeout = 8.0)
def is_training_active(self) -> bool:
"""Check if training is currently active."""
with self._lock:
@ -389,20 +449,54 @@ class TrainingBackend:
self._progress.error
or "Training process exited unexpectedly"
)
self._ensure_db_run_created()
self._finalize_run_in_db(
status = "stopped" if self._should_stop else "error",
error_message = None
if self._should_stop
else "Training process terminated unexpectedly",
)
return
def _handle_event(self, event: dict) -> None:
"""Apply a subprocess event to local state."""
"""Apply a subprocess event to local state.
State updates happen inside self._lock; DB I/O happens after
releasing it so status-polling API endpoints are never blocked
by slow SQLite writes.
"""
etype = event.get("type")
db_action: Optional[str] = None
db_action_kwargs: dict = {}
with self._lock:
if etype == "progress":
self._progress.step = event.get("step", self._progress.step)
self._progress.epoch = event.get("epoch", self._progress.epoch)
self._progress.loss = event.get("loss", self._progress.loss)
self._progress.learning_rate = event.get(
"learning_rate", self._progress.learning_rate
)
# loss/lr are sanitized below; update progress after coercion
_raw_loss = event.get("loss")
_raw_lr = event.get("learning_rate")
try:
_safe_loss = float(_raw_loss) if _raw_loss is not None else None
except (TypeError, ValueError):
logger.debug("Could not convert loss to float: %s", _raw_loss)
_safe_loss = None
if _safe_loss is not None and not math.isfinite(_safe_loss):
_safe_loss = None
try:
_safe_lr = float(_raw_lr) if _raw_lr is not None else None
except (TypeError, ValueError):
logger.debug(
"Could not convert learning_rate to float: %s", _raw_lr
)
_safe_lr = None
if _safe_lr is not None and not math.isfinite(_safe_lr):
_safe_lr = None
if _safe_loss is not None:
self._progress.loss = _safe_loss
if _safe_lr is not None:
self._progress.learning_rate = _safe_lr
self._progress.total_steps = event.get(
"total_steps", self._progress.total_steps
)
@ -416,30 +510,85 @@ class TrainingBackend:
if status:
self._progress.status_message = status
# Update metric histories
# Update metric histories — reuse sanitized values from above
step = event.get("step", 0)
loss = event.get("loss", 0.0)
lr = event.get("learning_rate", 0.0)
if step >= 0 and loss > 0:
loss = _safe_loss
lr = _safe_lr
if step > 0 and loss is not None:
self.loss_history.append(loss)
self.lr_history.append(lr)
self.lr_history.append(lr if lr is not None else 0.0)
self.step_history.append(step)
grad_norm = event.get("grad_norm")
gn = None
if grad_norm is not None:
try:
gn = float(grad_norm)
except (TypeError, ValueError):
gn = None
if gn is not None and math.isfinite(gn):
if step > 0 and gn is not None and math.isfinite(gn):
self.grad_norm_history.append(gn)
self.grad_norm_step_history.append(step)
else:
gn = None
eval_loss = event.get("eval_loss")
if eval_loss is not None:
self.eval_loss_history.append(eval_loss)
self.eval_step_history.append(step)
self.eval_enabled = True
try:
eval_loss = float(eval_loss)
except (TypeError, ValueError):
logger.debug(
"Could not convert eval_loss to float: %s", eval_loss
)
eval_loss = None
if step > 0 and eval_loss is not None and math.isfinite(eval_loss):
self.eval_loss_history.append(eval_loss)
self.eval_step_history.append(step)
self.eval_enabled = True
else:
eval_loss = None
# Buffer metric for DB flush (loss/lr already sanitized above)
self._metric_buffer.append(
{
"step": step,
"loss": loss,
"learning_rate": lr,
"grad_norm": gn,
"eval_loss": eval_loss,
"epoch": event.get("epoch"),
"num_tokens": event.get("num_tokens"),
"elapsed_seconds": event.get("elapsed_seconds"),
}
)
# Decide which DB action to take after releasing the lock
if not self._db_run_created and self.current_job_id and self._db_config:
db_action = "create_run"
db_action_kwargs = {
"job_id": self.current_job_id,
"model_name": self._db_config["model_name"],
"dataset_name": self._db_config.get("hf_dataset")
or next(
iter(self._db_config.get("local_datasets") or []), "unknown"
),
"config_json": _json.dumps(self._db_config),
"started_at": self._db_started_at
or datetime.now(timezone.utc).isoformat(),
"total_steps": event.get("total_steps"),
}
elif (
event.get("total_steps")
and self._db_run_created
and not self._db_total_steps_set
):
db_action = "update_total_steps"
db_action_kwargs = {
"job_id": self.current_job_id,
"total_steps": event["total_steps"],
}
elif len(self._metric_buffer) >= self.FLUSH_THRESHOLD:
db_action = "flush"
elif etype == "eval_configured":
self.eval_enabled = True
@ -454,6 +603,14 @@ class TrainingBackend:
self._output_dir = event.get("output_dir")
msg = event.get("status_message", "Training completed")
self._progress.status_message = msg
if not self._db_run_created and self.current_job_id and self._db_config:
db_action = "create_and_finalize"
else:
db_action = "finalize"
db_action_kwargs = {
"status": "stopped" if self._should_stop else "completed",
"output_dir": self._output_dir,
}
elif etype == "error":
self._progress.is_training = False
@ -462,6 +619,149 @@ class TrainingBackend:
stack = event.get("stack", "")
if stack:
logger.error("Stack trace:\n%s", stack)
if not self._db_run_created and self.current_job_id and self._db_config:
db_action = "create_and_finalize"
else:
db_action = "finalize"
db_action_kwargs = {
"status": "stopped" if self._should_stop else "error",
"error_message": event.get("error", "Unknown error"),
}
# --- DB I/O outside the lock ---
if db_action == "create_run":
try:
from storage.studio_db import create_run
create_run(
id = db_action_kwargs["job_id"],
model_name = db_action_kwargs["model_name"],
dataset_name = db_action_kwargs["dataset_name"],
config_json = db_action_kwargs["config_json"],
started_at = db_action_kwargs["started_at"],
total_steps = db_action_kwargs["total_steps"],
)
self._db_run_created = True
if db_action_kwargs["total_steps"]:
self._db_total_steps_set = True
except Exception:
logger.warning("Failed to create DB run record", exc_info = True)
elif db_action == "create_and_finalize":
self._ensure_db_run_created()
self._finalize_run_in_db(**db_action_kwargs)
elif db_action == "update_total_steps":
try:
from storage.studio_db import update_run_total_steps
update_run_total_steps(
db_action_kwargs["job_id"], db_action_kwargs["total_steps"]
)
self._db_total_steps_set = True
except Exception:
logger.warning("Failed to update total_steps in DB", exc_info = True)
elif db_action == "flush":
self._flush_metrics_to_db()
elif db_action == "finalize":
self._finalize_run_in_db(**db_action_kwargs)
def _ensure_db_run_created(self) -> None:
"""Create the DB row if it doesn't exist yet. Called outside the lock."""
if self._db_run_created or not self.current_job_id or not self._db_config:
return
try:
from storage.studio_db import create_run
dataset_name = self._db_config.get("hf_dataset") or next(
iter(self._db_config.get("local_datasets") or []), "unknown"
)
create_run(
id = self.current_job_id,
model_name = self._db_config["model_name"],
dataset_name = dataset_name,
config_json = _json.dumps(self._db_config),
started_at = self._db_started_at
or datetime.now(timezone.utc).isoformat(),
total_steps = self._progress.total_steps or None,
)
self._db_run_created = True
except Exception:
logger.warning(
"Failed to create DB run record for early failure", exc_info = True
)
def _finalize_run_in_db(
self,
status: str,
error_message: Optional[str] = None,
output_dir: Optional[str] = None,
) -> None:
"""Flush remaining metrics and mark a run as finished in the DB."""
if not self.current_job_id or not self._db_run_created or self._run_finalized:
return
self._flush_metrics_to_db()
try:
from storage.studio_db import finish_run
from utils.downsample import downsample
sparkline = downsample(self.loss_history, 50)
finish_run(
id = self.current_job_id,
status = status,
ended_at = datetime.now(timezone.utc).isoformat(),
final_step = self._progress.step,
final_loss = self._progress.loss
if (
self._progress.loss is not None
and math.isfinite(self._progress.loss)
)
else None,
duration_seconds = self._progress.elapsed_seconds,
loss_sparkline = _json.dumps(sparkline),
output_dir = output_dir,
error_message = error_message,
)
self._run_finalized = True
except Exception:
logger.warning(
"Failed to finalize run in DB (status=%s)", status, exc_info = True
)
def _flush_metrics_to_db(self) -> None:
"""Flush buffered metrics to the database and update live progress."""
if (
not self._metric_buffer
or not self.current_job_id
or not self._db_run_created
):
return
# Cap buffer to prevent unbounded memory growth
if len(self._metric_buffer) > 500:
logger.warning(
"Metric buffer exceeded 500 entries (%d) — trimming oldest",
len(self._metric_buffer),
)
self._metric_buffer = self._metric_buffer[-500:]
# Snapshot before insert so metrics arriving during the write are preserved
batch = list(self._metric_buffer)
try:
from storage.studio_db import insert_metrics_batch, update_run_progress
insert_metrics_batch(self.current_job_id, batch)
del self._metric_buffer[: len(batch)]
update_run_progress(
id = self.current_job_id,
step = self._progress.step,
loss = self._progress.loss
if (
self._progress.loss is not None
and math.isfinite(self._progress.loss)
)
else None,
duration_seconds = self._progress.elapsed_seconds,
)
except Exception:
# Leave buffer intact for retry on next flush
logger.warning("Failed to flush metrics to DB", exc_info = True)
@staticmethod
def _read_queue(q: Any, timeout_sec: float) -> Optional[dict]:
@ -561,11 +861,13 @@ class TrainingBackend:
if progress.error:
title = f"Error: {progress.error}"
elif progress.is_completed:
title = f"Training completed! Final loss: {progress.loss:.4f}"
loss_str = f"{progress.loss:.4f}" if progress.loss is not None else "--"
title = f"Training completed! Final loss: {loss_str}"
elif progress.status_message:
title = progress.status_message
elif progress.step > 0:
title = f"Epoch: {progress.epoch} | Step: {progress.step}/{progress.total_steps} | Loss: {progress.loss:.4f}"
loss_str = f"{progress.loss:.4f}" if progress.loss is not None else "--"
title = f"Epoch: {progress.epoch} | Step: {progress.step}/{progress.total_steps} | Loss: {loss_str}"
else:
title = "Training Loss"

View file

@ -16,47 +16,315 @@ from __future__ import annotations
import structlog
from loggers import get_logger
import os
import shutil
import sys
import time
import traceback
import subprocess as _sp
from pathlib import Path
from typing import Any
from typing import Any, Callable
logger = get_logger(__name__)
from utils.hardware import apply_gpu_ids
from utils.wheel_utils import (
direct_wheel_url,
flash_attn_wheel_url,
install_wheel,
probe_torch_wheel_env,
url_exists,
)
_CAUSAL_CONV1D_RELEASE_TAG = "v1.6.1.post4"
_CAUSAL_CONV1D_PACKAGE_VERSION = "1.6.1"
_MAMBA_SSM_RELEASE_TAG = "v2.3.1"
_MAMBA_SSM_PACKAGE_VERSION = "2.3.1"
_FLASH_ATTN_RUNTIME_MIN_SEQ_LEN = 32768
_FLASH_ATTN_SKIP_ENV = "UNSLOTH_STUDIO_SKIP_FLASHATTN_INSTALL"
def _model_wants_causal_conv1d(model_name: str) -> bool:
name = model_name.lower()
return any(
key in name
for key in (
"qwen3.5",
"qwen3_5",
"qwen3-next",
"qwen3_next",
"nemotron_h",
"nemotron-h",
"nemotron-3-nano",
"falcon_h1",
"falcon-h1",
"granite-4.0-h",
"granitemoehybrid",
"lfm2",
)
)
def _install_package_wheel_first(
*,
event_queue: Any,
import_name: str,
display_name: str,
pypi_name: str,
pypi_version: str | None = None,
filename_prefix: str | None = None,
release_tag: str | None = None,
release_base_url: str | None = None,
wheel_url_builder: Callable[[dict[str, str] | None], str | None] | None = None,
pypi_spec: str | None = None,
pypi_status_message: str | None = None,
) -> bool:
try:
__import__(import_name)
logger.info("%s already installed", display_name)
return True
except ImportError:
pass
env = probe_torch_wheel_env(timeout = 30)
if wheel_url_builder is not None:
wheel_url = wheel_url_builder(env)
else:
wheel_url = direct_wheel_url(
filename_prefix = filename_prefix,
package_version = pypi_version,
release_tag = release_tag,
release_base_url = release_base_url,
env = env,
)
if wheel_url is None:
logger.info("No compatible %s wheel candidate", display_name)
elif url_exists(wheel_url):
_send_status(event_queue, f"Installing prebuilt {display_name} wheel...")
for installer, result in install_wheel(
wheel_url,
python_executable = sys.executable,
use_uv = bool(shutil.which("uv")),
run = _sp.run,
):
if result.returncode == 0:
logger.info("Installed prebuilt %s wheel successfully", display_name)
return True
logger.warning(
"%s failed to install %s wheel:\n%s",
installer,
display_name,
result.stdout,
)
else:
logger.info("No published %s wheel found: %s", display_name, wheel_url)
is_hip = env and env.get("hip_version")
if is_hip and not shutil.which("hipcc"):
logger.error(
"%s requires hipcc for source compilation on ROCm. "
"Install the ROCm HIP SDK: https://rocm.docs.amd.com",
display_name,
)
_send_status(
event_queue,
f"{display_name}: hipcc not found (ROCm HIP SDK required)",
)
return False
if pypi_spec is None:
pypi_spec = f"{pypi_name}=={pypi_version}"
if pypi_status_message is None:
if is_hip:
pypi_status_message = (
f"Compiling {display_name} from source for ROCm "
"(this may take several minutes)..."
)
else:
pypi_status_message = f"Installing {display_name} from PyPI..."
_send_status(event_queue, pypi_status_message)
# Prefer uv for faster dependency resolution when available
plain_pypi_install = pypi_version is None
if plain_pypi_install:
if shutil.which("uv"):
pypi_cmd = [
"uv",
"pip",
"install",
"--python",
sys.executable,
pypi_spec,
]
else:
pypi_cmd = [sys.executable, "-m", "pip", "install", pypi_spec]
else:
if shutil.which("uv"):
pypi_cmd = [
"uv",
"pip",
"install",
"--python",
sys.executable,
"--no-build-isolation",
"--no-deps",
]
# Avoid stale cache artifacts from partial HIP source builds
if is_hip:
pypi_cmd.append("--no-cache")
pypi_cmd.append(pypi_spec)
else:
pypi_cmd = [
sys.executable,
"-m",
"pip",
"install",
"--no-build-isolation",
"--no-deps",
"--no-cache-dir",
pypi_spec,
]
# Source compilation on ROCm can take 10-30 minutes; use a generous
# timeout. Non-HIP installs preserve the pre-existing "no timeout"
# behaviour so unrelated slow installs (e.g. causal-conv1d source
# build on Linux aarch64 or unsupported torch/CUDA combinations)
# are not aborted at 5 minutes by this PR.
_run_kwargs: dict[str, Any] = {
"stdout": _sp.PIPE,
"stderr": _sp.STDOUT,
"text": True,
}
if is_hip:
_run_kwargs["timeout"] = 1800
try:
result = _sp.run(pypi_cmd, **_run_kwargs)
except _sp.TimeoutExpired:
logger.error(
"%s installation timed out after %ds",
display_name,
_run_kwargs.get("timeout"),
)
_send_status(
event_queue,
f"{display_name} installation timed out after "
f"{_run_kwargs.get('timeout')}s",
)
return False
if result.returncode != 0:
if is_hip:
# Surface a clear error for ROCm source build failures
error_lines = (result.stdout or "").strip().splitlines()
snippet = "\n".join(error_lines[-5:]) if error_lines else "(no output)"
logger.error(
"Failed to compile %s for ROCm:\n%s",
display_name,
result.stdout,
)
_send_status(
event_queue,
f"Failed to compile {display_name} for ROCm. "
"Check that hipcc and ROCm development headers are installed.\n"
f"{snippet}",
)
else:
logger.error(
"Failed to install %s from PyPI:\n%s",
display_name,
result.stdout,
)
return False
if is_hip:
logger.info("Compiled and installed %s from source for ROCm", display_name)
else:
logger.info("Installed %s from PyPI", display_name)
return True
def _ensure_causal_conv1d_fast_path(event_queue: Any, model_name: str) -> None:
if not _model_wants_causal_conv1d(model_name):
return
_install_package_wheel_first(
event_queue = event_queue,
import_name = "causal_conv1d",
display_name = "causal-conv1d",
pypi_name = "causal-conv1d",
pypi_version = _CAUSAL_CONV1D_PACKAGE_VERSION,
filename_prefix = "causal_conv1d",
release_tag = _CAUSAL_CONV1D_RELEASE_TAG,
release_base_url = "https://github.com/Dao-AILab/causal-conv1d/releases/download",
)
_SSM_MODEL_SUBSTRINGS = (
"nemotron_h",
"nemotron-h",
"nemotron-3-nano",
"falcon_h1",
"falcon-h1",
"granite-4.0-h",
"granitemoehybrid",
)
def _ensure_mamba_ssm(event_queue: Any, model_name: str) -> None:
if not any(sub in model_name.lower() for sub in _SSM_MODEL_SUBSTRINGS):
return
logger.info("SSM model detected; setting up mamba-ssm after causal-conv1d")
_install_package_wheel_first(
event_queue = event_queue,
import_name = "mamba_ssm",
display_name = "mamba-ssm",
pypi_name = "mamba-ssm",
pypi_version = _MAMBA_SSM_PACKAGE_VERSION,
filename_prefix = "mamba_ssm",
release_tag = _MAMBA_SSM_RELEASE_TAG,
release_base_url = "https://github.com/state-spaces/mamba/releases/download",
)
def _should_try_runtime_flash_attn_install(max_seq_length: int) -> bool:
if os.getenv(_FLASH_ATTN_SKIP_ENV) == "1":
return False
if max_seq_length < _FLASH_ATTN_RUNTIME_MIN_SEQ_LEN:
return False
return sys.platform.startswith("linux")
def _ensure_flash_attn_for_long_context(event_queue: Any, max_seq_length: int) -> None:
if not _should_try_runtime_flash_attn_install(max_seq_length):
return
installed = _install_package_wheel_first(
event_queue = event_queue,
import_name = "flash_attn",
display_name = "flash-attn",
pypi_name = "flash-attn",
wheel_url_builder = flash_attn_wheel_url,
pypi_spec = "flash-attn",
pypi_status_message = "Installing flash-attn from PyPI for long-context training...",
)
if not installed:
_send_status(event_queue, "Continuing without flash-attn")
def _activate_transformers_version(model_name: str) -> None:
"""Activate the correct transformers version BEFORE any ML imports.
If the model needs transformers 5.x, prepend the pre-installed .venv_t5/
directory to sys.path. Otherwise do nothing (default 4.57.x in .venv/).
"""
"""Activate the correct transformers version BEFORE any ML imports."""
# Ensure backend is on path for utils imports
backend_path = str(Path(__file__).resolve().parent.parent.parent)
if backend_path not in sys.path:
sys.path.insert(0, backend_path)
from utils.transformers_version import (
needs_transformers_5,
_resolve_base_model,
_ensure_venv_t5_exists,
_VENV_T5_DIR,
)
from utils.transformers_version import activate_transformers_for_subprocess
resolved = _resolve_base_model(model_name)
if needs_transformers_5(resolved):
if not _ensure_venv_t5_exists():
raise RuntimeError(
f"Cannot activate transformers 5.x: .venv_t5 missing at {_VENV_T5_DIR}"
)
if _VENV_T5_DIR not in sys.path:
sys.path.insert(0, _VENV_T5_DIR)
logger.info("Activated transformers 5.x from %s", _VENV_T5_DIR)
# Propagate to child subprocesses (e.g. GGUF converter)
_pp = os.environ.get("PYTHONPATH", "")
os.environ["PYTHONPATH"] = _VENV_T5_DIR + (os.pathsep + _pp if _pp else "")
else:
logger.info("Using default transformers (4.57.x) for %s", model_name)
activate_transformers_for_subprocess(model_name)
def run_training_process(
@ -88,6 +356,8 @@ def run_training_process(
env = os.getenv("ENVIRONMENT_TYPE", "production"),
)
apply_gpu_ids(config.get("resolved_gpu_ids"))
model_name = config["model_name"]
# ── 1. Activate correct transformers version BEFORE any ML imports ──
@ -104,62 +374,47 @@ def run_training_process(
)
return
# ── 1a. Auto-enable trust_remote_code for unsloth/* transformers 5.x models ──
# Some newer architectures (e.g. NemotronH) have config parsing bugs in
# transformers that require trust_remote_code=True as a workaround.
# Only auto-enable for unsloth/* prefixed models (trusted source).
from utils.transformers_version import needs_transformers_5
# ── 1a. Auto-enable trust_remote_code for NemotronH/Nano models ──
# NemotronH has config parsing bugs in transformers that require
# trust_remote_code=True as a workaround. Other transformers 5.x models
# (Qwen3.5, Gemma 4, etc.) are native and do NOT need it — enabling it
# bypasses the compiler (disabling fused CE).
# NOTE: Must NOT match Llama-Nemotron (standard Llama architecture).
_NEMOTRON_TRUST_SUBSTRINGS = ("nemotron_h", "nemotron-h", "nemotron-3-nano")
_lowered = model_name.lower()
if (
needs_transformers_5(model_name)
and model_name.lower().startswith("unsloth/")
any(sub in _lowered for sub in _NEMOTRON_TRUST_SUBSTRINGS)
and (_lowered.startswith("unsloth/") or _lowered.startswith("nvidia/"))
and not config.get("trust_remote_code", False)
):
config["trust_remote_code"] = True
logger.info(
"Auto-enabled trust_remote_code for unsloth/* transformers 5.x model: %s",
"Auto-enabled trust_remote_code for Nemotron model: %s",
model_name,
)
# ── 1b. Auto-install mamba-ssm for SSM/hybrid models (NemotronH, Falcon-H1) ──
_SSM_MODEL_SUBSTRINGS = ("nemotron_h", "nemotron-3-nano", "falcon_h1", "falcon-h1")
if any(sub in model_name.lower() for sub in _SSM_MODEL_SUBSTRINGS):
try:
import mamba_ssm # noqa: F401
logger.info("mamba-ssm already installed")
except ImportError:
logger.info(
"SSM model detected — installing mamba-ssm and causal-conv1d (this may take several minutes)..."
)
_send_status(
event_queue, "Installing mamba-ssm (first time only, ~7 min)..."
)
import subprocess as _sp
# --no-build-isolation: compile against current torch (no version conflicts)
# --no-deps: don't pull in torch/transformers/triton (already installed)
for _pkg in ["causal_conv1d", "mamba_ssm"]:
_r = _sp.run(
[
sys.executable,
"-m",
"pip",
"install",
"--no-build-isolation",
"--no-deps",
"--no-cache-dir",
_pkg,
],
stdout = _sp.PIPE,
stderr = _sp.STDOUT,
text = True,
)
if _r.returncode != 0:
logger.error("Failed to install %s:\n%s", _pkg, _r.stdout)
else:
logger.info("Installed %s successfully", _pkg)
logger.info("mamba-ssm installation complete")
# ── 1b. Set up causal-conv1d first, then install mamba-ssm if needed ──
try:
_ensure_causal_conv1d_fast_path(event_queue, model_name)
_ensure_mamba_ssm(event_queue, model_name)
_ensure_flash_attn_for_long_context(
event_queue,
int(config.get("max_seq_length", 2048)),
)
except Exception as exc:
event_queue.put(
{
"type": "error",
"error": (
f"Please choose another model to train, since "
f"causal-conv1d / mamba-ssm failed to install "
f"with error: {exc}"
),
"stack": traceback.format_exc(limit = 20),
"ts": time.time(),
}
)
return
# ── 1c. Set fork start method so dataset.map() can multiprocess ──
# The parent launched us via spawn (clean process), but the compiled
@ -242,7 +497,7 @@ def run_training_process(
# Wire up progress callback → event_queue
def _on_progress(progress: TrainingProgress):
has_train_loss = progress.step >= 0 and progress.loss > 0
has_train_loss = progress.step > 0 and progress.loss is not None
has_eval_loss = progress.eval_loss is not None
if has_train_loss or has_eval_loss:
event_queue.put(
@ -424,6 +679,7 @@ def run_training_process(
is_dataset_image = config.get("is_dataset_image", False),
is_dataset_audio = config.get("is_dataset_audio", False),
trust_remote_code = config.get("trust_remote_code", False),
gpu_ids = config.get("resolved_gpu_ids"),
)
if not success or trainer.should_stop:
if trainer.should_stop:
@ -533,7 +789,7 @@ def run_training_process(
warmup_ratio = config.get("warmup_ratio"),
max_steps = max_steps if max_steps and max_steps > 0 else 0,
save_steps = save_steps if save_steps and save_steps > 0 else 0,
weight_decay = config.get("weight_decay", 0.01),
weight_decay = config.get("weight_decay", 0.001),
random_seed = config.get("random_seed", 3407),
packing = config.get("packing", False),
train_on_completions = config.get("train_on_completions", False),
@ -879,7 +1135,7 @@ def _run_embedding_training(event_queue: Any, stop_queue: Any, config: dict) ->
"lr_scheduler_type": config.get("lr_scheduler_type", "linear"),
"batch_sampler": BatchSamplers.NO_DUPLICATES,
"optim": config.get("optim", "adamw_8bit"),
"weight_decay": config.get("weight_decay", 0.01),
"weight_decay": config.get("weight_decay", 0.001),
"seed": config.get("random_seed", 3407),
}
@ -918,7 +1174,7 @@ def _run_embedding_training(event_queue: Any, stop_queue: Any, config: dict) ->
def on_log(self, args, state, control, logs = None, **kwargs):
if not logs:
return
loss_value = logs.get("loss", logs.get("train_loss", 0.0))
loss_value = logs.get("loss", logs.get("train_loss", None))
current_step = state.global_step
elapsed = time.time() - training_start_time
@ -934,7 +1190,7 @@ def _run_embedding_training(event_queue: Any, stop_queue: Any, config: dict) ->
"step": current_step,
"epoch": round(state.epoch, 2) if state.epoch else 0,
"loss": loss_value,
"learning_rate": logs.get("learning_rate", 0.0),
"learning_rate": logs.get("learning_rate", None),
"total_steps": total_steps,
"elapsed_seconds": elapsed,
"eta_seconds": eta,

View file

@ -23,9 +23,23 @@ if _backend_dir not in sys.path:
# See: https://github.com/python/cpython/issues/102396
import _platform_compat # noqa: F401
import mimetypes
import shutil
import warnings
from contextlib import asynccontextmanager
from importlib.metadata import PackageNotFoundError, version as package_version
# Fix broken Windows registry MIME types. Some Windows installs map .js to
# "text/plain" in the registry (HKCR\.js\Content Type). Python's mimetypes
# module reads from the registry, and FastAPI/Starlette's StaticFiles uses
# mimetypes.guess_type() to set Content-Type headers. Browsers enforce strict
# MIME checking for ES module scripts (<script type="module">) and will refuse
# to execute .js files served as text/plain — resulting in a blank page.
# Calling add_type() *before* StaticFiles is instantiated ensures the correct
# types are used regardless of the OS registry.
if sys.platform == "win32":
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
# Suppress annoying dependency warnings in production
if os.getenv("ENVIRONMENT_TYPE", "production") == "production":
@ -34,7 +48,7 @@ if os.getenv("ENVIRONMENT_TYPE", "production") == "production":
# warnings.filterwarnings("ignore", category=DeprecationWarning)
# warnings.filterwarnings("ignore", module="triton.*")
from fastapi import FastAPI
from fastapi import Depends, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, HTMLResponse, Response
@ -49,15 +63,43 @@ from routes import (
export_router,
inference_router,
models_router,
training_history_router,
training_router,
)
from auth import storage
from utils.hardware import detect_hardware, get_device, DeviceType
from auth.authentication import get_current_subject
from utils.hardware import (
detect_hardware,
get_device,
DeviceType,
get_backend_visible_gpu_info,
)
import utils.hardware.hardware as _hw_module
from utils.cache_cleanup import clear_unsloth_compiled_cache
def get_unsloth_version() -> str:
try:
return package_version("unsloth")
except PackageNotFoundError:
pass
version_file = (
_Path(__file__).resolve().parents[2] / "unsloth" / "models" / "_utils.py"
)
try:
for line in version_file.read_text(encoding = "utf-8").splitlines():
if line.startswith("__version__ = "):
return line.split("=", 1)[1].strip().strip('"').strip("'")
except OSError:
pass
return "dev"
UNSLOTH_VERSION = get_unsloth_version()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Startup: detect hardware, seed default admin if needed. Shutdown: clean up compiled cache."""
@ -73,6 +115,17 @@ async def lifespan(app: FastAPI):
# Detect hardware first — sets DEVICE global used everywhere
detect_hardware()
from storage.studio_db import cleanup_orphaned_runs
try:
cleanup_orphaned_runs()
except Exception as exc:
import structlog
structlog.get_logger(__name__).warning(
"cleanup_orphaned_runs failed at startup: %s", exc
)
# Pre-cache the helper GGUF model for LLM-assisted dataset detection.
# Runs in a background thread so it doesn't block server startup.
import threading
@ -90,13 +143,13 @@ async def lifespan(app: FastAPI):
if storage.ensure_default_admin():
bootstrap_pw = storage.get_bootstrap_password()
app.state.bootstrap_password = bootstrap_pw
bootstrap_path = storage.DB_PATH.parent / ".bootstrap_password"
print("\n" + "=" * 60)
print("DEFAULT ADMIN ACCOUNT CREATED")
print(
"Sign in with the seeded credentials and change the password immediately:\n"
)
print(f" username: {storage.DEFAULT_ADMIN_USERNAME}")
print(f" password: {bootstrap_pw}\n")
print(f" password saved to: {bootstrap_path}")
print(" Open the Studio UI to sign in and change it.")
print("=" * 60 + "\n")
else:
app.state.bootstrap_password = storage.get_bootstrap_password()
@ -109,7 +162,7 @@ async def lifespan(app: FastAPI):
# Create FastAPI app
app = FastAPI(
title = "Unsloth UI Backend",
version = "1.0.0",
version = UNSLOTH_VERSION,
description = "Backend API for Unsloth UI - Training and Model Management",
lifespan = lifespan,
)
@ -149,6 +202,9 @@ app.include_router(inference_router, prefix = "/v1", tags = ["openai-compat"])
app.include_router(datasets_router, prefix = "/api/datasets", tags = ["datasets"])
app.include_router(data_recipe_router, prefix = "/api/data-recipe", tags = ["data-recipe"])
app.include_router(export_router, prefix = "/api/export", tags = ["export"])
app.include_router(
training_history_router, prefix = "/api/train", tags = ["training-history"]
)
# ============ Health and System Endpoints ============
@ -164,78 +220,53 @@ async def health_check():
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"service": "Unsloth UI Backend",
"version": UNSLOTH_VERSION,
"device_type": device_type,
"chat_only": _hw_module.CHAT_ONLY,
}
@app.post("/api/shutdown")
async def shutdown_server(
request: Request,
current_subject: str = Depends(get_current_subject),
):
"""Gracefully shut down the Unsloth Studio server.
Called by the frontend quit dialog so users can stop the server from the UI
without needing to use the CLI or kill the process manually.
"""
import asyncio
async def _delayed_shutdown():
await asyncio.sleep(0.2) # Let the HTTP response return first
trigger = getattr(request.app.state, "trigger_shutdown", None)
if trigger is not None:
trigger()
else:
# Fallback when not launched via run_server() (e.g. direct uvicorn)
import signal
import os
os.kill(os.getpid(), signal.SIGTERM)
request.app.state._shutdown_task = asyncio.create_task(_delayed_shutdown())
return {"status": "shutting_down"}
@app.get("/api/system")
async def get_system_info():
"""Get system information"""
import platform
import subprocess
import psutil
from utils.hardware import get_device, get_gpu_memory_info, DeviceType
from utils.hardware import get_device
from utils.hardware.hardware import _backend_label
# GPU Info — query nvidia-smi for physical GPUs, filtered by
# CUDA_VISIBLE_DEVICES when set (the frontend uses this for GGUF
# fit estimation and llama-server respects CVD too).
import os
gpu_info: dict = {"available": False, "devices": []}
device = get_device()
if device == DeviceType.CUDA:
# Parse CUDA_VISIBLE_DEVICES allowlist
allowed_indices = None
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
if cvd is not None and cvd.strip():
try:
allowed_indices = set(int(x.strip()) for x in cvd.split(","))
except ValueError:
pass # Non-numeric (e.g. GPU-uuid), show all
try:
result = subprocess.run(
[
"nvidia-smi",
"--query-gpu=index,name,memory.total",
"--format=csv,noheader,nounits",
],
capture_output = True,
text = True,
timeout = 10,
)
if result.returncode == 0:
for line in result.stdout.strip().splitlines():
parts = [p.strip() for p in line.split(",")]
if len(parts) == 3:
idx = int(parts[0])
if allowed_indices is not None and idx not in allowed_indices:
continue
gpu_info["devices"].append(
{
"index": idx,
"name": parts[1],
"memory_total_gb": round(int(parts[2]) / 1024, 2),
}
)
gpu_info["available"] = len(gpu_info["devices"]) > 0
except Exception:
pass
# Fallback to torch-based single-GPU detection
if not gpu_info["available"]:
mem_info = get_gpu_memory_info()
if mem_info.get("available"):
gpu_info["available"] = True
gpu_info["devices"].append(
{
"index": mem_info.get("device", 0),
"name": mem_info.get("device_name", "Unknown"),
"memory_total_gb": round(mem_info.get("total_gb", 0), 2),
}
)
visibility_info = get_backend_visible_gpu_info()
gpu_info = {
"available": visibility_info["available"],
"devices": visibility_info["devices"],
}
# CPU & Memory
memory = psutil.virtual_memory()
@ -243,7 +274,10 @@ async def get_system_info():
return {
"platform": platform.platform(),
"python_version": platform.python_version(),
"device_backend": get_device().value,
# Use the centralized _backend_label helper so the /api/system
# endpoint reports "rocm" on AMD hosts instead of "cuda", matching
# the /api/hardware and /api/gpu-visibility endpoints.
"device_backend": _backend_label(get_device()),
"cpu_count": psutil.cpu_count(),
"memory": {
"total_gb": round(memory.total / 1e9, 2),
@ -254,6 +288,13 @@ async def get_system_info():
}
@app.get("/api/system/gpu-visibility")
async def get_gpu_visibility(
current_subject: str = Depends(get_current_subject),
):
return get_backend_visible_gpu_info()
@app.get("/api/system/hardware")
async def get_hardware_info():
"""Return GPU name, total VRAM, and key ML package versions."""
@ -335,7 +376,7 @@ def setup_frontend(app: FastAPI, build_path: Path):
@app.get("/{full_path:path}")
async def serve_frontend(full_path: str):
if full_path.startswith("api"):
if full_path in {"api", "v1"} or full_path.startswith(("api/", "v1/")):
return {"error": "API endpoint not found"}
file_path = (build_path / full_path).resolve()

View file

@ -10,6 +10,11 @@ from .training import (
TrainingJobResponse,
TrainingStatus,
TrainingProgress,
TrainingRunSummary,
TrainingRunListResponse,
TrainingRunMetrics,
TrainingRunDetailResponse,
TrainingRunDeleteResponse,
)
from .models import (
CheckpointInfo,
@ -71,6 +76,11 @@ __all__ = [
"TrainingJobResponse",
"TrainingStatus",
"TrainingProgress",
"TrainingRunSummary",
"TrainingRunListResponse",
"TrainingRunMetrics",
"TrainingRunDetailResponse",
"TrainingRunDeleteResponse",
# Model management schemas
"ModelDetails",
"LocalModelInfo",

View file

@ -5,6 +5,8 @@
Pydantic schemas for Authentication API
"""
from typing import Optional
from pydantic import BaseModel, Field
@ -45,3 +47,44 @@ class ChangePasswordRequest(BaseModel):
new_password: str = Field(
..., min_length = 8, description = "Replacement password (minimum 8 characters)"
)
# ---------------------------------------------------------------------------
# API key schemas
# ---------------------------------------------------------------------------
class CreateApiKeyRequest(BaseModel):
"""Request body to create a new API key."""
name: str = Field(..., description = "Human-readable label for this key")
expires_in_days: Optional[int] = Field(
None, description = "Number of days until the key expires (None = never)"
)
class ApiKeyResponse(BaseModel):
"""Public representation of an API key (never contains the raw key)."""
id: int
name: str
key_prefix: str = Field(
..., description = "First 8 characters after sk-unsloth- for display"
)
created_at: str
last_used_at: Optional[str] = None
expires_at: Optional[str] = None
is_active: bool
class CreateApiKeyResponse(BaseModel):
"""Returned once when a key is created -- ``key`` is never shown again."""
key: str = Field(..., description = "Full API key (shown once)")
api_key: ApiKeyResponse
class ApiKeyListResponse(BaseModel):
"""List of API keys for the authenticated user."""
api_keys: list[ApiKeyResponse]

View file

@ -11,7 +11,7 @@ import time
import uuid
from typing import Annotated, Any, Dict, Literal, Optional, List, Union
from pydantic import BaseModel, Discriminator, Field, Tag
from pydantic import BaseModel, Discriminator, Field, Tag, model_validator
class LoadRequest(BaseModel):
@ -22,7 +22,10 @@ class LoadRequest(BaseModel):
None, description = "HuggingFace token for gated models"
)
max_seq_length: int = Field(
4096, ge = 128, le = 32768, description = "Maximum sequence length"
0,
ge = 0,
le = 1048576,
description = "Maximum sequence length (0 = model default for GGUF)",
)
load_in_4bit: bool = Field(True, description = "Load model in 4-bit quantization")
is_lora: bool = Field(False, description = "Whether this is a LoRA adapter")
@ -41,6 +44,14 @@ class LoadRequest(BaseModel):
None,
description = "KV cache data type for both K and V (e.g. 'f16', 'bf16', 'q8_0', 'q4_1', 'q5_1')",
)
gpu_ids: Optional[List[int]] = Field(
None,
description = "Physical GPU indices to use, for example [0, 1]. Omit or pass [] to use automatic selection. Explicit gpu_ids are unsupported when the parent CUDA_VISIBLE_DEVICES uses UUID/MIG entries. Not supported for GGUF models.",
)
speculative_type: Optional[str] = Field(
None,
description = "Speculative decoding mode for GGUF models (e.g. 'ngram-simple', 'ngram-mod'). Ignored for non-GGUF and vision models.",
)
class UnloadRequest(BaseModel):
@ -83,6 +94,10 @@ class ValidateModelResponse(BaseModel):
is_gguf: bool = Field(False, description = "Whether this is a GGUF model (llama.cpp)")
is_lora: bool = Field(False, description = "Whether this is a LoRA adapter")
is_vision: bool = Field(False, description = "Whether this is a vision-capable model")
requires_trust_remote_code: bool = Field(
False,
description = "Whether the model defaults require trust_remote_code to be enabled for loading.",
)
class GenerateRequest(BaseModel):
@ -126,13 +141,28 @@ class LoadResponse(BaseModel):
inference: dict = Field(
..., description = "Inference parameters (temperature, top_p, top_k, min_p)"
)
requires_trust_remote_code: bool = Field(
False,
description = "Whether the model defaults require trust_remote_code to be enabled for loading.",
)
context_length: Optional[int] = Field(
None, description = "Model's native context length (from GGUF metadata)"
)
max_context_length: Optional[int] = Field(
None, description = "Maximum context length currently available on this hardware"
)
native_context_length: Optional[int] = Field(
None,
description = "Model's native context length from GGUF metadata (not capped by VRAM)",
)
supports_reasoning: bool = Field(
False,
description = "Whether model supports thinking/reasoning mode (enable_thinking)",
)
reasoning_always_on: bool = Field(
False,
description = "Whether reasoning is always on (hardcoded <think> tags, not toggleable)",
)
supports_tools: bool = Field(
False,
description = "Whether model supports tool calling (web search, etc.)",
@ -145,6 +175,10 @@ class LoadResponse(BaseModel):
None,
description = "Jinja2 chat template string (from GGUF metadata or tokenizer)",
)
speculative_type: Optional[str] = Field(
None,
description = "Active speculative decoding mode (e.g. 'ngram-simple', 'ngram-mod'), or None if disabled",
)
class UnloadResponse(BaseModel):
@ -154,6 +188,39 @@ class UnloadResponse(BaseModel):
model: str = Field(..., description = "Model identifier that was unloaded")
class LoadProgressResponse(BaseModel):
"""Progress of the active GGUF load, sampled on demand.
Used by the UI to show a real progress bar during the
post-download warmup window (mmap + CUDA upload), rather than a
generic "Starting model..." spinner that freezes for minutes on
large MoE models.
"""
phase: Optional[str] = Field(
None,
description = (
"Load phase: 'mmap' (weights paging into RAM via mmap), "
"'ready' (llama-server reported healthy), or null when no "
"load is in flight."
),
)
bytes_loaded: int = Field(
0,
description = (
"Bytes of the model already resident in the llama-server "
"process (VmRSS on Linux)."
),
)
bytes_total: int = Field(
0,
description = "Total bytes across all GGUF shards for the active model.",
)
fraction: float = Field(
0.0, description = "bytes_loaded / bytes_total, clamped to 0..1."
)
class InferenceStatusResponse(BaseModel):
"""Current inference backend status"""
@ -187,15 +254,34 @@ class InferenceStatusResponse(BaseModel):
inference: Optional[Dict[str, Any]] = Field(
None, description = "Recommended inference parameters for the active model"
)
requires_trust_remote_code: bool = Field(
False,
description = "Whether the active model requires trust_remote_code to be enabled for loading.",
)
supports_reasoning: bool = Field(
False, description = "Whether the active model supports reasoning/thinking mode"
)
reasoning_always_on: bool = Field(
False, description = "Whether reasoning is always on (not toggleable)"
)
supports_tools: bool = Field(
False, description = "Whether the active model supports tool calling"
)
context_length: Optional[int] = Field(
None, description = "Context length of the active model"
)
max_context_length: Optional[int] = Field(
None,
description = "Maximum context length currently available for the active model",
)
native_context_length: Optional[int] = Field(
None,
description = "Model's native context length from GGUF metadata (not capped by VRAM)",
)
speculative_type: Optional[str] = Field(
None,
description = "Active speculative decoding mode (e.g. 'ngram-simple', 'ngram-mod'), or None if disabled",
)
# =====================================================================
@ -252,14 +338,68 @@ class ChatMessage(BaseModel):
``content`` may be a plain string (text-only) or a list of
content parts for multimodal messages (OpenAI vision format).
Assistant messages that only contain tool calls may set ``content``
to ``None`` with ``tool_calls`` populated. ``role="tool"`` messages
carry the result of a client-executed tool call and require
``tool_call_id`` per the OpenAI spec.
"""
role: Literal["system", "user", "assistant"] = Field(
role: Literal["system", "user", "assistant", "tool"] = Field(
..., description = "Message role"
)
content: Union[str, list[ContentPart]] = Field(
..., description = "Message content (string or multimodal parts)"
content: Optional[Union[str, list[ContentPart]]] = Field(
None, description = "Message content (string or multimodal parts)"
)
tool_call_id: Optional[str] = Field(
None,
description = "OpenAI tool-result messages: id of the tool call this result belongs to.",
)
tool_calls: Optional[list[dict]] = Field(
None,
description = "OpenAI assistant messages: structured tool calls the model decided to make.",
)
name: Optional[str] = Field(
None,
description = "OpenAI tool-result messages: name of the tool whose result this is.",
)
@model_validator(mode = "after")
def _validate_role_shape(self) -> "ChatMessage":
# Enforce the per-role OpenAI spec shape at the request boundary.
# Without this, malformed messages (e.g. user entries with no
# content, tool_calls on a user/system role, role="tool" without
# tool_call_id) would be silently forwarded to llama-server via
# the passthrough path, surfacing as opaque upstream errors or
# broken tool-call reconciliation downstream.
# Tool-call metadata must appear only on the appropriate role.
if self.tool_calls is not None and self.role != "assistant":
raise ValueError('"tool_calls" is only valid on role="assistant" messages.')
if self.tool_call_id is not None and self.role != "tool":
raise ValueError('"tool_call_id" is only valid on role="tool" messages.')
if self.name is not None and self.role != "tool":
raise ValueError('"name" is only valid on role="tool" messages.')
# Per-role content requirements.
if self.role == "tool":
if not self.tool_call_id:
raise ValueError(
'role="tool" messages require "tool_call_id" per the OpenAI spec.'
)
if not self.content:
raise ValueError('role="tool" messages require non-empty "content".')
elif self.role == "assistant":
# Assistant messages may omit content when tool_calls is set.
if not self.content and not self.tool_calls:
raise ValueError(
'role="assistant" messages require either "content" or "tool_calls".'
)
else: # "user" | "system"
if not self.content:
raise ValueError(
f'role="{self.role}" messages require non-empty "content".'
)
return self
class ChatCompletionRequest(BaseModel):
@ -269,18 +409,49 @@ class ChatCompletionRequest(BaseModel):
Extensions (non-OpenAI fields) are marked with 'x-unsloth'.
"""
# Accept unknown fields defensively so future OpenAI fields (seed,
# response_format, logprobs, frequency_penalty, etc.) don't get
# silently dropped by Pydantic before route code runs. Mirrors
# AnthropicMessagesRequest and ResponsesRequest.
model_config = {"extra": "allow"}
model: str = Field(
"default",
description = "Model identifier (informational; the active model is used)",
)
messages: list[ChatMessage] = Field(..., description = "Conversation messages")
stream: bool = Field(True, description = "Whether to stream the response via SSE")
stream: bool = Field(
False,
description = (
"Whether to stream the response via SSE. Default matches OpenAI's "
"spec (`false`); opt into streaming by sending `stream: true`."
),
)
temperature: float = Field(0.6, ge = 0.0, le = 2.0)
top_p: float = Field(0.95, ge = 0.0, le = 1.0)
max_tokens: Optional[int] = Field(
None, ge = 1, description = "Maximum tokens to generate (None = until EOS)"
)
presence_penalty: float = Field(0.0, ge = 0.0, le = 2.0, description = "Presence penalty")
stop: Optional[Union[str, list[str]]] = Field(
None,
description = "OpenAI stop sequences: a single string or list of strings at which generation halts.",
)
tools: Optional[list[dict]] = Field(
None,
description = (
"OpenAI function-tool definitions. When provided without `enable_tools=true`, "
"Studio forwards the tools to the backend so the model returns structured "
"tool_calls for the client to execute (standard OpenAI function calling)."
),
)
tool_choice: Optional[Union[str, dict]] = Field(
None,
description = (
"OpenAI tool choice: 'auto' | 'required' | 'none' | "
"{'type': 'function', 'function': {'name': ...}}"
),
)
# ── Unsloth extensions (ignored by standard OpenAI clients) ──
top_k: int = Field(20, ge = -1, le = 100, description = "[x-unsloth] Top-k sampling")
@ -288,7 +459,7 @@ class ChatCompletionRequest(BaseModel):
0.01, ge = 0.0, le = 1.0, description = "[x-unsloth] Min-p sampling threshold"
)
repetition_penalty: float = Field(
1.1, ge = 1.0, le = 2.0, description = "[x-unsloth] Repetition penalty"
1.0, ge = 1.0, le = 2.0, description = "[x-unsloth] Repetition penalty"
)
image_base64: Optional[str] = Field(
None, description = "[x-unsloth] Base64-encoded image for vision models"
@ -323,7 +494,7 @@ class ChatCompletionRequest(BaseModel):
description = "[x-unsloth] Auto-detect and fix malformed tool calls from model output.",
)
max_tool_calls_per_message: Optional[int] = Field(
10,
25,
ge = 0,
description = "[x-unsloth] Maximum number of tool call iterations per message (0 = disabled, 9999 = unlimited).",
)
@ -403,3 +574,434 @@ class ChatCompletion(BaseModel):
model: str = "default"
choices: list[CompletionChoice]
usage: CompletionUsage = Field(default_factory = CompletionUsage)
# =====================================================================
# OpenAI Responses API Models (/v1/responses)
# =====================================================================
# ── Request models ──────────────────────────────────────────────
class ResponsesInputTextPart(BaseModel):
"""Text content part in a Responses API message (type=input_text)."""
type: Literal["input_text"]
text: str
class ResponsesInputImagePart(BaseModel):
"""Image content part in a Responses API message (type=input_image)."""
type: Literal["input_image"]
image_url: str = Field(..., description = "data:image/png;base64,... or https://...")
detail: Optional[Literal["auto", "low", "high"]] = "auto"
class ResponsesOutputTextPart(BaseModel):
"""Assistant ``output_text`` content part replayed on subsequent turns.
When a client (OpenAI Codex CLI, OpenAI Python SDK agents) loops on a
stateless Responses endpoint, prior assistant messages are round-tripped
as ``{"role":"assistant","content":[{"type":"output_text","text":...,
"annotations":[],"logprobs":[]}]}``. We preserve the text and ignore
the annotations/logprobs metadata when flattening into Chat Completions.
"""
type: Literal["output_text"]
text: str
annotations: Optional[list] = None
logprobs: Optional[list] = None
model_config = {"extra": "allow"}
class ResponsesUnknownContentPart(BaseModel):
"""Catch-all for content-part types we don't model explicitly.
Keeps validation green when a client sends newer part types (e.g.
``input_audio``, ``input_file``) we haven't mapped; these are silently
skipped during normalisation rather than rejected with a 422.
"""
type: str
model_config = {"extra": "allow"}
ResponsesContentPart = Union[
ResponsesInputTextPart,
ResponsesInputImagePart,
ResponsesOutputTextPart,
ResponsesUnknownContentPart,
]
class ResponsesInputMessage(BaseModel):
"""A single message in the Responses API input array."""
type: Optional[Literal["message"]] = None
role: Literal["system", "user", "assistant", "developer"]
content: Union[str, list[ResponsesContentPart]]
# Codex (gpt-5.3-codex+) attaches a `phase` field ("commentary" |
# "final_answer") to assistant messages and requires clients to preserve
# it on subsequent turns. We accept and round-trip it; llama-server does
# not care about it.
model_config = {"extra": "allow"}
class ResponsesFunctionCallInputItem(BaseModel):
"""A prior assistant function_call being replayed in a multi-turn Responses input.
The Responses API represents tool calls as top-level input items (not
nested inside assistant messages), correlated across turns by ``call_id``.
"""
type: Literal["function_call"]
id: Optional[str] = Field(
None, description = "Item id assigned by the server (e.g. fc_...)"
)
call_id: str = Field(
...,
description = "Correlation id matching a function_call_output on the next turn.",
)
name: str
arguments: str = Field(
..., description = "JSON string of the arguments the model produced."
)
status: Optional[Literal["in_progress", "completed", "incomplete"]] = None
class ResponsesFunctionCallOutputInputItem(BaseModel):
"""A tool result supplied by the client for a prior function_call.
Replaces Chat Completions' ``role="tool"`` message. Correlated to the
originating call by ``call_id``.
"""
type: Literal["function_call_output"]
id: Optional[str] = None
call_id: str
output: Union[str, list] = Field(
..., description = "String or content-array result of the tool call."
)
status: Optional[Literal["in_progress", "completed", "incomplete"]] = None
class ResponsesUnknownInputItem(BaseModel):
"""Catch-all for Responses input item types we don't model explicitly.
Covers ``reasoning`` items (replayed from prior o-series / gpt-5 turns)
and any future item types the client may send. These items are dropped
during normalisation llama-server-backed GGUFs cannot consume them
but keeping them in the request-model union stops unrelated turns from
failing validation with a 422.
"""
type: str
model_config = {"extra": "allow"}
def _responses_input_item_discriminator(v: Any) -> str:
"""Route a Responses input item to the correct tagged variant.
Pydantic's default smart-union matching fails when one variant in the
union is tagged with a strict ``Literal`` (``function_call`` /
``function_call_output``) and the incoming dict uses a different
``type`` the other variants' validation errors are hidden and the
outer ``Union[str, list[...]]`` reports a misleading "Input should be a
valid string" error. An explicit discriminator makes the routing
deterministic and lets us fall through to the catch-all.
"""
if isinstance(v, dict):
t = v.get("type")
r = v.get("role")
else:
t = getattr(v, "type", None)
r = getattr(v, "role", None)
if t == "function_call":
return "function_call"
if t == "function_call_output":
return "function_call_output"
if r is not None or t == "message":
return "message"
return "unknown"
ResponsesInputItem = Annotated[
Union[
Annotated[ResponsesInputMessage, Tag("message")],
Annotated[ResponsesFunctionCallInputItem, Tag("function_call")],
Annotated[ResponsesFunctionCallOutputInputItem, Tag("function_call_output")],
Annotated[ResponsesUnknownInputItem, Tag("unknown")],
],
Discriminator(_responses_input_item_discriminator),
]
class ResponsesFunctionTool(BaseModel):
"""Flat function-tool definition used by the Responses API request.
Unlike Chat Completions (which nests ``{"name": ..., "parameters": ...}``
inside a ``"function"`` key), the Responses API uses a flat shape with
``type``, ``name``, ``description``, ``parameters``, and ``strict`` at the
top level of each tool entry.
"""
type: Literal["function"]
name: str
description: Optional[str] = None
parameters: Optional[dict] = None
strict: Optional[bool] = None
class ResponsesRequest(BaseModel):
"""OpenAI Responses API request."""
model: str = Field("default", description = "Model identifier")
input: Union[str, list[ResponsesInputItem]] = Field(
default = [],
description = "Input text or list of messages / function_call / function_call_output items",
)
instructions: Optional[str] = Field(
None, description = "System / developer instructions"
)
temperature: Optional[float] = Field(None, ge = 0.0, le = 2.0)
top_p: Optional[float] = Field(None, ge = 0.0, le = 1.0)
max_output_tokens: Optional[int] = Field(None, ge = 1)
stream: bool = Field(False, description = "Whether to stream the response via SSE")
# OpenAI function-calling fields — forwarded to llama-server via the
# Chat Completions pass-through (see routes/inference.py). Typed as a
# plain list so built-in tool shapes (``web_search``, ``file_search``,
# ``mcp``, ...) round-trip without validation errors — the translator
# picks out only ``type=="function"`` entries for forwarding.
tools: Optional[list[dict]] = Field(
None,
description = (
"Responses-shape function tool definitions. Entries with "
'`type="function"` are translated to the Chat Completions nested '
"shape before being forwarded to llama-server; other tool types "
"(built-in web_search, file_search, mcp, ...) are accepted for SDK "
"compatibility but ignored on the llama-server passthrough."
),
)
tool_choice: Optional[Any] = Field(
None,
description = (
"'auto' | 'required' | 'none' | {'type': 'function', 'name': ...} — "
"the Responses-shape forcing object is translated to the Chat "
"Completions nested shape internally."
),
)
parallel_tool_calls: Optional[bool] = None
previous_response_id: Optional[str] = None
store: Optional[bool] = None
metadata: Optional[dict] = None
truncation: Optional[Any] = None
user: Optional[str] = None
text: Optional[Any] = None
reasoning: Optional[Any] = None
model_config = {"extra": "allow"}
# ── Response models ─────────────────────────────────────────────
class ResponsesOutputTextContent(BaseModel):
"""A text content block inside an output message."""
type: Literal["output_text"] = "output_text"
text: str
annotations: list = Field(default_factory = list)
class ResponsesOutputMessage(BaseModel):
"""An output message in the Responses API response."""
type: Literal["message"] = "message"
id: str = Field(default_factory = lambda: f"msg_{uuid.uuid4().hex[:12]}")
status: Literal["completed", "in_progress"] = "completed"
role: Literal["assistant"] = "assistant"
content: list[ResponsesOutputTextContent] = Field(default_factory = list)
class ResponsesOutputFunctionCall(BaseModel):
"""A function-call output item in the Responses API response.
Unlike Chat Completions (which nests tool calls inside the assistant
message), the Responses API emits each tool call as its own top-level
``output`` item so clients can correlate results via ``call_id`` on the
next turn.
"""
type: Literal["function_call"] = "function_call"
id: str = Field(default_factory = lambda: f"fc_{uuid.uuid4().hex[:12]}")
call_id: str
name: str
arguments: str = Field(
..., description = "JSON string of the arguments the model produced."
)
status: Literal["completed", "in_progress", "incomplete"] = "completed"
ResponsesOutputItem = Union[ResponsesOutputMessage, ResponsesOutputFunctionCall]
class ResponsesUsage(BaseModel):
"""Token usage for a Responses API response (input_tokens, not prompt_tokens)."""
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
class ResponsesResponse(BaseModel):
"""Top-level Responses API response object."""
id: str = Field(default_factory = lambda: f"resp_{uuid.uuid4().hex[:12]}")
object: Literal["response"] = "response"
created_at: int = Field(default_factory = lambda: int(time.time()))
status: Literal["completed", "in_progress", "failed"] = "completed"
model: str = "default"
output: list[ResponsesOutputItem] = Field(default_factory = list)
usage: ResponsesUsage = Field(default_factory = ResponsesUsage)
error: Optional[Any] = None
incomplete_details: Optional[Any] = None
instructions: Optional[str] = None
metadata: dict = Field(default_factory = dict)
temperature: Optional[float] = None
top_p: Optional[float] = None
max_output_tokens: Optional[int] = None
previous_response_id: Optional[str] = None
text: Optional[Any] = None
tool_choice: Optional[Any] = None
tools: list = Field(default_factory = list)
truncation: Optional[Any] = None
# =====================================================================
# Anthropic Messages API Models (/v1/messages)
# =====================================================================
# ── Request models ─────────────────────────────────────────────
class AnthropicTextBlock(BaseModel):
type: Literal["text"]
text: str
class AnthropicImageSource(BaseModel):
type: Literal["base64", "url"]
media_type: Optional[str] = None
data: Optional[str] = None
url: Optional[str] = None
class AnthropicImageBlock(BaseModel):
type: Literal["image"]
source: AnthropicImageSource
class AnthropicToolUseBlock(BaseModel):
type: Literal["tool_use"]
id: str
name: str
input: dict
class AnthropicToolResultBlock(BaseModel):
type: Literal["tool_result"]
tool_use_id: str
content: Union[str, list] = ""
AnthropicContentBlock = Union[
AnthropicTextBlock,
AnthropicImageBlock,
AnthropicToolUseBlock,
AnthropicToolResultBlock,
]
class AnthropicMessage(BaseModel):
role: Literal["user", "assistant"]
content: Union[str, list[AnthropicContentBlock]]
class AnthropicTool(BaseModel):
name: str
description: Optional[str] = None
input_schema: dict
class AnthropicMessagesRequest(BaseModel):
model: str = "default"
max_tokens: Optional[int] = None
messages: list[AnthropicMessage]
system: Optional[Union[str, list]] = None
tools: Optional[list[AnthropicTool]] = None
tool_choice: Optional[Any] = None
stream: bool = False
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
stop_sequences: Optional[list[str]] = None
metadata: Optional[dict] = None
# [x-unsloth] extensions — mirror the OpenAI endpoint convenience fields
min_p: Optional[float] = Field(
None, ge = 0.0, le = 1.0, description = "[x-unsloth] Min-p sampling threshold"
)
repetition_penalty: Optional[float] = Field(
None, ge = 1.0, le = 2.0, description = "[x-unsloth] Repetition penalty"
)
presence_penalty: Optional[float] = Field(
None, ge = 0.0, le = 2.0, description = "[x-unsloth] Presence penalty"
)
enable_tools: Optional[bool] = None
enabled_tools: Optional[list[str]] = None
session_id: Optional[str] = None
model_config = {"extra": "allow"}
# ── Response models ────────────────────────────────────────────
class AnthropicUsage(BaseModel):
input_tokens: int = 0
output_tokens: int = 0
class AnthropicResponseTextBlock(BaseModel):
type: Literal["text"] = "text"
text: str
class AnthropicResponseToolUseBlock(BaseModel):
type: Literal["tool_use"] = "tool_use"
id: str
name: str
input: dict
AnthropicResponseBlock = Union[
AnthropicResponseTextBlock, AnthropicResponseToolUseBlock
]
class AnthropicMessagesResponse(BaseModel):
id: str = Field(default_factory = lambda: f"msg_{uuid.uuid4().hex[:24]}")
type: Literal["message"] = "message"
role: Literal["assistant"] = "assistant"
content: list[AnthropicResponseBlock] = Field(default_factory = list)
model: str = "default"
stop_reason: Optional[str] = None
stop_sequence: Optional[str] = None
usage: AnthropicUsage = Field(default_factory = AnthropicUsage)

View file

@ -165,7 +165,7 @@ class LocalModelInfo(BaseModel):
id: str = Field(..., description = "Identifier to use for loading/training")
display_name: str = Field(..., description = "Display label")
path: str = Field(..., description = "Local path where model data was discovered")
source: Literal["models_dir", "hf_cache"] = Field(
source: Literal["models_dir", "hf_cache", "lmstudio", "custom"] = Field(
...,
description = "Discovery source",
)
@ -189,7 +189,92 @@ class LocalModelListResponse(BaseModel):
None,
description = "HF cache root that was scanned",
)
lmstudio_dirs: List[str] = Field(
default_factory = list,
description = "LM Studio model directories that were scanned",
)
models: List[LocalModelInfo] = Field(
default_factory = list,
description = "Discovered local/cached models",
)
class AddScanFolderRequest(BaseModel):
"""Request body for adding a custom scan folder."""
path: str = Field(
..., description = "Absolute or relative directory path to scan for models"
)
class ScanFolderInfo(BaseModel):
"""A registered custom model scan folder."""
id: int = Field(..., description = "Database row ID")
path: str = Field(..., description = "Normalized absolute path")
created_at: str = Field(..., description = "ISO 8601 creation timestamp")
class BrowseEntry(BaseModel):
"""A directory entry surfaced by the folder browser."""
name: str = Field(..., description = "Entry name (basename, not full path)")
has_models: bool = Field(
False,
description = (
"Hint that the directory likely contains models "
"(*.gguf, *.safetensors, config.json, or HF-style "
"`models--*` subfolders). Used by the UI to highlight "
"promising candidates; the scanner itself is authoritative."
),
)
hidden: bool = Field(
False,
description = "Name starts with a dot (e.g. `.cache`)",
)
class BrowseFoldersResponse(BaseModel):
"""Response schema for the folder browser endpoint."""
current: str = Field(..., description = "Absolute path of the directory just listed")
parent: Optional[str] = Field(
None,
description = (
"Parent directory of `current`, or null if `current` is the "
"filesystem root. The frontend uses this to render an `Up` row."
),
)
entries: List[BrowseEntry] = Field(
default_factory = list,
description = (
"Subdirectories of `current`. Sorted with model-bearing "
"directories first, then alphabetically case-insensitive; "
"hidden entries come last within each group."
),
)
suggestions: List[str] = Field(
default_factory = list,
description = (
"Handy starting points (home, HF cache, already-registered "
"scan folders). Rendered as quick-pick chips above the list."
),
)
truncated: bool = Field(
False,
description = (
"True when the listing was capped because the directory had "
"more subfolders than the server is willing to enumerate in "
"one request. The UI should show a hint telling the user to "
"narrow their path."
),
)
model_files_here: int = Field(
0,
description = (
"Count of GGUF/safetensors files immediately inside "
"``current``. Used by the UI to surface a hint on leaf "
"model directories (which otherwise look `empty` because "
"they contain only files, no subdirectories)."
),
)

View file

@ -81,7 +81,7 @@ class TrainingStartRequest(BaseModel):
warmup_ratio: Optional[float] = Field(None, description = "Warmup ratio")
max_steps: Optional[int] = Field(None, description = "Maximum training steps")
save_steps: int = Field(100, description = "Steps between checkpoints")
weight_decay: float = Field(0.01, description = "Weight decay")
weight_decay: float = Field(0.001, description = "Weight decay")
random_seed: int = Field(42, description = "Random seed")
packing: bool = Field(False, description = "Enable sequence packing")
optim: str = Field("adamw_8bit", description = "Optimizer")
@ -128,6 +128,12 @@ class TrainingStartRequest(BaseModel):
enable_tensorboard: bool = Field(False, description = "Enable TensorBoard logging")
tensorboard_dir: Optional[str] = Field(None, description = "TensorBoard directory")
# GPU selection
gpu_ids: Optional[List[int]] = Field(
None,
description = "Physical GPU indices to use, for example [0, 1]. Omit or pass [] to use automatic selection. Explicit gpu_ids are unsupported when the parent CUDA_VISIBLE_DEVICES uses UUID/MIG entries.",
)
class TrainingJobResponse(BaseModel):
"""Immediate response when training is initiated"""
@ -177,8 +183,8 @@ class TrainingProgress(BaseModel):
job_id: str = Field(..., description = "Training job identifier")
step: int = Field(..., description = "Current training step")
total_steps: int = Field(..., description = "Total training steps")
loss: float = Field(..., description = "Current loss value")
learning_rate: float = Field(..., description = "Current learning rate")
loss: Optional[float] = Field(None, description = "Current loss value")
learning_rate: Optional[float] = Field(None, description = "Current learning rate")
progress_percent: float = Field(
..., description = "Progress percentage (0.0 to 100.0)"
)
@ -196,3 +202,59 @@ class TrainingProgress(BaseModel):
eval_loss: Optional[float] = Field(
None, description = "Eval loss from the most recent evaluation step"
)
class TrainingRunSummary(BaseModel):
"""Summary of a training run for list views."""
id: str
status: Literal["running", "completed", "stopped", "error"]
model_name: str
dataset_name: str
started_at: str
ended_at: Optional[str] = None
total_steps: Optional[int] = None
final_step: Optional[int] = None
final_loss: Optional[float] = None
output_dir: Optional[str] = None
duration_seconds: Optional[float] = None
error_message: Optional[str] = None
loss_sparkline: Optional[List[float]] = None
class TrainingRunListResponse(BaseModel):
"""Response for listing training runs."""
runs: List[TrainingRunSummary]
total: int
class TrainingRunMetrics(BaseModel):
"""Metrics arrays for a training run, using paired step arrays per metric."""
step_history: List[int] = Field(default_factory = list)
loss_history: List[float] = Field(default_factory = list)
loss_step_history: List[int] = Field(default_factory = list)
lr_history: List[float] = Field(default_factory = list)
lr_step_history: List[int] = Field(default_factory = list)
grad_norm_history: List[float] = Field(default_factory = list)
grad_norm_step_history: List[int] = Field(default_factory = list)
eval_loss_history: List[float] = Field(default_factory = list)
eval_step_history: List[int] = Field(default_factory = list)
final_epoch: Optional[float] = None
final_num_tokens: Optional[int] = None
class TrainingRunDetailResponse(BaseModel):
"""Response for a single training run with config and metrics."""
run: TrainingRunSummary
config: dict
metrics: TrainingRunMetrics
class TrainingRunDeleteResponse(BaseModel):
"""Response for deleting a training run."""
status: str
message: str

View file

@ -11,7 +11,7 @@ version = "0.1.0"
description = "Local Data Designer unstructured seed reader plugin"
requires-python = ">=3.11"
dependencies = [
"data-designer-engine>=0.5.1,<0.6",
"data-designer-engine>=0.5.4,<0.6",
"pandas>=2,<3",
"pymupdf>=1.24.0",
"pymupdf4llm>=0.0.17",

View file

@ -2,9 +2,13 @@
descript-audio-codec
descript-audiotools
julius
torchcodec
torchcodec==0.10.0
snac
# peft 0.19.0 causes export subprocess shutdown issues in Studio;
# installing with --no-deps to avoid pulling in torch>=0.11.0
peft==0.18.1
# TRL and related packages
trl==0.23.1
git+https://github.com/meta-pytorch/OpenEnv.git
@ -12,3 +16,5 @@ git+https://github.com/meta-pytorch/OpenEnv.git
torch-c-dlpack-ext
sentence_transformers==5.2.0
transformers==4.57.6
pytorch_tokenizers
kernels==0.12.1

View file

@ -0,0 +1,50 @@
# Runtime dependencies for no-torch (GGUF-only) mode.
# Installed with --no-deps to prevent transitive torch resolution
# from packages like accelerate, peft, trl, sentence-transformers.
#
# Includes unsloth's own direct deps (typer, pydantic, pyyaml,
# nest-asyncio) since unsloth is also installed with --no-deps
# (current PyPI metadata still declares torch as a hard dep).
# unsloth direct deps (from pyproject.toml [project].dependencies)
typer
pydantic
pyyaml
nest-asyncio
# HF ecosystem (from [huggingfacenotorch] extras in pyproject.toml)
wheel>=0.42.0
packaging
numpy
tqdm
psutil
tyro
protobuf
sentencepiece>=0.2.0
safetensors>=0.4.3
datasets>=3.4.1,!=4.0.*,!=4.1.0,<4.4.0
accelerate>=0.34.1
peft>=0.18.0,!=0.11.0
huggingface_hub>=0.34.0
hf_transfer
diffusers
# Transitive deps required because this file is installed with --no-deps.
# Without these, `from transformers import AutoConfig` fails at import time.
regex
typing_extensions
filelock
httpx
httpcore
certifi
idna
anyio
sniffio
h11
tokenizers
transformers>=4.51.3,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,!=4.54.0,!=4.55.0,!=4.55.1,!=4.57.0,!=4.57.4,!=4.57.5,!=5.0.0,!=5.1.0,<=5.3.0
trl>=0.18.2,!=0.19.0,<=0.24.0
sentence-transformers
cut_cross_entropy
pillow

View file

@ -1,6 +1,2 @@
# Torch AO overrides (installed with --force-reinstall --no-cache-dir)
torchao==0.14.0
pytorch_tokenizers
# Kernel packages
kernels

View file

@ -1,22 +1,25 @@
# Data Designer runtime deps installed explicitly (single-env mode).
# DuckDB 1.5 removed Relation.record_batch(); keep <1.5 until upstream ships the fix.
# Synced with data-designer-engine==0.5.4 requirements.
anyascii<1,>=0.3.3
duckdb<1.5,>=1.1.3
chardet<6,>=3.0.2
duckdb<2,>=1.5.0
faker<21,>=20.1.0
fsspec<2026,>=2025.3.0
httpx<1,>=0.27.2
httpx-retries<1,>=0.4.2
json-repair<1,>=0.48.0
jsonpath-rust-bindings<2,>=1.0
jsonschema<5,>=4.0.0
litellm<1.80.12,>=1.73.6
lxml<7,>=6.0.2
marko<3,>=2.1.2
mcp<2,>=1.26.0
networkx<4,>=3.0
python-json-logger<4,>=3
python-json-logger>=3,<4
ruff<1,>=0.14.10
scipy<2,>=1.11.0
sqlfluff<4,>=3.2.0
tiktoken<1,>=0.8.0
# Unstructured-seed plugin deps (plugin installed with --no-deps)
pymupdf>=1.24.0
pymupdf4llm>=0.0.17
mammoth>=1.8.0

View file

@ -1,5 +1,5 @@
# Install Data Designer in same env as Unsloth.
data-designer==0.5.2
data-designer-config==0.5.2
data-designer-engine==0.5.2
data-designer==0.5.4
data-designer-config==0.5.4
data-designer-engine==0.5.4
prompt-toolkit>=3,<4

View file

@ -12,6 +12,7 @@ from routes.datasets import router as datasets_router
from routes.auth import router as auth_router
from routes.data_recipe import router as data_recipe_router
from routes.export import router as export_router
from routes.training_history import router as training_history_router
__all__ = [
"training_router",
@ -21,4 +22,5 @@ __all__ = [
"auth_router",
"data_recipe_router",
"export_router",
"training_history_router",
]

View file

@ -7,11 +7,17 @@ Authentication API routes
from fastapi import APIRouter, Depends, HTTPException, status
from datetime import datetime, timedelta, timezone
from models.auth import (
ApiKeyListResponse,
ApiKeyResponse,
AuthLoginRequest,
RefreshTokenRequest,
AuthStatusResponse,
ChangePasswordRequest,
CreateApiKeyRequest,
CreateApiKeyResponse,
RefreshTokenRequest,
)
from models.users import Token
from auth import storage, hashing
@ -131,3 +137,68 @@ async def change_password(
token_type = "bearer",
must_change_password = False,
)
# ---------------------------------------------------------------------------
# API key management
# ---------------------------------------------------------------------------
def _row_to_api_key_response(row: dict) -> ApiKeyResponse:
return ApiKeyResponse(
id = row["id"],
name = row["name"],
key_prefix = row["key_prefix"],
created_at = row["created_at"],
last_used_at = row.get("last_used_at"),
expires_at = row.get("expires_at"),
is_active = bool(row["is_active"]),
)
@router.post("/api-keys", response_model = CreateApiKeyResponse)
async def create_api_key(
payload: CreateApiKeyRequest,
current_subject: str = Depends(get_current_subject),
) -> CreateApiKeyResponse:
"""Create a new API key. The raw key is returned once and cannot be retrieved later."""
expires_at = None
if payload.expires_in_days is not None:
expires_at = (
datetime.now(timezone.utc) + timedelta(days = payload.expires_in_days)
).isoformat()
raw_key, row = storage.create_api_key(
username = current_subject,
name = payload.name,
expires_at = expires_at,
)
return CreateApiKeyResponse(
key = raw_key,
api_key = _row_to_api_key_response(row),
)
@router.get("/api-keys", response_model = ApiKeyListResponse)
async def list_api_keys(
current_subject: str = Depends(get_current_subject),
) -> ApiKeyListResponse:
"""List all API keys for the authenticated user (raw keys are never exposed)."""
rows = storage.list_api_keys(current_subject)
return ApiKeyListResponse(
api_keys = [_row_to_api_key_response(r) for r in rows],
)
@router.delete("/api-keys/{key_id}")
async def revoke_api_key(
key_id: int,
current_subject: str = Depends(get_current_subject),
) -> dict:
"""Revoke (soft-delete) an API key."""
if not storage.revoke_api_key(current_subject, key_id):
raise HTTPException(
status_code = status.HTTP_404_NOT_FOUND,
detail = "API key not found",
)
return {"detail": "API key revoked"}

View file

@ -5,7 +5,9 @@
from __future__ import annotations
from datetime import timedelta
from typing import Any
from urllib.parse import urlparse
from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import JSONResponse, StreamingResponse
@ -26,6 +28,161 @@ from models.data_recipe import (
router = APIRouter()
def _resolve_local_v1_endpoint(request: Request) -> str:
"""Return the loopback /v1 URL for the actual backend listen port.
Resolution order:
1. ``app.state.server_port`` - explicitly published by run.py after
the uvicorn server has bound. This is the most reliable source
because it survives reverse proxies, TLS terminators and tunnels.
2. ``request.scope["server"]`` - the real (host, port) tuple uvicorn
sets when the request is dispatched. Used when Studio is started
outside ``run_server`` (e.g. ``uvicorn studio.backend.main:app``).
3. ``request.base_url`` parsed - last resort for test fixtures that
do not route through a live uvicorn server.
"""
port: Any = getattr(request.app.state, "server_port", None)
if not isinstance(port, int) or port <= 0:
server = request.scope.get("server")
if (
isinstance(server, tuple)
and len(server) >= 2
and isinstance(server[1], int)
and server[1] > 0
):
port = server[1]
else:
parsed = urlparse(str(request.base_url))
port = parsed.port if parsed.port is not None else 8888
return f"http://127.0.0.1:{int(port)}/v1"
def _used_llm_model_aliases(recipe: dict[str, Any]) -> set[str]:
"""Return the set of model_aliases that are actually referenced by an
LLM column. Used to narrow the "Chat model loaded" gate so that orphan
model_config nodes on the canvas do not block unrelated recipe runs.
The ``llm-`` prefix matches the existing convention in
``core/data_recipe/service.py::_recipe_has_llm_columns`` and covers all
LLM column types emitted by the frontend (llm-text, llm-code,
llm-structured, llm-judge).
"""
aliases: set[str] = set()
for column in recipe.get("columns", []):
if not isinstance(column, dict):
continue
column_type = column.get("column_type")
if not isinstance(column_type, str) or not column_type.startswith("llm-"):
continue
alias = column.get("model_alias")
if isinstance(alias, str) and alias:
aliases.add(alias)
return aliases
def _inject_local_providers(recipe: dict[str, Any], request: Request) -> None:
"""
Mutate recipe dict in-place: for any provider with is_local=True,
generate a JWT and fill in the endpoint pointing at this server.
"""
providers = recipe.get("model_providers")
if not providers:
return
# Collect local providers and pop is_local from ALL dicts unconditionally.
# Strict `is True` guard so malformed payloads (is_local: 1,
# is_local: "true") do not accidentally trigger the loopback rewrite.
local_indices: list[int] = []
for i, provider in enumerate(providers):
if not isinstance(provider, dict):
continue
is_local = provider.pop("is_local", None)
if is_local is True:
local_indices.append(i)
if not local_indices:
return
endpoint = _resolve_local_v1_endpoint(request)
# Only gate on model-loaded if a local provider is actually reachable
# from an LLM column through a model_config. Orphan model_config nodes
# that reference a local provider but that no LLM column uses should
# not block runs; the recipe would never call /v1 for them.
local_names = {
providers[i].get("name") for i in local_indices if providers[i].get("name")
}
used_aliases = _used_llm_model_aliases(recipe)
referenced_providers = {
mc.get("provider")
for mc in recipe.get("model_configs", [])
if (
isinstance(mc, dict)
and mc.get("provider")
and mc.get("alias") in used_aliases
)
}
token = ""
if local_names & referenced_providers:
# Verify a model is loaded.
# NOTE: This is a point-in-time check (TOCTOU). The model could be unloaded
# or swapped after this check but before the recipe subprocess calls /v1.
# The inference endpoint returns a clear 400 in that case.
#
# Imports are deferred to avoid circular dependencies with inference modules.
from routes.inference import get_llama_cpp_backend
from core.inference import get_inference_backend
llama = get_llama_cpp_backend()
model_loaded = llama.is_loaded
if not model_loaded:
backend = get_inference_backend()
model_loaded = bool(backend.active_model_name)
if not model_loaded:
raise ValueError(
"No model loaded in Chat. Load a model first, then run the recipe."
)
from auth.authentication import (
create_access_token,
) # deferred: avoids circular import
# Uses the "unsloth" admin subject. If the user changes their password,
# the JWT secret rotates and this token becomes invalid mid-run.
# Acceptable for v1 - recipes typically finish well within one session.
token = create_access_token(
subject = "unsloth",
expires_delta = timedelta(hours = 24),
)
# Defensively strip any stale "external"-only fields the frontend may
# have left on the dict (extra_headers/extra_body/api_key_env). The UI
# hides these inputs in local mode but the payload builder still serializes
# them, so a previously external provider that flipped to local can carry
# invalid JSON or rogue auth headers into the local /v1 call.
for i in local_indices:
providers[i]["endpoint"] = endpoint
providers[i]["api_key"] = token
providers[i]["provider_type"] = "openai"
providers[i].pop("api_key_env", None)
providers[i].pop("extra_headers", None)
providers[i].pop("extra_body", None)
# Force skip_health_check on any model_config that references a local
# provider. The local /v1/models endpoint only lists the real loaded
# model (e.g. "unsloth/llama-3.2-1b") and not the placeholder "local"
# that the recipe sends as the model id, so data_designer's pre-flight
# health check would otherwise fail before the first completion call.
# The backend route ignores the model id field in chat completions, so
# skipping the check is safe.
for mc in recipe.get("model_configs", []):
if not isinstance(mc, dict):
continue
if mc.get("provider") in local_names:
mc["skip_health_check"] = True
def _normalize_run_name(value: Any) -> str | None:
if value is None:
return None
@ -40,7 +197,7 @@ def _normalize_run_name(value: Any) -> str | None:
@router.post("/jobs", response_class = JSONResponse, response_model = JobCreateResponse)
def create_job(payload: RecipePayload):
def create_job(payload: RecipePayload, request: Request):
recipe = payload.recipe
if not recipe.get("columns"):
raise HTTPException(status_code = 400, detail = "Recipe must include columns.")
@ -67,6 +224,11 @@ def create_job(payload: RecipePayload):
status_code = 400, detail = f"invalid run_config: {exc}"
) from exc
try:
_inject_local_providers(recipe, request)
except ValueError as exc:
raise HTTPException(status_code = 400, detail = str(exc)) from exc
mgr = get_job_manager()
try:
job_id = mgr.start(recipe = recipe, run = run)

View file

@ -388,7 +388,7 @@ def _extract_text_from_file(file_path: Path, ext: str) -> str:
import pymupdf4llm
raw = pymupdf4llm.to_markdown(
str(file_path), write_images = False, show_progress = False
str(file_path), write_images = False, show_progress = False, use_ocr = False
)
elif ext == ".docx":
import mammoth

View file

@ -68,6 +68,20 @@ def _collect_validation_errors(recipe: dict[str, Any]) -> list[ValidateError]:
return errors
def _patch_local_providers(recipe: dict[str, Any]) -> None:
"""Strip is_local and fill a dummy endpoint so validation doesn't choke.
Uses a strict `is True` check to match _inject_local_providers in
jobs.py - malformed payloads with truthy but non-boolean is_local
values should not be treated as local.
"""
for provider in recipe.get("model_providers", []):
if not isinstance(provider, dict):
continue
if provider.pop("is_local", None) is True:
provider["endpoint"] = "http://127.0.0.1"
@router.post("/validate", response_model = ValidateResponse)
def validate(payload: RecipePayload) -> ValidateResponse:
recipe = payload.recipe
@ -77,6 +91,8 @@ def validate(payload: RecipePayload) -> ValidateResponse:
errors = [ValidateError(message = "Recipe must include columns.")],
)
_patch_local_providers(recipe)
try:
validate_recipe(recipe)
except RuntimeError as exc:

View file

@ -11,10 +11,55 @@ import json
import sys
from pathlib import Path
from uuid import uuid4
from fastapi import APIRouter, Depends, HTTPException, UploadFile
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile
import re as _re
import structlog
from loggers import get_logger
_VALID_REPO_ID = _re.compile(r"^[A-Za-z0-9._-]+/[A-Za-z0-9._-]+$")
def _is_valid_repo_id(repo_id: str) -> bool:
return bool(_VALID_REPO_ID.fullmatch(repo_id))
_dataset_size_cache: dict[str, int] = {}
def _get_dataset_size_cached(repo_id: str) -> int:
if repo_id in _dataset_size_cache:
return _dataset_size_cache[repo_id]
try:
from huggingface_hub import dataset_info as hf_dataset_info
info = hf_dataset_info(repo_id, token = None, files_metadata = True)
total = sum(s.size for s in info.siblings if getattr(s, "size", None))
_dataset_size_cache[repo_id] = total
return total
except Exception:
return 0
def _resolve_hf_cache_realpath(repo_dir: Path) -> Optional[str]:
"""Pick the most useful on-disk path for a HF cache repo dir.
Mirrors the helper in routes/models.py: prefer the most-recent
snapshot dir, fall back to the cache repo root, return resolved
realpath. Duplicated here to keep routes/datasets.py self-contained.
"""
try:
snapshots_dir = repo_dir / "snapshots"
if snapshots_dir.is_dir():
snaps = [s for s in snapshots_dir.iterdir() if s.is_dir()]
if snaps:
latest = max(snaps, key = lambda s: s.stat().st_mtime)
return str(latest.resolve())
return str(repo_dir.resolve())
except Exception:
return None
# Add backend directory to path
backend_path = Path(__file__).parent.parent.parent
if str(backend_path) not in sys.path:
@ -308,6 +353,89 @@ def list_local_datasets(
return LocalDatasetsResponse(datasets = _build_local_dataset_items())
@router.get("/download-progress")
async def get_dataset_download_progress(
repo_id: str = Query(
..., description = "HuggingFace dataset repo ID, e.g. 'unsloth/LaTeX_OCR'"
),
current_subject: str = Depends(get_current_subject),
):
"""Return download progress for a HuggingFace dataset repo.
Mirrors ``GET /api/models/download-progress`` but scans the
``datasets--owner--name`` cache directory under HF_HUB_CACHE.
Modern ``datasets``/``huggingface_hub`` caches both raw model and
raw dataset blobs in HF_HUB_CACHE; the ``datasets`` library writes
its processed Arrow shards elsewhere, but the in-progress *download*
bytes are observable here. Returns ``cache_path`` so the UI can
show users where the dataset blobs landed on disk.
"""
_empty = {
"downloaded_bytes": 0,
"expected_bytes": 0,
"progress": 0,
"cache_path": None,
}
try:
if not _is_valid_repo_id(repo_id):
return _empty
from huggingface_hub import constants as hf_constants
cache_dir = Path(hf_constants.HF_HUB_CACHE)
target = f"datasets--{repo_id.replace('/', '--')}".lower()
completed_bytes = 0
in_progress_bytes = 0
cache_path: Optional[str] = None
if cache_dir.is_dir():
for entry in cache_dir.iterdir():
if entry.name.lower() != target:
continue
cache_path = _resolve_hf_cache_realpath(entry)
blobs_dir = entry / "blobs"
if not blobs_dir.is_dir():
break
for f in blobs_dir.iterdir():
if not f.is_file():
continue
if f.name.endswith(".incomplete"):
in_progress_bytes += f.stat().st_size
else:
completed_bytes += f.stat().st_size
break
downloaded_bytes = completed_bytes + in_progress_bytes
if downloaded_bytes == 0:
return {**_empty, "cache_path": cache_path}
expected_bytes = _get_dataset_size_cached(repo_id)
if expected_bytes <= 0:
return {
"downloaded_bytes": downloaded_bytes,
"expected_bytes": 0,
"progress": 0,
"cache_path": cache_path,
}
# Same 95% completion threshold as the model endpoint -- HF blob
# dedup makes completed_bytes drift slightly under expected_bytes,
# and inter-file gaps would otherwise look like "done".
if completed_bytes >= expected_bytes * 0.95:
progress = 1.0
else:
progress = min(downloaded_bytes / expected_bytes, 0.99)
return {
"downloaded_bytes": downloaded_bytes,
"expected_bytes": expected_bytes,
"progress": round(progress, 3),
"cache_path": cache_path,
}
except Exception as e:
logger.warning(f"Error checking dataset download progress for {repo_id}: {e}")
return _empty
@router.post("/check-format", response_model = CheckFormatResponse)
def check_format(
request: CheckFormatRequest,

View file

@ -5,9 +5,15 @@
Export API routes: checkpoint discovery and model export operations.
"""
import asyncio
import json
import sys
import time
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import StreamingResponse
import structlog
from loggers import get_logger
@ -97,7 +103,11 @@ async def load_checkpoint(
logger.warning("Could not stop training: %s", e)
backend = get_export_backend()
success, message = backend.load_checkpoint(
# load_checkpoint spawns and waits on a subprocess and can take
# minutes. Run it in a worker thread so the event loop stays
# free to serve the live log SSE stream concurrently.
success, message = await asyncio.to_thread(
backend.load_checkpoint,
checkpoint_path = request.checkpoint_path,
max_seq_length = request.max_seq_length,
load_in_4bit = request.load_in_4bit,
@ -129,7 +139,7 @@ async def cleanup_export_memory(
"""
try:
backend = get_export_backend()
success = backend.cleanup_memory()
success = await asyncio.to_thread(backend.cleanup_memory)
if not success:
raise HTTPException(
@ -173,6 +183,17 @@ async def get_export_status(
)
def _export_details(output_path: Optional[str]) -> Optional[Dict[str, Any]]:
"""Wrap the resolved on-disk export path into the details dict the
frontend reads to populate the Export Complete screen. Returns None
when the export had no local component (Hub-only push) so the
Pydantic field stays absent rather than ``{"output_path": null}``.
"""
if not output_path:
return None
return {"output_path": output_path}
@router.post("/export/merged", response_model = ExportOperationResponse)
async def export_merged_model(
request: ExportMergedModelRequest,
@ -185,7 +206,8 @@ async def export_merged_model(
"""
try:
backend = get_export_backend()
success, message = backend.export_merged_model(
success, message, output_path = await asyncio.to_thread(
backend.export_merged_model,
save_directory = request.save_directory,
format_type = request.format_type,
push_to_hub = request.push_to_hub,
@ -197,7 +219,11 @@ async def export_merged_model(
if not success:
raise HTTPException(status_code = 400, detail = message)
return ExportOperationResponse(success = True, message = message)
return ExportOperationResponse(
success = True,
message = message,
details = _export_details(output_path),
)
except HTTPException:
raise
except Exception as e:
@ -220,7 +246,8 @@ async def export_base_model(
"""
try:
backend = get_export_backend()
success, message = backend.export_base_model(
success, message, output_path = await asyncio.to_thread(
backend.export_base_model,
save_directory = request.save_directory,
push_to_hub = request.push_to_hub,
repo_id = request.repo_id,
@ -232,7 +259,11 @@ async def export_base_model(
if not success:
raise HTTPException(status_code = 400, detail = message)
return ExportOperationResponse(success = True, message = message)
return ExportOperationResponse(
success = True,
message = message,
details = _export_details(output_path),
)
except HTTPException:
raise
except Exception as e:
@ -255,7 +286,8 @@ async def export_gguf(
"""
try:
backend = get_export_backend()
success, message = backend.export_gguf(
success, message, output_path = await asyncio.to_thread(
backend.export_gguf,
save_directory = request.save_directory,
quantization_method = request.quantization_method,
push_to_hub = request.push_to_hub,
@ -266,7 +298,11 @@ async def export_gguf(
if not success:
raise HTTPException(status_code = 400, detail = message)
return ExportOperationResponse(success = True, message = message)
return ExportOperationResponse(
success = True,
message = message,
details = _export_details(output_path),
)
except HTTPException:
raise
except Exception as e:
@ -289,7 +325,8 @@ async def export_lora_adapter(
"""
try:
backend = get_export_backend()
success, message = backend.export_lora_adapter(
success, message, output_path = await asyncio.to_thread(
backend.export_lora_adapter,
save_directory = request.save_directory,
push_to_hub = request.push_to_hub,
repo_id = request.repo_id,
@ -300,7 +337,11 @@ async def export_lora_adapter(
if not success:
raise HTTPException(status_code = 400, detail = message)
return ExportOperationResponse(success = True, message = message)
return ExportOperationResponse(
success = True,
message = message,
details = _export_details(output_path),
)
except HTTPException:
raise
except Exception as e:
@ -309,3 +350,155 @@ async def export_lora_adapter(
status_code = 500,
detail = f"Failed to export LoRA adapter: {str(e)}",
)
# ─────────────────────────────────────────────────────────────────────
# Live export log stream (Server-Sent Events)
# ─────────────────────────────────────────────────────────────────────
#
# The export worker subprocess redirects its stdout/stderr into a pipe
# that a reader thread forwards to the orchestrator as log entries (see
# core/export/worker.py::_setup_log_capture and
# core/export/orchestrator.py::_append_log). This endpoint streams
# those entries to the browser so the export dialog can show a live
# terminal-style output panel while load_checkpoint / export_merged /
# export_gguf / export_lora / export_base run.
#
# Shape follows the training progress SSE endpoint
# (routes/training.py::stream_training_progress): each event carries
# `id`, `event`, and `data` fields, the stream starts with a `retry:`
# directive, and `Last-Event-ID` is honored on reconnect.
def _format_sse(data: str, event: str, event_id: Optional[int] = None) -> str:
"""Format a single SSE message with id/event/data fields."""
lines = []
if event_id is not None:
lines.append(f"id: {event_id}")
lines.append(f"event: {event}")
lines.append(f"data: {data}")
lines.append("")
lines.append("")
return "\n".join(lines)
@router.get("/logs/stream")
async def stream_export_logs(
request: Request,
since: Optional[int] = Query(
None,
description = "Return log entries with seq strictly greater than this cursor.",
),
current_subject: str = Depends(get_current_subject),
):
"""
Stream live stdout/stderr output from the export worker subprocess
as Server-Sent Events.
Events:
- `log` : a single log line (data: {"stream","line","ts"})
- `heartbeat`: periodic keepalive when no new lines are available
- `complete` : emitted once the export worker is idle and no new
lines arrived for ~1 second. Clients should close.
- `error` : unrecoverable server-side error
The `id:` field on each event is the log entry's monotonic seq
number so the browser can resume via `Last-Event-ID` on reconnect.
"""
backend = get_export_backend()
# Determine starting cursor. Explicit `since` wins, then
# Last-Event-ID header on reconnect, otherwise start from the
# run-start snapshot captured by clear_logs() so the client sees
# every line emitted since the current run began -- even if the
# SSE connection opened after the POST that kicked off the export.
# Using get_current_log_seq() here would lose the early bootstrap
# lines that arrive in the gap between POST and SSE connect.
last_event_id = request.headers.get("last-event-id")
if since is None and last_event_id is not None:
try:
since = int(last_event_id)
except ValueError:
pass
if since is None:
cursor = backend.get_run_start_seq()
else:
cursor = max(0, int(since))
async def event_generator() -> AsyncGenerator[str, None]:
nonlocal cursor
# Tell the browser to reconnect after 3 seconds if the
# connection drops mid-export.
yield "retry: 3000\n\n"
last_yield = time.monotonic()
idle_since: Optional[float] = None
try:
while True:
if await request.is_disconnected():
return
entries, new_cursor = backend.get_logs_since(cursor)
if entries:
for entry in entries:
payload = json.dumps(
{
"stream": entry.get("stream", "stdout"),
"line": entry.get("line", ""),
"ts": entry.get("ts"),
}
)
yield _format_sse(
payload,
event = "log",
event_id = int(entry.get("seq", 0)),
)
cursor = new_cursor
last_yield = time.monotonic()
idle_since = None
else:
now = time.monotonic()
if now - last_yield > 10.0:
yield _format_sse("{}", event = "heartbeat")
last_yield = now
if not backend.is_export_active():
# Give the reader thread a moment to drain any
# trailing lines the worker process printed
# just before signalling done.
if idle_since is None:
idle_since = now
elif now - idle_since > 1.0:
yield _format_sse(
"{}",
event = "complete",
event_id = cursor,
)
return
else:
idle_since = None
await asyncio.sleep(0.1)
except asyncio.CancelledError:
# Client disconnected mid-yield. Don't re-raise, just end
# the generator cleanly so StreamingResponse finalizes.
return
except Exception as exc:
logger.error("Export log stream failed: %s", exc, exc_info = True)
try:
yield _format_sse(
json.dumps({"error": str(exc)}),
event = "error",
)
except Exception:
pass
return StreamingResponse(
event_generator(),
media_type = "text/event-stream",
headers = {
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -14,6 +14,7 @@ import structlog
from loggers import get_logger
import asyncio
from datetime import datetime
import uuid as _uuid
# Add backend directory to path
# The backend code should be in the same directory structure
@ -87,14 +88,22 @@ async def get_hardware_utilization(
Get a live snapshot of GPU hardware utilization.
Designed to be polled by the frontend during training.
Returns GPU utilization %, temperature, VRAM usage, and power draw
via nvidia-smi for maximum accuracy.
Returns live GPU memory usage information for the active backend.
"""
from utils.hardware import get_gpu_utilization
return get_gpu_utilization()
@router.get("/hardware/visible")
async def get_visible_hardware_utilization(
current_subject: str = Depends(get_current_subject),
):
from utils.hardware import get_visible_gpu_utilization
return get_visible_gpu_utilization()
@router.post("/start")
async def start_training(
request: TrainingStartRequest,
@ -115,15 +124,11 @@ async def start_training(
backend = get_training_backend()
# Generate job ID and attach to backend for later status/progress calls
job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
backend.current_job_id = job_id
# Check if training is already active
# Check if training is already active (before mutating any state)
if backend.is_training_active():
existing_job_id: Optional[str] = getattr(backend, "current_job_id", "")
return TrainingJobResponse(
job_id = existing_job_id or job_id,
job_id = existing_job_id or "",
status = "error",
message = (
"Training is already in progress. "
@ -132,6 +137,12 @@ async def start_training(
error = "Training already active",
)
# Generate job ID — passed into start_training() which sets it on the
# backend only after confirming the old pump thread is dead.
job_id = (
f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{_uuid.uuid4().hex[:8]}"
)
# Validate dataset paths if provided
if request.local_datasets:
request.local_datasets = _validate_local_dataset_paths(
@ -199,6 +210,7 @@ async def start_training(
"enable_tensorboard": request.enable_tensorboard,
"tensorboard_dir": request.tensorboard_dir or "",
"trust_remote_code": request.trust_remote_code,
"gpu_ids": request.gpu_ids,
}
# Training page has no trust_remote_code toggle — the value comes from
@ -248,12 +260,12 @@ async def start_training(
logger.warning("Could not shut down export subprocess: %s", e)
# start_training now spawns a subprocess (non-blocking)
success = backend.start_training(**training_kwargs)
success = backend.start_training(job_id = job_id, **training_kwargs)
if not success:
progress_error = backend.trainer.training_progress.error
return TrainingJobResponse(
job_id = job_id,
job_id = backend.current_job_id or "",
status = "error",
message = progress_error or "Failed to start training subprocess",
error = progress_error or "subprocess_start_failed",
@ -266,6 +278,9 @@ async def start_training(
error = None,
)
except ValueError as e:
logger.warning("Rejected training GPU selection: %s", e)
raise HTTPException(status_code = 400, detail = str(e))
except Exception as e:
logger.error(f"Error starting training: {e}", exc_info = True)
raise HTTPException(
@ -345,7 +360,7 @@ async def reset_training(
error = None,
status_message = "Ready to train",
step = 0,
loss = 0.0,
loss = None,
epoch = 0,
total_steps = 0,
)
@ -419,8 +434,8 @@ async def get_training_status(
"epoch": getattr(progress, "epoch", 0),
"step": getattr(progress, "step", 0),
"total_steps": getattr(progress, "total_steps", 0),
"loss": getattr(progress, "loss", 0.0),
"learning_rate": getattr(progress, "learning_rate", 0.0),
"loss": getattr(progress, "loss", None),
"learning_rate": getattr(progress, "learning_rate", None),
}
# Build metric history for chart recovery after SSE reconnection
@ -526,8 +541,8 @@ async def stream_training_progress(
# ── Helpers ──────────────────────────────────────────────
def build_progress(
step: int,
loss: float,
learning_rate: float,
loss: Optional[float],
learning_rate: Optional[float],
total_steps: int,
epoch: Optional[float] = None,
progress: Optional[Any] = None,
@ -604,10 +619,10 @@ async def stream_training_progress(
loss_val = (
backend.loss_history[i]
if i < len(backend.loss_history)
else 0.0
else None
)
lr_val = (
backend.lr_history[i] if i < len(backend.lr_history) else 0.0
backend.lr_history[i] if i < len(backend.lr_history) else None
)
tp_replay = getattr(
getattr(backend, "trainer", None), "training_progress", None
@ -645,8 +660,8 @@ async def stream_training_progress(
initial_progress = build_progress(
step = 0,
loss = 0.0,
learning_rate = 0.0,
loss = None,
learning_rate = None,
total_steps = initial_total_steps,
epoch = initial_epoch,
progress = tp,
@ -660,9 +675,9 @@ async def stream_training_progress(
if backend.step_history:
final_step = backend.step_history[-1]
final_loss = (
backend.loss_history[-1] if backend.loss_history else 0.0
backend.loss_history[-1] if backend.loss_history else None
)
final_lr = backend.lr_history[-1] if backend.lr_history else 0.0
final_lr = backend.lr_history[-1] if backend.lr_history else None
final_total_steps = (
getattr(tp, "total_steps", final_step) if tp else final_step
)
@ -680,7 +695,9 @@ async def stream_training_progress(
)
else:
yield format_sse(
build_progress(-1, 0.0, 0.0, 0, progress = tp).model_dump_json(),
build_progress(
-1, None, None, 0, progress = tp
).model_dump_json(),
event = "complete",
event_id = 0,
)
@ -698,9 +715,9 @@ async def stream_training_progress(
if backend.step_history:
current_step = backend.step_history[-1]
current_loss = (
backend.loss_history[-1] if backend.loss_history else 0.0
backend.loss_history[-1] if backend.loss_history else None
)
current_lr = backend.lr_history[-1] if backend.lr_history else 0.0
current_lr = backend.lr_history[-1] if backend.lr_history else None
tp_inner = getattr(
getattr(backend, "trainer", None), "training_progress", None
)
@ -763,8 +780,8 @@ async def stream_training_progress(
)
preparing_payload = build_progress(
0,
0.0,
0.0,
None,
None,
prep_total,
progress = tp_prep,
)
@ -781,7 +798,7 @@ async def stream_training_progress(
getattr(backend, "trainer", None), "training_progress", None
)
timeout_payload = build_progress(
last_step, 0.0, 0.0, 0, progress = tp_timeout
last_step, None, None, 0, progress = tp_timeout
)
yield format_sse(
timeout_payload.model_dump_json(),
@ -797,7 +814,7 @@ async def stream_training_progress(
tp_error = getattr(
getattr(backend, "trainer", None), "training_progress", None
)
error_payload = build_progress(0, 0.0, 0.0, 0, progress = tp_error)
error_payload = build_progress(0, None, None, 0, progress = tp_error)
yield format_sse(
error_payload.model_dump_json(),
event = "error",
@ -807,8 +824,8 @@ async def stream_training_progress(
# ── Final "complete" event ───────────────────────────────
final_step = backend.step_history[-1] if backend.step_history else last_step
final_loss = backend.loss_history[-1] if backend.loss_history else 0.0
final_lr = backend.lr_history[-1] if backend.lr_history else 0.0
final_loss = backend.loss_history[-1] if backend.loss_history else None
final_lr = backend.lr_history[-1] if backend.lr_history else None
final_tp = getattr(getattr(backend, "trainer", None), "training_progress", None)
final_total_steps = (
getattr(final_tp, "total_steps", final_step) if final_tp else final_step

View file

@ -0,0 +1,85 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""
Training history API routes browse, view, and delete past training runs.
"""
import json
from fastapi import APIRouter, Depends, HTTPException, Query
from loggers import get_logger
from auth.authentication import get_current_subject
from models import (
TrainingRunDeleteResponse,
TrainingRunDetailResponse,
TrainingRunListResponse,
TrainingRunMetrics,
TrainingRunSummary,
)
from storage.studio_db import delete_run, get_run, get_run_metrics, list_runs
logger = get_logger(__name__)
router = APIRouter()
@router.get("/runs", response_model = TrainingRunListResponse)
async def list_training_runs(
limit: int = Query(50, ge = 1, le = 200),
offset: int = Query(0, ge = 0),
current_subject: str = Depends(get_current_subject),
):
"""List training runs, newest first."""
result = list_runs(limit = limit, offset = offset)
return TrainingRunListResponse(
runs = [TrainingRunSummary(**r) for r in result["runs"]],
total = result["total"],
)
@router.get("/runs/{run_id}", response_model = TrainingRunDetailResponse)
async def get_training_run_detail(
run_id: str,
current_subject: str = Depends(get_current_subject),
):
"""Get a single training run with full config and metrics."""
run = get_run(run_id)
if run is None:
raise HTTPException(status_code = 404, detail = f"Run {run_id} not found")
try:
config = json.loads(run.get("config_json", "{}"))
except (json.JSONDecodeError, TypeError):
logger.debug("Failed to parse config_json for run %s", run_id)
config = {}
metrics_data = get_run_metrics(run_id)
return TrainingRunDetailResponse(
run = TrainingRunSummary(**{k: v for k, v in run.items() if k != "config_json"}),
config = config,
metrics = TrainingRunMetrics(**metrics_data),
)
@router.delete("/runs/{run_id}", response_model = TrainingRunDeleteResponse)
async def delete_training_run(
run_id: str,
current_subject: str = Depends(get_current_subject),
):
"""Delete a training run and its metrics (CASCADE)."""
run = get_run(run_id)
if run is None:
raise HTTPException(status_code = 404, detail = f"Run {run_id} not found")
if run["status"] == "running":
raise HTTPException(
status_code = 409, detail = "Cannot delete a running training run"
)
logger.info("Deleting training run %s", run_id)
delete_run(run_id)
return TrainingRunDeleteResponse(
status = "deleted",
message = f"Run {run_id} deleted",
)

View file

@ -24,6 +24,7 @@ if str(backend_dir) not in sys.path:
import _platform_compat # noqa: F401
from loggers import get_logger
from startup_banner import print_studio_access_banner
logger = get_logger(__name__)
@ -73,18 +74,79 @@ def _resolve_external_ip() -> str:
return "0.0.0.0"
def _get_pid_on_port(port: int) -> "tuple[int, str] | None":
"""Return (pid, process_name) of the process listening on *port*, or None.
Uses psutil when available. Falls back gracefully to None so callers
can still report the port conflict without process details.
Works on Windows, macOS, and Linux wherever psutil is installed.
"""
try:
import psutil
except ImportError:
return None
try:
for conn in psutil.net_connections(kind = "tcp"):
if conn.status == "LISTEN" and conn.laddr.port == port:
if conn.pid is None:
return None
try:
proc = psutil.Process(conn.pid)
return (conn.pid, proc.name())
except (psutil.NoSuchProcess, psutil.AccessDenied):
return (conn.pid, "<unknown>")
except (psutil.AccessDenied, OSError) as e:
# psutil.net_connections() needs elevated privileges on some platforms
logger.debug("Failed to scan network connections for port %s: %s", port, e)
return None
def _is_port_free(host: str, port: int) -> bool:
"""Check if a port is available for binding."""
"""Check if a port is available for binding.
When *host* is ``0.0.0.0`` (wildcard), we also check whether anything
is already listening on ``127.0.0.1`` (and ``::1`` when IPv6 is
available). An SSH tunnel or similar process may hold the loopback
address while our wildcard bind still succeeds, making Unsloth Studio
unreachable via ``localhost``.
Works on Windows, macOS, and Linux.
"""
import socket
# 1. Can we bind to the requested address?
# Use getaddrinfo so both IPv4 ("0.0.0.0") and IPv6 ("::") hosts
# resolve to the correct address family automatically.
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
addr_info = socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM)
family, socktype, proto, _, sockaddr = addr_info[0]
with socket.socket(family, socktype, proto) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind((host, port))
return True
s.bind(sockaddr)
except OSError:
return False
# 2. When binding to all interfaces, verify that localhost is not
# already claimed by another process (e.g. an SSH -L tunnel).
# We attempt a TCP connect -- if it succeeds something is listening.
if host in ("0.0.0.0", "::"):
for loopback, family in [
("127.0.0.1", socket.AF_INET),
("::1", socket.AF_INET6),
]:
try:
with socket.socket(family, socket.SOCK_STREAM) as s:
s.settimeout(1)
if s.connect_ex((loopback, port)) == 0:
# Connection succeeded -- port is taken on loopback
return False
except OSError:
# IPv6 disabled or other OS-level restriction -- skip
continue
return True
def _find_free_port(host: str, start: int, max_attempts: int = 20) -> int:
"""Find a free port starting from `start`, trying up to max_attempts ports."""
@ -97,6 +159,29 @@ def _find_free_port(host: str, start: int, max_attempts: int = 20) -> int:
)
_PID_FILE = Path.home() / ".unsloth" / "studio" / "studio.pid"
def _write_pid_file():
"""Write the current process PID to the studio PID file."""
try:
_PID_FILE.parent.mkdir(parents = True, exist_ok = True)
_PID_FILE.write_text(str(os.getpid()))
except OSError:
pass
def _remove_pid_file():
"""Remove the PID file if it belongs to this process."""
try:
if _PID_FILE.is_file():
stored = _PID_FILE.read_text().strip()
if stored == str(os.getpid()):
_PID_FILE.unlink(missing_ok = True)
except OSError:
pass
def _graceful_shutdown(server = None):
"""Explicitly shut down all subprocess backends and the uvicorn server.
@ -104,6 +189,7 @@ def _graceful_shutdown(server = None):
before the parent exits. This is critical on Windows where atexit
handlers are unreliable after Ctrl+C.
"""
_remove_pid_file()
logger.info("Graceful shutdown initiated — cleaning up subprocesses...")
# 1. Shut down uvicorn server (releases the listening socket)
@ -149,11 +235,11 @@ def _graceful_shutdown(server = None):
logger.info("All subprocesses cleaned up")
# The uvicorn server instance set by run_server(), used by callers
# The uvicorn server instance -- set by run_server(), used by callers
# that need to tell the server to exit (e.g. signal handlers).
_server = None
# Shutdown event used to wake the main loop on signal
# Shutdown event -- used to wake the main loop on signal
_shutdown_event = None
@ -162,6 +248,7 @@ def run_server(
port: int = 8888,
frontend_path: Path = Path(__file__).resolve().parent.parent / "frontend" / "dist",
silent: bool = False,
llama_parallel_slots: int = 1,
):
"""
Start the FastAPI server.
@ -171,6 +258,7 @@ def run_server(
port: Port to bind to (auto-increments if in use)
frontend_path: Path to frontend build directory (optional)
silent: Suppress startup messages
llama_parallel_slots: Number of parallel slots for llama-server
Note:
Signal handlers are NOT registered here so that embedders
@ -205,18 +293,31 @@ def run_server(
# Auto-find free port if requested port is in use
if not _is_port_free(host, port):
original_port = port
port = _find_free_port(host, port)
blocker = _get_pid_on_port(port)
port = _find_free_port(host, port + 1)
if not silent:
print(f"Port {original_port} is in use, using port {port} instead")
print("")
print("=" * 50)
if blocker:
pid, name = blocker
print(
f"Port {original_port} is already in use by " f"{name} (PID {pid})."
)
else:
print(f"Port {original_port} is already in use.")
print(f"Unsloth Studio will use port {port} instead.")
print(f"Open http://localhost:{port} in your browser.")
print("=" * 50)
print("")
# Setup frontend if path provided
if frontend_path:
if setup_frontend(app, frontend_path):
if not silent:
print(f"✅ Frontend loaded from {frontend_path}")
print(f"[OK] Frontend loaded from {frontend_path}")
else:
if not silent:
print(f"⚠️ Frontend not found at {frontend_path}")
print(f"[WARNING] Frontend not found at {frontend_path}")
# Create the uvicorn server and expose it for signal handlers
config = uvicorn.Config(
@ -225,6 +326,15 @@ def run_server(
_server = uvicorn.Server(config)
_shutdown_event = Event()
# Expose the actual bound port so request-handling code can build
# loopback URLs that point at the real backend, not whatever port a
# reverse proxy or tunnel exposed in the request URL. Only publish
# an explicit value when we know the concrete port; for ephemeral
# binds (port==0) leave it unset and let request handlers fall back
# to the ASGI request scope or request.base_url.
app.state.server_port = port if port and port > 0 else None
app.state.llama_parallel_slots = llama_parallel_slots
# Run server in a daemon thread
def _run():
asyncio.run(_server.serve())
@ -233,21 +343,27 @@ def run_server(
thread.start()
time.sleep(3)
_write_pid_file()
import atexit
atexit.register(_remove_pid_file)
# Expose a shutdown callable via app.state so the /api/shutdown endpoint
# can trigger graceful shutdown without circular imports.
def _trigger_shutdown():
_graceful_shutdown(_server)
if _shutdown_event is not None:
_shutdown_event.set()
app.state.trigger_shutdown = _trigger_shutdown
if not silent:
display_host = _resolve_external_ip() if host == "0.0.0.0" else host
print("")
print("=" * 50)
print(f"🦥 Open your web browser, and enter http://localhost:{port}")
print("=" * 50)
print("")
print("=" * 50)
print(f"🦥 Unsloth Studio is running on port {port}")
print(f" Local Access: http://localhost:{port}")
print(f" Worldwide Web Address: http://{display_host}:{port}")
print(f" API: http://{display_host}:{port}/api")
print(f" Health: http://{display_host}:{port}/api/health")
print("=" * 50)
print_studio_access_banner(
port = port,
bind_host = host,
display_host = display_host,
)
return app
@ -297,7 +413,7 @@ if __name__ == "__main__":
sys.stderr.flush()
sys.exit(1)
# ── Signal handler — ensures subprocess cleanup on Ctrl+C ────
# Signal handler -- ensures subprocess cleanup on Ctrl+C
def _signal_handler(signum, frame):
_graceful_shutdown(_server)
_shutdown_event.set()

View file

@ -0,0 +1,123 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""Terminal banner for Studio startup.
Stdlib only safe to import without the rest of the backend (no structlog/uvicorn).
"""
from __future__ import annotations
import os
import sys
def stdout_supports_color() -> bool:
"""True if we should emit ANSI colors."""
if os.environ.get("NO_COLOR", "").strip():
return False
if os.environ.get("FORCE_COLOR", "").strip():
return True
try:
return sys.stdout.isatty()
except (AttributeError, OSError, ValueError):
return False
def print_port_in_use_notice(original_port: int, new_port: int) -> None:
"""Message when the requested port is taken and another is chosen."""
msg = f"Port {original_port} is in use, using port {new_port} instead."
if stdout_supports_color():
print(f"\033[38;5;245m{msg}\033[0m")
else:
print(msg)
def print_studio_access_banner(
*,
port: int,
bind_host: str,
display_host: str,
) -> None:
"""Pretty-print URLs after the server is listening (beginner-friendly)."""
use_color = stdout_supports_color()
dim = "\033[38;5;245m"
title = "\033[38;5;150m"
local_url_style = "\033[38;5;108;1m"
secondary = "\033[38;5;109m"
reset = "\033[0m"
def style(text: str, code: str) -> str:
return f"{code}{text}{reset}" if use_color else text
ipv6_bind = bind_host in ("::", "::1")
if ipv6_bind:
loopback_url = f"http://[::1]:{port}"
alt_local = f"http://localhost:{port}"
else:
loopback_url = f"http://127.0.0.1:{port}"
alt_local = f"http://localhost:{port}"
if ":" in display_host:
external_url = f"http://[{display_host}]:{port}"
else:
external_url = f"http://{display_host}:{port}"
listen_all = bind_host in ("0.0.0.0", "::")
loopback_bind = bind_host in ("127.0.0.1", "localhost", "::1")
# Use loopback URL only when the server is reachable on loopback;
# otherwise show the actual bound address.
primary_url = loopback_url if listen_all or loopback_bind else external_url
tip_url = alt_local if listen_all or loopback_bind else external_url
api_base = primary_url
lines: list[str] = [
"",
style("🦥 Unsloth Studio is running", title),
style("" * 52, dim),
style(" On this machine -- open this in your browser:", dim),
style(f" {primary_url}", local_url_style),
]
if (listen_all or loopback_bind) and primary_url != alt_local:
lines.append(style(f" (same as {alt_local})", dim))
if listen_all and display_host not in (
"127.0.0.1",
"localhost",
"::1",
"0.0.0.0",
"::",
):
lines.extend(
[
"",
style(" From another device on your network / to share:", dim),
style(f" {external_url}", secondary),
]
)
elif not listen_all and not loopback_bind and external_url != primary_url:
lines.extend(
[
"",
style(" Bound address:", dim),
style(f" {external_url}", secondary),
]
)
lines.extend(
[
"",
style(" API & health:", dim),
style(f" {api_base}/api", secondary),
style(f" {api_base}/api/health", secondary),
style("" * 52, dim),
style(
f" Tip: if you are on this computer, open {tip_url}/ in your browser.",
dim,
),
"",
]
)
print("\n".join(lines))

View file

@ -0,0 +1,2 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0

View file

@ -0,0 +1,488 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""
SQLite storage for training run history and metrics.
Follows the same pattern as auth/storage.py module-level functions,
raw sqlite3, per-function connections. Enhancements over auth:
- WAL mode for concurrent read/write access
- PRAGMA foreign_keys = ON for CASCADE deletes
"""
import json
import logging
import os
import platform
import sqlite3
import threading
from datetime import datetime, timezone
logger = logging.getLogger(__name__)
from typing import Optional
from utils.paths import studio_db_path, ensure_dir
def _denied_path_prefixes() -> list[str]:
"""Platform-aware denylist of system directories."""
system = platform.system()
if system == "Linux":
return ["/proc", "/sys", "/dev", "/etc", "/boot", "/run"]
if system == "Darwin":
# realpath() resolves /etc -> /private/etc, /tmp -> /private/tmp on macOS,
# so include the /private variants to avoid bypasses.
return [
"/System",
"/Library",
"/dev",
"/etc",
"/private/etc",
"/tmp",
"/private/tmp",
"/var",
"/private/var",
]
if system == "Windows":
win = os.environ.get("SystemRoot", r"C:\Windows")
pf = os.environ.get("ProgramFiles", r"C:\Program Files")
pf86 = os.environ.get("ProgramFiles(x86)", r"C:\Program Files (x86)")
return [os.path.normcase(p) for p in [win, pf, pf86]]
return []
_schema_lock = threading.Lock()
_schema_ready = False
def _ensure_schema(conn: sqlite3.Connection) -> None:
"""Create tables and indexes if they don't exist. Called once per process."""
conn.execute("PRAGMA journal_mode=WAL")
conn.execute(
"""
CREATE TABLE IF NOT EXISTS training_runs (
id TEXT NOT NULL PRIMARY KEY,
status TEXT NOT NULL DEFAULT 'running',
model_name TEXT NOT NULL,
dataset_name TEXT NOT NULL,
config_json TEXT NOT NULL,
started_at TEXT NOT NULL,
ended_at TEXT,
total_steps INTEGER,
final_step INTEGER,
final_loss REAL,
output_dir TEXT,
error_message TEXT,
duration_seconds REAL,
loss_sparkline TEXT
)
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS training_metrics (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_id TEXT NOT NULL REFERENCES training_runs(id) ON DELETE CASCADE,
step INTEGER NOT NULL,
loss REAL,
learning_rate REAL,
grad_norm REAL,
eval_loss REAL,
epoch REAL,
num_tokens INTEGER,
elapsed_seconds REAL,
UNIQUE(run_id, step)
)
"""
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_metrics_run_id ON training_metrics(run_id)"
)
# Use COLLATE NOCASE on Windows so C:\Models and c:\models dedup via the
# UNIQUE constraint. On Linux/macOS (case-sensitive FS) keep the default
# BINARY collation so /Models and /models remain distinct.
collation = "COLLATE NOCASE" if platform.system() == "Windows" else ""
conn.execute(
f"""
CREATE TABLE IF NOT EXISTS scan_folders (
id INTEGER PRIMARY KEY AUTOINCREMENT,
path TEXT NOT NULL UNIQUE {collation},
created_at TEXT NOT NULL
)
"""
)
def get_connection() -> sqlite3.Connection:
"""Open studio.db with WAL mode, create tables once per process, enable foreign keys."""
global _schema_ready
db_path = studio_db_path()
ensure_dir(db_path.parent)
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
# foreign_keys is session-scoped, must be set per connection
conn.execute("PRAGMA foreign_keys=ON")
if not _schema_ready:
with _schema_lock:
if not _schema_ready:
try:
_ensure_schema(conn)
_schema_ready = True
except Exception:
conn.close()
raise
return conn
def create_run(
id: str,
model_name: str,
dataset_name: str,
config_json: str,
started_at: str,
total_steps: Optional[int],
) -> None:
conn = get_connection()
try:
conn.execute(
"""
INSERT INTO training_runs (id, model_name, dataset_name, config_json, started_at, total_steps)
VALUES (?, ?, ?, ?, ?, ?)
""",
(id, model_name, dataset_name, config_json, started_at, total_steps),
)
conn.commit()
finally:
conn.close()
def update_run_total_steps(id: str, total_steps: int) -> None:
conn = get_connection()
try:
conn.execute(
"UPDATE training_runs SET total_steps = ? WHERE id = ?",
(total_steps, id),
)
conn.commit()
finally:
conn.close()
def update_run_progress(
id: str, step: int, loss: Optional[float], duration_seconds: Optional[float]
) -> None:
"""Update current progress on a running training run (called on each metric flush)."""
conn = get_connection()
try:
conn.execute(
"UPDATE training_runs SET final_step = ?, final_loss = ?, duration_seconds = ? WHERE id = ?",
(step, loss, duration_seconds, id),
)
conn.commit()
finally:
conn.close()
def finish_run(
id: str,
status: str,
ended_at: str,
final_step: Optional[int],
final_loss: Optional[float],
duration_seconds: Optional[float],
loss_sparkline: Optional[str] = None,
output_dir: Optional[str] = None,
error_message: Optional[str] = None,
) -> None:
conn = get_connection()
try:
conn.execute(
"""
UPDATE training_runs
SET status = ?, ended_at = ?, final_step = ?, final_loss = ?,
duration_seconds = ?, loss_sparkline = ?, output_dir = ?,
error_message = ?
WHERE id = ?
""",
(
status,
ended_at,
final_step,
final_loss,
duration_seconds,
loss_sparkline,
output_dir,
error_message,
id,
),
)
conn.commit()
finally:
conn.close()
def insert_metrics_batch(run_id: str, metrics: list[dict]) -> None:
if not metrics:
return
conn = get_connection()
try:
conn.executemany(
"""
INSERT INTO training_metrics
(run_id, step, loss, learning_rate, grad_norm, eval_loss, epoch, num_tokens, elapsed_seconds)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(run_id, step) DO UPDATE SET
loss = COALESCE(excluded.loss, loss),
learning_rate = COALESCE(excluded.learning_rate, learning_rate),
grad_norm = COALESCE(excluded.grad_norm, grad_norm),
eval_loss = COALESCE(excluded.eval_loss, eval_loss),
epoch = COALESCE(excluded.epoch, epoch),
num_tokens = COALESCE(excluded.num_tokens, num_tokens),
elapsed_seconds = COALESCE(excluded.elapsed_seconds, elapsed_seconds)
""",
[
(
run_id,
m.get("step"),
m.get("loss"),
m.get("learning_rate"),
m.get("grad_norm"),
m.get("eval_loss"),
m.get("epoch"),
m.get("num_tokens"),
m.get("elapsed_seconds"),
)
for m in metrics
],
)
conn.commit()
finally:
conn.close()
def list_runs(limit: int = 50, offset: int = 0) -> dict:
conn = get_connection()
try:
total = conn.execute("SELECT COUNT(*) FROM training_runs").fetchone()[0]
rows = conn.execute(
"""
SELECT id, status, model_name, dataset_name, started_at, ended_at,
total_steps, final_step, final_loss, output_dir,
duration_seconds, error_message, loss_sparkline
FROM training_runs
ORDER BY started_at DESC
LIMIT ? OFFSET ?
""",
(limit, offset),
).fetchall()
runs = []
for row in rows:
run = dict(row)
sparkline = run.get("loss_sparkline")
if sparkline:
try:
run["loss_sparkline"] = json.loads(sparkline)
except (json.JSONDecodeError, TypeError):
logger.debug(
"Failed to parse loss_sparkline for run %s", run.get("id")
)
run["loss_sparkline"] = None
runs.append(run)
return {"runs": runs, "total": total}
finally:
conn.close()
def get_run(id: str) -> Optional[dict]:
conn = get_connection()
try:
row = conn.execute("SELECT * FROM training_runs WHERE id = ?", (id,)).fetchone()
if row is None:
return None
run = dict(row)
sparkline = run.get("loss_sparkline")
if sparkline:
try:
run["loss_sparkline"] = json.loads(sparkline)
except (json.JSONDecodeError, TypeError):
logger.debug("Failed to parse loss_sparkline for run %s", id)
run["loss_sparkline"] = None
return run
finally:
conn.close()
def get_run_metrics(id: str) -> dict:
"""Return metric arrays for a run, using paired step arrays per metric."""
conn = get_connection()
try:
rows = conn.execute(
"""
SELECT step, loss, learning_rate, grad_norm, eval_loss, epoch,
num_tokens, elapsed_seconds
FROM training_metrics
WHERE run_id = ?
ORDER BY step
""",
(id,),
).fetchall()
step_history: list[int] = []
loss_history: list[float] = []
loss_step_history: list[int] = []
lr_history: list[float] = []
lr_step_history: list[int] = []
grad_norm_history: list[float] = []
grad_norm_step_history: list[int] = []
eval_loss_history: list[float] = []
eval_step_history: list[int] = []
final_epoch: float | None = None
final_num_tokens: int | None = None
for row in rows:
step = row["step"]
step_history.append(step)
if step > 0 and row["loss"] is not None:
loss_history.append(row["loss"])
loss_step_history.append(step)
if step > 0 and row["learning_rate"] is not None:
lr_history.append(row["learning_rate"])
lr_step_history.append(step)
if step > 0 and row["grad_norm"] is not None:
grad_norm_history.append(row["grad_norm"])
grad_norm_step_history.append(step)
if step > 0 and row["eval_loss"] is not None:
eval_loss_history.append(row["eval_loss"])
eval_step_history.append(step)
if row["epoch"] is not None:
final_epoch = row["epoch"]
if row["num_tokens"] is not None:
final_num_tokens = row["num_tokens"]
return {
"step_history": step_history,
"loss_history": loss_history,
"loss_step_history": loss_step_history,
"lr_history": lr_history,
"lr_step_history": lr_step_history,
"grad_norm_history": grad_norm_history,
"grad_norm_step_history": grad_norm_step_history,
"eval_loss_history": eval_loss_history,
"eval_step_history": eval_step_history,
"final_epoch": final_epoch,
"final_num_tokens": final_num_tokens,
}
finally:
conn.close()
def delete_run(id: str) -> None:
conn = get_connection()
try:
conn.execute("DELETE FROM training_runs WHERE id = ?", (id,))
conn.commit()
finally:
conn.close()
def cleanup_orphaned_runs() -> None:
"""Mark any 'running' rows as errored on startup (server restarted mid-training)."""
conn = get_connection()
try:
conn.execute(
"""
UPDATE training_runs
SET status = 'error',
error_message = 'Server restarted during training',
ended_at = ?
WHERE status = 'running'
""",
(datetime.now(timezone.utc).isoformat(),),
)
conn.commit()
finally:
conn.close()
def list_scan_folders() -> list[dict]:
conn = get_connection()
try:
rows = conn.execute(
"SELECT id, path, created_at FROM scan_folders ORDER BY created_at"
).fetchall()
return [dict(row) for row in rows]
finally:
conn.close()
def add_scan_folder(path: str) -> dict:
"""Add a directory to the custom scan folder list. Returns the row."""
if not path or not path.strip():
raise ValueError("Path cannot be empty")
normalized = os.path.realpath(os.path.expanduser(path.strip()))
# Validate the path is an existing, readable directory before persisting.
if not os.path.exists(normalized):
raise ValueError("Path does not exist")
if not os.path.isdir(normalized):
raise ValueError("Path must be a directory, not a file")
if not os.access(normalized, os.R_OK | os.X_OK):
raise ValueError("Path is not readable")
# On Windows, use normcase for denylist comparison but store the
# original-cased path so downstream consumers see the native
# drive-letter casing the user expects (e.g. C:\Models, not c:\models).
is_win = platform.system() == "Windows"
check = os.path.normcase(normalized) if is_win else normalized
for prefix in _denied_path_prefixes():
if check == prefix or check.startswith(prefix + os.sep):
raise ValueError(f"Path under {prefix} is not allowed")
conn = get_connection()
try:
now = datetime.now(timezone.utc).isoformat()
# On Windows, use case-insensitive lookup so C:\Models and c:\models
# dedup correctly while preserving the originally-stored casing.
if is_win:
existing = conn.execute(
"SELECT id, path, created_at FROM scan_folders WHERE path = ? COLLATE NOCASE",
(normalized,),
).fetchone()
else:
existing = conn.execute(
"SELECT id, path, created_at FROM scan_folders WHERE path = ?",
(normalized,),
).fetchone()
if existing is not None:
return dict(existing)
try:
conn.execute(
"INSERT INTO scan_folders (path, created_at) VALUES (?, ?)",
(normalized, now),
)
conn.commit()
except sqlite3.IntegrityError:
pass # duplicate -- fall through to SELECT
# Use the same collation as the pre-check so we find the row even
# when a concurrent writer stored it with different casing (Windows).
fallback_sql = (
"SELECT id, path, created_at FROM scan_folders WHERE path = ? COLLATE NOCASE"
if is_win
else "SELECT id, path, created_at FROM scan_folders WHERE path = ?"
)
row = conn.execute(fallback_sql, (normalized,)).fetchone()
if row is None:
raise ValueError("Folder was concurrently removed")
return dict(row)
finally:
conn.close()
def remove_scan_folder(id: int) -> None:
conn = get_connection()
try:
conn.execute("DELETE FROM scan_folders WHERE id = ?", (id,))
conn.commit()
finally:
conn.close()

View file

@ -3,14 +3,136 @@
"""
Shared pytest configuration for the backend test suite.
Ensures that the backend root is on sys.path so that
`import utils.utils` (and similar flat imports) resolve correctly.
Responsibilities:
1. Put the backend root on sys.path so `from models.inference import ...`
(and similar flat imports) resolve in test modules mirrors how the
app itself is launched.
2. Provide a hybrid ``studio_server`` session fixture for end-to-end tests
(see ``test_studio_api.py``). The fixture supports two invocation modes:
a. **External server.** If ``UNSLOTH_E2E_BASE_URL`` is set, tests point
at an already-running Studio instance. ``UNSLOTH_E2E_API_KEY`` must
also be set. This is the fast-iteration mode: start the server once
with ``unsloth studio run ...``, then run pytest against it many
times with no per-run GGUF load cost.
b. **Fixture-managed server.** Otherwise, the fixture launches a fresh
server via ``_start_server`` and tears it down at session end. This
is the one-shot mode for CI or a clean-slate verification run.
The model / variant for mode (b) come from ``--unsloth-model`` /
``--unsloth-gguf-variant`` pytest options, then ``UNSLOTH_E2E_MODEL`` /
``UNSLOTH_E2E_VARIANT`` env vars, then the defaults in
``test_studio_api.py``.
"""
import os
import sys
from pathlib import Path
import pytest
# Add backend root to sys.path (mirrors how the app itself is launched)
_backend_root = Path(__file__).resolve().parent.parent
if str(_backend_root) not in sys.path:
sys.path.insert(0, str(_backend_root))
# ── Pytest CLI options ───────────────────────────────────────────────
def pytest_addoption(parser):
group = parser.getgroup(
"unsloth-e2e",
"Unsloth Studio end-to-end test options",
)
group.addoption(
"--unsloth-model",
action = "store",
default = None,
help = (
"GGUF model id used when starting a server for e2e tests. "
"Ignored if UNSLOTH_E2E_BASE_URL is set. Overrides "
"UNSLOTH_E2E_MODEL env var. Defaults to test_studio_api.py's "
"DEFAULT_MODEL."
),
)
group.addoption(
"--unsloth-gguf-variant",
action = "store",
default = None,
help = (
"GGUF variant used when starting a server for e2e tests. "
"Ignored if UNSLOTH_E2E_BASE_URL is set. Overrides "
"UNSLOTH_E2E_VARIANT env var. Defaults to test_studio_api.py's "
"DEFAULT_VARIANT."
),
)
# ── E2E server fixtures ──────────────────────────────────────────────
@pytest.fixture(scope = "session")
def studio_server(request):
"""Yield ``(base_url, api_key)`` for e2e tests.
Resolution order:
1. If ``UNSLOTH_E2E_BASE_URL`` is set point at that server,
require ``UNSLOTH_E2E_API_KEY`` alongside (skip if missing).
2. Otherwise start a fresh ``unsloth studio run`` subprocess via
the existing ``_start_server`` helper in ``test_studio_api.py``
and tear it down on session teardown.
Session-scoped so the expensive GGUF load happens at most once per
pytest invocation. Lazily instantiated tests that don't request
the fixture (e.g. the unit tests in ``test_anthropic_messages.py``
or ``test_help_output``) do not trigger server startup.
"""
external_url = os.environ.get("UNSLOTH_E2E_BASE_URL")
if external_url:
api_key = os.environ.get("UNSLOTH_E2E_API_KEY")
if not api_key:
pytest.skip(
"UNSLOTH_E2E_BASE_URL is set but UNSLOTH_E2E_API_KEY is "
"missing — tests that require auth cannot run against an "
"external server without it.",
)
yield external_url, api_key
return
# Lazy import: pytest has already loaded test_studio_api into
# sys.modules by the time any test requests this fixture, so this
# is a cache hit, not a re-execution.
import test_studio_api as _e2e
model = (
request.config.getoption("--unsloth-model")
or os.environ.get("UNSLOTH_E2E_MODEL")
or _e2e.DEFAULT_MODEL
)
variant = (
request.config.getoption("--unsloth-gguf-variant")
or os.environ.get("UNSLOTH_E2E_VARIANT")
or _e2e.DEFAULT_VARIANT
)
proc, api_key = _e2e._start_server(model, variant)
try:
yield f"http://{_e2e.HOST}:{_e2e.PORT}", api_key
finally:
_e2e._kill_server(proc)
@pytest.fixture
def base_url(studio_server):
"""Base URL for the e2e Studio server (from ``studio_server``)."""
return studio_server[0]
@pytest.fixture
def api_key(studio_server):
"""API key for the e2e Studio server (from ``studio_server``)."""
return studio_server[1]

View file

@ -0,0 +1,774 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved.
"""
Tests for the Anthropic Messages API schemas and translation layer.
No running server or GPU required.
"""
import sys
import os
import json
_backend = os.path.join(os.path.dirname(__file__), "..")
sys.path.insert(0, _backend)
from models.inference import (
AnthropicMessagesRequest,
AnthropicMessagesResponse,
AnthropicMessage,
AnthropicTextBlock,
AnthropicToolUseBlock,
AnthropicToolResultBlock,
AnthropicTool,
AnthropicUsage,
AnthropicResponseTextBlock,
AnthropicResponseToolUseBlock,
)
from core.inference.anthropic_compat import (
anthropic_messages_to_openai,
anthropic_tools_to_openai,
build_anthropic_sse_event,
AnthropicStreamEmitter,
AnthropicPassthroughEmitter,
)
# =====================================================================
# Pydantic model tests
# =====================================================================
class TestAnthropicModels:
def test_minimal_request(self):
req = AnthropicMessagesRequest(
messages = [{"role": "user", "content": "Hi"}],
)
assert req.max_tokens is None
assert req.model == "default"
assert req.stream is False
def test_max_tokens_optional(self):
req = AnthropicMessagesRequest(
max_tokens = 100,
messages = [{"role": "user", "content": "Hi"}],
)
assert req.max_tokens == 100
def test_system_as_string(self):
req = AnthropicMessagesRequest(
max_tokens = 50,
messages = [{"role": "user", "content": "Hi"}],
system = "You are helpful.",
)
assert req.system == "You are helpful."
def test_tools_field_parses(self):
req = AnthropicMessagesRequest(
max_tokens = 100,
messages = [{"role": "user", "content": "Hi"}],
tools = [{"name": "web_search", "input_schema": {"type": "object"}}],
)
assert len(req.tools) == 1
assert req.tools[0].name == "web_search"
def test_extra_fields_accepted(self):
req = AnthropicMessagesRequest(
max_tokens = 100,
messages = [{"role": "user", "content": "Hi"}],
some_future_field = "hello",
)
assert req.max_tokens == 100
def test_stream_defaults_false(self):
req = AnthropicMessagesRequest(
max_tokens = 100,
messages = [{"role": "user", "content": "Hi"}],
)
assert req.stream is False
def test_enable_tools_shorthand(self):
req = AnthropicMessagesRequest(
messages = [{"role": "user", "content": "Hi"}],
enable_tools = True,
enabled_tools = ["web_search", "python"],
session_id = "my-session",
)
assert req.enable_tools is True
assert req.enabled_tools == ["web_search", "python"]
assert req.session_id == "my-session"
def test_extension_fields_default_none(self):
req = AnthropicMessagesRequest(
messages = [{"role": "user", "content": "Hi"}],
)
assert req.enable_tools is None
assert req.enabled_tools is None
assert req.session_id is None
def test_response_model_defaults(self):
resp = AnthropicMessagesResponse()
assert resp.type == "message"
assert resp.role == "assistant"
assert resp.id.startswith("msg_")
assert resp.content == []
assert resp.usage.input_tokens == 0
# =====================================================================
# Message translation tests
# =====================================================================
class TestAnthropicMessagesToOpenAI:
def test_simple_user_message(self):
msgs = [{"role": "user", "content": "Hello"}]
result = anthropic_messages_to_openai(msgs)
assert result == [{"role": "user", "content": "Hello"}]
def test_system_string_prepended(self):
msgs = [{"role": "user", "content": "Hello"}]
result = anthropic_messages_to_openai(msgs, system = "Be brief.")
assert result[0] == {"role": "system", "content": "Be brief."}
assert result[1] == {"role": "user", "content": "Hello"}
def test_system_as_block_list(self):
system = [
{"type": "text", "text": "Be brief."},
{"type": "text", "text": "Be accurate."},
]
msgs = [{"role": "user", "content": "Hello"}]
result = anthropic_messages_to_openai(msgs, system = system)
assert result[0]["role"] == "system"
assert "Be brief." in result[0]["content"]
assert "Be accurate." in result[0]["content"]
def test_multi_turn_conversation(self):
msgs = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"},
{"role": "user", "content": "How are you?"},
]
result = anthropic_messages_to_openai(msgs)
assert len(result) == 3
assert result[0]["role"] == "user"
assert result[1]["role"] == "assistant"
assert result[2]["role"] == "user"
def test_assistant_tool_use_maps_to_tool_calls(self):
msgs = [
{
"role": "assistant",
"content": [
{"type": "text", "text": "Let me search."},
{
"type": "tool_use",
"id": "tu_1",
"name": "web_search",
"input": {"query": "test"},
},
],
}
]
result = anthropic_messages_to_openai(msgs)
assert len(result) == 1
m = result[0]
assert m["role"] == "assistant"
assert m["content"] == "Let me search."
assert len(m["tool_calls"]) == 1
tc = m["tool_calls"][0]
assert tc["id"] == "tu_1"
assert tc["function"]["name"] == "web_search"
assert json.loads(tc["function"]["arguments"]) == {"query": "test"}
def test_tool_result_maps_to_tool_role(self):
msgs = [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "tu_1",
"content": "Result text",
},
],
}
]
result = anthropic_messages_to_openai(msgs)
assert len(result) == 1
assert result[0]["role"] == "tool"
assert result[0]["tool_call_id"] == "tu_1"
assert result[0]["content"] == "Result text"
def test_mixed_text_and_tool_use_blocks(self):
msgs = [
{
"role": "assistant",
"content": [
{"type": "text", "text": "Thinking..."},
{
"type": "tool_use",
"id": "tu_1",
"name": "python",
"input": {"code": "1+1"},
},
{
"type": "tool_use",
"id": "tu_2",
"name": "terminal",
"input": {"command": "ls"},
},
],
}
]
result = anthropic_messages_to_openai(msgs)
assert len(result) == 1
m = result[0]
assert m["content"] == "Thinking..."
assert len(m["tool_calls"]) == 2
def test_tool_result_with_list_content(self):
msgs = [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "tu_1",
"content": [
{"type": "text", "text": "Line 1"},
{"type": "text", "text": "Line 2"},
],
},
],
}
]
result = anthropic_messages_to_openai(msgs)
assert result[0]["content"] == "Line 1 Line 2"
# =====================================================================
# Tool translation tests
# =====================================================================
class TestAnthropicToolsToOpenAI:
def test_single_tool(self):
tools = [
{
"name": "web_search",
"description": "Search",
"input_schema": {
"type": "object",
"properties": {"query": {"type": "string"}},
},
}
]
result = anthropic_tools_to_openai(tools)
assert len(result) == 1
assert result[0]["type"] == "function"
assert result[0]["function"]["name"] == "web_search"
assert result[0]["function"]["parameters"]["type"] == "object"
def test_multiple_tools(self):
tools = [
{"name": "a", "description": "Tool A", "input_schema": {}},
{"name": "b", "description": "Tool B", "input_schema": {}},
]
result = anthropic_tools_to_openai(tools)
assert len(result) == 2
assert result[0]["function"]["name"] == "a"
assert result[1]["function"]["name"] == "b"
def test_empty_list(self):
assert anthropic_tools_to_openai([]) == []
def test_pydantic_model_input(self):
tool = AnthropicTool(
name = "test", description = "desc", input_schema = {"type": "object"}
)
result = anthropic_tools_to_openai([tool])
assert result[0]["function"]["name"] == "test"
# =====================================================================
# SSE event helper tests
# =====================================================================
class TestBuildAnthropicSSEEvent:
def test_basic_event(self):
result = build_anthropic_sse_event("message_start", {"type": "message_start"})
assert result.startswith("event: message_start\n")
assert "data: " in result
assert result.endswith("\n\n")
def test_data_is_valid_json(self):
result = build_anthropic_sse_event("test", {"key": "value"})
data_line = result.split("\n")[1]
payload = json.loads(data_line.removeprefix("data: "))
assert payload == {"key": "value"}
# =====================================================================
# Stream emitter tests
# =====================================================================
class TestAnthropicStreamEmitter:
def test_start_emits_message_start_and_content_block_start(self):
e = AnthropicStreamEmitter()
events = e.start("msg_123", "test-model")
assert len(events) == 2
assert "message_start" in events[0]
assert "content_block_start" in events[1]
assert '"type": "text"' in events[1]
def test_content_delta_emits_text_delta(self):
e = AnthropicStreamEmitter()
e.start("msg_1", "m")
events = e.feed({"type": "content", "text": "Hello"})
assert len(events) == 1
parsed = json.loads(events[0].split("data: ")[1])
assert parsed["delta"]["type"] == "text_delta"
assert parsed["delta"]["text"] == "Hello"
def test_cumulative_content_diffs_correctly(self):
e = AnthropicStreamEmitter()
e.start("msg_1", "m")
e.feed({"type": "content", "text": "Hel"})
events = e.feed({"type": "content", "text": "Hello"})
parsed = json.loads(events[0].split("data: ")[1])
assert parsed["delta"]["text"] == "lo"
def test_empty_content_diff_no_event(self):
e = AnthropicStreamEmitter()
e.start("msg_1", "m")
e.feed({"type": "content", "text": "Hi"})
events = e.feed({"type": "content", "text": "Hi"})
assert events == []
def test_tool_start_closes_text_opens_tool_block(self):
e = AnthropicStreamEmitter()
e.start("msg_1", "m")
e.feed({"type": "content", "text": "Thinking"})
events = e.feed(
{
"type": "tool_start",
"tool_name": "web_search",
"tool_call_id": "tc_1",
"arguments": {"query": "test"},
}
)
# content_block_stop + content_block_start(tool_use) + content_block_delta(input_json)
assert len(events) == 3
assert "content_block_stop" in events[0]
assert "tool_use" in events[1]
assert "input_json_delta" in events[2]
def test_tool_end_closes_tool_opens_new_text_block(self):
e = AnthropicStreamEmitter()
e.start("msg_1", "m")
e.feed(
{
"type": "tool_start",
"tool_name": "t",
"tool_call_id": "tc_1",
"arguments": {},
}
)
events = e.feed(
{
"type": "tool_end",
"tool_name": "t",
"tool_call_id": "tc_1",
"result": "done",
}
)
# content_block_stop (tool) + tool_result + content_block_start (new text)
assert len(events) == 3
assert "content_block_stop" in events[0]
assert "tool_result" in events[1]
parsed = json.loads(events[1].split("data: ")[1])
assert parsed["content"] == "done"
assert parsed["tool_use_id"] == "tc_1"
assert "content_block_start" in events[2]
assert '"type": "text"' in events[2]
def test_finish_emits_stop_events(self):
e = AnthropicStreamEmitter()
e.start("msg_1", "m")
events = e.finish("end_turn")
# content_block_stop + message_delta + message_stop
assert len(events) == 3
assert "content_block_stop" in events[0]
assert "message_delta" in events[1]
assert "end_turn" in events[1]
assert "message_stop" in events[2]
def test_metadata_captured_in_finish_usage(self):
e = AnthropicStreamEmitter()
e.start("msg_1", "m")
e.feed(
{
"type": "metadata",
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
}
)
events = e.finish("end_turn")
delta_event = [ev for ev in events if "message_delta" in ev][0]
parsed = json.loads(delta_event.split("data: ")[1])
assert parsed["usage"]["output_tokens"] == 20
def test_status_events_ignored(self):
e = AnthropicStreamEmitter()
e.start("msg_1", "m")
events = e.feed({"type": "status", "text": "Searching..."})
assert events == []
def test_no_tool_calls_simple_text_flow(self):
e = AnthropicStreamEmitter()
start_events = e.start("msg_1", "m")
content_events = e.feed({"type": "content", "text": "Hello world"})
meta_events = e.feed(
{"type": "metadata", "usage": {"prompt_tokens": 5, "completion_tokens": 2}}
)
end_events = e.finish("end_turn")
assert len(start_events) == 2
assert len(content_events) == 1
assert meta_events == []
assert len(end_events) == 3
def test_block_index_increments(self):
e = AnthropicStreamEmitter()
e.start("msg_1", "m")
assert e.block_index == 0
e.feed(
{
"type": "tool_start",
"tool_name": "t",
"tool_call_id": "tc_1",
"arguments": {},
}
)
assert e.block_index == 1
e.feed(
{
"type": "tool_end",
"tool_name": "t",
"tool_call_id": "tc_1",
"result": "ok",
}
)
assert e.block_index == 2
def test_text_after_tool_resets_prev_text(self):
e = AnthropicStreamEmitter()
e.start("msg_1", "m")
e.feed({"type": "content", "text": "Before tool"})
e.feed(
{
"type": "tool_start",
"tool_name": "t",
"tool_call_id": "tc_1",
"arguments": {},
}
)
e.feed(
{
"type": "tool_end",
"tool_name": "t",
"tool_call_id": "tc_1",
"result": "ok",
}
)
# After tool_end, prev_text should be reset
events = e.feed({"type": "content", "text": "After tool"})
parsed = json.loads(events[0].split("data: ")[1])
assert parsed["delta"]["text"] == "After tool"
# =====================================================================
# Pass-through emitter tests (client-side tool execution path)
# =====================================================================
class TestAnthropicPassthroughEmitter:
def _parse(self, event_str):
return json.loads(event_str.split("data: ")[1])
def test_start_emits_message_start_only(self):
e = AnthropicPassthroughEmitter()
events = e.start("msg_1", "test-model")
assert len(events) == 1
assert "message_start" in events[0]
parsed = self._parse(events[0])
assert parsed["message"]["id"] == "msg_1"
assert parsed["message"]["model"] == "test-model"
def test_text_chunk_opens_text_block_and_emits_delta(self):
e = AnthropicPassthroughEmitter()
e.start("msg_1", "m")
chunk = {"choices": [{"delta": {"content": "Hello"}}]}
events = e.feed_chunk(chunk)
# content_block_start + content_block_delta
assert len(events) == 2
assert "content_block_start" in events[0]
assert '"type": "text"' in events[0]
delta = self._parse(events[1])
assert delta["delta"]["type"] == "text_delta"
assert delta["delta"]["text"] == "Hello"
def test_sequential_text_chunks_single_block(self):
e = AnthropicPassthroughEmitter()
e.start("msg_1", "m")
events1 = e.feed_chunk({"choices": [{"delta": {"content": "Hello"}}]})
events2 = e.feed_chunk({"choices": [{"delta": {"content": " world"}}]})
# First chunk opens the block, second only emits delta
assert len(events1) == 2
assert len(events2) == 1
assert self._parse(events2[0])["delta"]["text"] == " world"
def test_tool_call_opens_tool_use_block(self):
e = AnthropicPassthroughEmitter()
e.start("msg_1", "m")
chunk = {
"choices": [
{
"delta": {
"tool_calls": [
{
"index": 0,
"id": "call_1",
"type": "function",
"function": {"name": "Bash", "arguments": ""},
}
]
}
}
]
}
events = e.feed_chunk(chunk)
assert len(events) == 1
parsed = self._parse(events[0])
assert parsed["type"] == "content_block_start"
assert parsed["content_block"]["type"] == "tool_use"
assert parsed["content_block"]["id"] == "call_1"
assert parsed["content_block"]["name"] == "Bash"
def test_tool_call_arguments_streamed_as_input_json_delta(self):
e = AnthropicPassthroughEmitter()
e.start("msg_1", "m")
# Open the tool call
e.feed_chunk(
{
"choices": [
{
"delta": {
"tool_calls": [
{
"index": 0,
"id": "c1",
"type": "function",
"function": {"name": "Bash", "arguments": ""},
}
]
}
}
]
}
)
# Stream argument fragments
events1 = e.feed_chunk(
{
"choices": [
{
"delta": {
"tool_calls": [
{"index": 0, "function": {"arguments": '{"cmd'}}
]
}
}
]
}
)
events2 = e.feed_chunk(
{
"choices": [
{
"delta": {
"tool_calls": [
{"index": 0, "function": {"arguments": '": "ls"}'}}
]
}
}
]
}
)
parsed1 = self._parse(events1[0])
parsed2 = self._parse(events2[0])
assert parsed1["delta"]["type"] == "input_json_delta"
assert parsed1["delta"]["partial_json"] == '{"cmd'
assert parsed2["delta"]["partial_json"] == '": "ls"}'
def test_text_then_tool_closes_text_block(self):
e = AnthropicPassthroughEmitter()
e.start("msg_1", "m")
e.feed_chunk({"choices": [{"delta": {"content": "Let me check."}}]})
events = e.feed_chunk(
{
"choices": [
{
"delta": {
"tool_calls": [
{
"index": 0,
"id": "c1",
"type": "function",
"function": {"name": "Bash", "arguments": ""},
}
]
}
}
]
}
)
# Should close text block and open tool_use block
assert "content_block_stop" in events[0]
assert "content_block_start" in events[1]
assert '"type": "tool_use"' in events[1]
def test_finish_reason_tool_calls_sets_tool_use_stop(self):
e = AnthropicPassthroughEmitter()
e.start("msg_1", "m")
e.feed_chunk(
{
"choices": [
{
"delta": {
"tool_calls": [
{
"index": 0,
"id": "c1",
"type": "function",
"function": {"name": "Bash", "arguments": "{}"},
}
]
}
}
]
}
)
e.feed_chunk({"choices": [{"delta": {}, "finish_reason": "tool_calls"}]})
events = e.finish()
delta_event = [ev for ev in events if "message_delta" in ev][0]
parsed = self._parse(delta_event)
assert parsed["delta"]["stop_reason"] == "tool_use"
def test_finish_reason_stop_sets_end_turn(self):
e = AnthropicPassthroughEmitter()
e.start("msg_1", "m")
e.feed_chunk({"choices": [{"delta": {"content": "Hi"}}]})
e.feed_chunk({"choices": [{"delta": {}, "finish_reason": "stop"}]})
events = e.finish()
delta_event = [ev for ev in events if "message_delta" in ev][0]
parsed = self._parse(delta_event)
assert parsed["delta"]["stop_reason"] == "end_turn"
def test_finish_reason_length_sets_max_tokens(self):
e = AnthropicPassthroughEmitter()
e.start("msg_1", "m")
e.feed_chunk({"choices": [{"delta": {"content": "Hi"}}]})
e.feed_chunk({"choices": [{"delta": {}, "finish_reason": "length"}]})
events = e.finish()
delta_event = [ev for ev in events if "message_delta" in ev][0]
parsed = self._parse(delta_event)
assert parsed["delta"]["stop_reason"] == "max_tokens"
def test_finish_closes_current_block(self):
e = AnthropicPassthroughEmitter()
e.start("msg_1", "m")
e.feed_chunk({"choices": [{"delta": {"content": "Hi"}}]})
events = e.finish()
assert "content_block_stop" in events[0]
assert "message_delta" in events[1]
assert "message_stop" in events[2]
def test_usage_chunk_captured(self):
e = AnthropicPassthroughEmitter()
e.start("msg_1", "m")
e.feed_chunk({"choices": [{"delta": {"content": "Hi"}}]})
e.feed_chunk(
{
"choices": [],
"usage": {"prompt_tokens": 10, "completion_tokens": 5},
}
)
events = e.finish()
delta_event = [ev for ev in events if "message_delta" in ev][0]
parsed = self._parse(delta_event)
assert parsed["usage"]["output_tokens"] == 5
def test_empty_chunk_returns_no_events(self):
e = AnthropicPassthroughEmitter()
e.start("msg_1", "m")
events = e.feed_chunk({"choices": []})
assert events == []
def test_no_blocks_at_all_still_produces_valid_finish(self):
e = AnthropicPassthroughEmitter()
e.start("msg_1", "m")
events = e.finish()
# No content_block_stop because no block was opened
assert not any("content_block_stop" in ev for ev in events)
assert any("message_delta" in ev for ev in events)
assert any("message_stop" in ev for ev in events)
def test_multiple_tool_calls_distinct_blocks(self):
e = AnthropicPassthroughEmitter()
e.start("msg_1", "m")
# First tool call
e.feed_chunk(
{
"choices": [
{
"delta": {
"tool_calls": [
{
"index": 0,
"id": "c1",
"type": "function",
"function": {"name": "Bash", "arguments": "{}"},
}
]
}
}
]
}
)
# Second tool call (different index)
events = e.feed_chunk(
{
"choices": [
{
"delta": {
"tool_calls": [
{
"index": 1,
"id": "c2",
"type": "function",
"function": {"name": "Read", "arguments": "{}"},
}
]
}
}
]
}
)
# Should close block 0, open block 1
assert "content_block_stop" in events[0]
assert "content_block_start" in events[1]
parsed = self._parse(events[1])
assert parsed["content_block"]["name"] == "Read"
assert parsed["content_block"]["id"] == "c2"

View file

@ -0,0 +1,86 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
import os
import sys
import types
from pathlib import Path
import pytest
from fastapi import HTTPException
# Keep this test runnable in lightweight environments where optional logging
# deps are not installed.
if "structlog" not in sys.modules:
class _DummyLogger:
def __getattr__(self, _name):
return lambda *args, **kwargs: None
sys.modules["structlog"] = types.SimpleNamespace(
BoundLogger = _DummyLogger,
get_logger = lambda *args, **kwargs: _DummyLogger(),
)
import routes.models as models_route
def test_resolve_browse_target_returns_allowed_directory(tmp_path):
allowed = tmp_path / "allowed"
target = allowed / "models" / "nested"
target.mkdir(parents = True)
resolved = models_route._resolve_browse_target(str(target), [allowed])
assert resolved == target.resolve()
def test_resolve_browse_target_rejects_outside_allowlist(tmp_path):
allowed = tmp_path / "allowed"
disallowed = tmp_path / "disallowed"
allowed.mkdir()
disallowed.mkdir()
with pytest.raises(HTTPException) as exc_info:
models_route._resolve_browse_target(str(disallowed), [allowed])
assert exc_info.value.status_code == 403
def test_resolve_browse_target_rejects_file_path(tmp_path):
allowed = tmp_path / "allowed"
allowed.mkdir()
model_file = allowed / "model.gguf"
model_file.write_text("gguf")
with pytest.raises(HTTPException) as exc_info:
models_route._resolve_browse_target(str(model_file), [allowed])
assert exc_info.value.status_code == 400
def test_resolve_browse_target_allows_symlink_into_other_allowed_root(tmp_path):
home_root = tmp_path / "home"
scan_root = tmp_path / "scan"
target = scan_root / "nested"
home_root.mkdir()
target.mkdir(parents = True)
(home_root / "scan-link").symlink_to(scan_root, target_is_directory = True)
resolved = models_route._resolve_browse_target(
str(home_root / "scan-link" / "nested"),
[home_root, scan_root],
)
assert resolved == target.resolve()
@pytest.mark.skipif(os.altsep is not None, reason = "POSIX-only path semantics")
def test_resolve_browse_target_allows_backslash_in_posix_segment(tmp_path):
allowed = tmp_path / "allowed"
target = allowed / r"dir\name"
target.mkdir(parents = True)
resolved = models_route._resolve_browse_target(str(target), [allowed])
assert resolved == target.resolve()

View file

@ -0,0 +1,120 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
from pathlib import Path
import sys
import types
# Keep this test runnable in lightweight environments where optional logging
# deps are not installed.
if "structlog" not in sys.modules:
class _DummyLogger:
def __getattr__(self, _name):
return lambda *args, **kwargs: None
sys.modules["structlog"] = types.SimpleNamespace(
BoundLogger = _DummyLogger,
get_logger = lambda *args, **kwargs: _DummyLogger(),
)
from utils.paths.path_utils import (
resolve_cached_repo_id_case,
get_cache_case_resolution_stats,
reset_cache_case_resolution_state,
)
import utils.paths.path_utils as path_utils
def _mk_cache_repo(cache_root: Path, repo_id: str) -> Path:
repo_dir = cache_root / f"models--{repo_id.replace('/', '--')}"
repo_dir.mkdir(parents = True, exist_ok = True)
return repo_dir
def test_resolve_cached_repo_id_case_exact_hit(tmp_path, monkeypatch):
reset_cache_case_resolution_state()
_mk_cache_repo(tmp_path, "Org/Model")
monkeypatch.setattr(path_utils, "_hf_hub_cache_dir", lambda: tmp_path)
resolved = resolve_cached_repo_id_case("Org/Model")
assert resolved == "Org/Model"
stats = get_cache_case_resolution_stats()
assert stats["calls"] == 1
assert stats["exact_hits"] == 1
assert stats["variant_hits"] == 0
def test_resolve_cached_repo_id_case_variant_hit(tmp_path, monkeypatch):
reset_cache_case_resolution_state()
_mk_cache_repo(tmp_path, "Org/Model")
monkeypatch.setattr(path_utils, "_hf_hub_cache_dir", lambda: tmp_path)
resolved = resolve_cached_repo_id_case("org/model")
assert resolved == "Org/Model"
stats = get_cache_case_resolution_stats()
assert stats["variant_hits"] == 1
assert stats["tie_breaks"] == 0
def test_resolve_cached_repo_id_case_tie_break_deterministic(tmp_path, monkeypatch):
reset_cache_case_resolution_state()
_mk_cache_repo(tmp_path, "Org/Model")
_mk_cache_repo(tmp_path, "org/model")
monkeypatch.setattr(path_utils, "_hf_hub_cache_dir", lambda: tmp_path)
resolved = resolve_cached_repo_id_case("oRg/mOdEl")
# Deterministic rule: lexical sort of candidate repo ids.
assert resolved == "Org/Model"
stats = get_cache_case_resolution_stats()
assert stats["variant_hits"] == 1
assert stats["tie_breaks"] == 1
def test_resolve_cached_repo_id_case_no_cache_fallback(tmp_path, monkeypatch):
reset_cache_case_resolution_state()
monkeypatch.setattr(path_utils, "_hf_hub_cache_dir", lambda: tmp_path)
resolved = resolve_cached_repo_id_case("Org/Missing")
assert resolved == "Org/Missing"
stats = get_cache_case_resolution_stats()
assert stats["fallbacks"] == 1
assert stats["variant_hits"] == 0
assert stats["exact_hits"] == 0
def test_resolve_cached_repo_id_case_memoization(tmp_path, monkeypatch):
reset_cache_case_resolution_state()
_mk_cache_repo(tmp_path, "Org/Model")
monkeypatch.setattr(path_utils, "_hf_hub_cache_dir", lambda: tmp_path)
first = resolve_cached_repo_id_case("org/model")
second = resolve_cached_repo_id_case("org/model")
assert first == "Org/Model"
assert second == "Org/Model"
stats = get_cache_case_resolution_stats()
assert stats["calls"] == 2
assert stats["variant_hits"] == 1
assert stats["memo_hits"] == 1
def test_resolve_cached_repo_id_case_late_cache_population(tmp_path, monkeypatch):
"""Regression guard: memoized fallback should not hide a later cache variant."""
reset_cache_case_resolution_state()
monkeypatch.setattr(path_utils, "_hf_hub_cache_dir", lambda: tmp_path)
first = resolve_cached_repo_id_case("org/model")
assert first == "org/model"
# Simulate cache being populated after first miss (e.g. another code path/download).
_mk_cache_repo(tmp_path, "Org/Model")
second = resolve_cached_repo_id_case("org/model")
# Desired behavior: second lookup should pick up the now-existing variant.
assert second == "Org/Model"

View file

@ -0,0 +1,398 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
import asyncio
import sys
import types
from pathlib import Path
from types import SimpleNamespace
# Keep this test runnable in lightweight environments where optional logging
# deps are not installed.
if "structlog" not in sys.modules:
class _DummyLogger:
def __getattr__(self, _name):
return lambda *args, **kwargs: None
sys.modules["structlog"] = types.SimpleNamespace(
BoundLogger = _DummyLogger,
get_logger = lambda *args, **kwargs: _DummyLogger(),
)
import routes.models as models_route
def _repo(
repo_id: str,
files: list[SimpleNamespace],
repo_path: Path,
*,
revisions: list[SimpleNamespace] | None = None,
) -> SimpleNamespace:
return SimpleNamespace(
repo_id = repo_id,
repo_type = "model",
repo_path = repo_path,
revisions = revisions or [SimpleNamespace(files = files)],
)
def _file(
name: str,
size_on_disk: int,
*,
blob_path: str | None = None,
) -> SimpleNamespace:
return SimpleNamespace(
file_name = name,
size_on_disk = size_on_disk,
blob_path = blob_path,
)
def test_iter_gguf_paths_matches_extension_case_insensitively(tmp_path):
nested = tmp_path / "snapshots" / "rev"
nested.mkdir(parents = True)
lower = nested / "Q4_K_M.gguf"
upper = nested / "Q8_0.GGUF"
other = nested / "README.md"
lower.write_text("a")
upper.write_text("b")
other.write_text("c")
result = sorted(path.name for path in models_route._iter_gguf_paths(tmp_path))
assert result == ["Q4_K_M.gguf", "Q8_0.GGUF"]
def test_list_cached_gguf_includes_non_suffix_repo_when_cache_contains_gguf(
monkeypatch, tmp_path
):
repo = _repo(
"HauhauCS/Gemma-4-E4B-Uncensored-HauhauCS-Aggressive",
[_file("Q4_K_M.gguf", 5_000), _file("README.md", 10)],
tmp_path / "models--HauhauCS--Gemma",
)
scan = SimpleNamespace(repos = [repo])
monkeypatch.setattr(models_route, "_all_hf_cache_scans", lambda: [scan])
result = asyncio.run(models_route.list_cached_gguf(current_subject = "test-user"))
assert result["cached"] == [
{
"repo_id": "HauhauCS/Gemma-4-E4B-Uncensored-HauhauCS-Aggressive",
"size_bytes": 5_000,
"cache_path": str(repo.repo_path),
}
]
def test_list_cached_gguf_matches_extension_case_insensitively(monkeypatch, tmp_path):
repo = _repo(
"Org/Model-Without-Suffix",
[_file("Q8_0.GGUF", 7_000)],
tmp_path / "models--Org--Model-Without-Suffix",
)
scan = SimpleNamespace(repos = [repo])
monkeypatch.setattr(models_route, "_all_hf_cache_scans", lambda: [scan])
result = asyncio.run(models_route.list_cached_gguf(current_subject = "test-user"))
assert result["cached"] == [
{
"repo_id": "Org/Model-Without-Suffix",
"size_bytes": 7_000,
"cache_path": str(repo.repo_path),
}
]
def test_list_cached_gguf_skips_repos_without_positive_gguf_size(monkeypatch, tmp_path):
missing = _repo(
"Org/ReadmeOnly",
[_file("README.md", 10)],
tmp_path / "models--Org--ReadmeOnly",
)
zero = _repo(
"Org/ZeroSize",
[_file("Q4_K_M.gguf", 0)],
tmp_path / "models--Org--ZeroSize",
)
scan = SimpleNamespace(repos = [missing, zero])
monkeypatch.setattr(models_route, "_all_hf_cache_scans", lambda: [scan])
result = asyncio.run(models_route.list_cached_gguf(current_subject = "test-user"))
assert result["cached"] == []
def test_list_cached_gguf_keeps_largest_duplicate_repo_across_scans(
monkeypatch, tmp_path
):
smaller = _repo(
"Org/Dupe",
[_file("Q4_K_M.gguf", 2_000)],
tmp_path / "models--Org--Dupe-a",
)
larger = _repo(
"org/dupe",
[_file("Q4_K_M.gguf", 5_000), _file("Q6_K.gguf", 1_000)],
tmp_path / "models--Org--Dupe-b",
)
monkeypatch.setattr(
models_route,
"_all_hf_cache_scans",
lambda: [
SimpleNamespace(repos = [smaller]),
SimpleNamespace(repos = [larger]),
],
)
result = asyncio.run(models_route.list_cached_gguf(current_subject = "test-user"))
assert result["cached"] == [
{
"repo_id": "org/dupe",
"size_bytes": 6_000,
"cache_path": str(larger.repo_path),
}
]
def test_list_cached_gguf_dedupes_shared_blobs_across_revisions(monkeypatch, tmp_path):
shared = "blobs/shared-q4"
repo = _repo(
"Org/SharedBlobRepo",
[],
tmp_path / "models--Org--SharedBlobRepo",
revisions = [
SimpleNamespace(files = [_file("Q4_K_M.gguf", 5_000, blob_path = shared)]),
SimpleNamespace(files = [_file("Q4_K_M.gguf", 5_000, blob_path = shared)]),
],
)
monkeypatch.setattr(
models_route,
"_all_hf_cache_scans",
lambda: [SimpleNamespace(repos = [repo])],
)
result = asyncio.run(models_route.list_cached_gguf(current_subject = "test-user"))
assert result["cached"] == [
{
"repo_id": "Org/SharedBlobRepo",
"size_bytes": 5_000,
"cache_path": str(repo.repo_path),
}
]
def test_list_cached_models_skips_non_suffix_repo_when_gguf_files_exist(
monkeypatch, tmp_path
):
mixed = _repo(
"Org/MixedRepo",
[
_file("Q4_K_M.gguf", 5_000),
_file("model.safetensors", 10_000),
],
tmp_path / "models--Org--MixedRepo",
)
monkeypatch.setattr(
models_route,
"_all_hf_cache_scans",
lambda: [SimpleNamespace(repos = [mixed])],
)
result = asyncio.run(models_route.list_cached_models(current_subject = "test-user"))
assert result["cached"] == []
def test_list_cached_gguf_includes_mixed_repo_with_gguf_and_safetensors(
monkeypatch, tmp_path
):
"""Mirror of the _skips_ test: the mixed repo should still surface in
cached-gguf so the picker can show it as a GGUF download."""
mixed = _repo(
"Org/MixedRepo",
[
_file("Q4_K_M.gguf", 5_000),
_file("model.safetensors", 10_000),
],
tmp_path / "models--Org--MixedRepo",
)
monkeypatch.setattr(
models_route,
"_all_hf_cache_scans",
lambda: [SimpleNamespace(repos = [mixed])],
)
result = asyncio.run(models_route.list_cached_gguf(current_subject = "test-user"))
assert result["cached"] == [
{
"repo_id": "Org/MixedRepo",
"size_bytes": 5_000,
"cache_path": str(mixed.repo_path),
}
]
def test_list_cached_gguf_handles_none_size_on_disk(monkeypatch, tmp_path):
"""A partial/interrupted GGUF download has ``size_on_disk = None``. The
route must treat the unknown bytes as zero instead of raising TypeError
out of ``sum()`` and wiping the entire response."""
partial = _repo(
"Org/PartialDownload",
[_file("Q4_K_M.gguf", None), _file("Q6_K.gguf", 5_000)],
tmp_path / "models--Org--PartialDownload",
)
monkeypatch.setattr(
models_route,
"_all_hf_cache_scans",
lambda: [SimpleNamespace(repos = [partial])],
)
result = asyncio.run(models_route.list_cached_gguf(current_subject = "test-user"))
assert result["cached"] == [
{
"repo_id": "Org/PartialDownload",
"size_bytes": 5_000,
"cache_path": str(partial.repo_path),
}
]
def test_list_cached_gguf_skips_malformed_repo_without_wiping_response(
monkeypatch, tmp_path
):
"""One repo raising during classification must not poison the response
for every other repo in the scan."""
class _ExplodingRepo:
repo_id = "Org/Broken"
repo_type = "model"
repo_path = tmp_path / "models--Org--Broken"
@property
def revisions(self):
raise RuntimeError("boom")
healthy = _repo(
"Org/Healthy",
[_file("Q4_K_M.gguf", 5_000)],
tmp_path / "models--Org--Healthy",
)
monkeypatch.setattr(
models_route,
"_all_hf_cache_scans",
lambda: [SimpleNamespace(repos = [_ExplodingRepo(), healthy])],
)
result = asyncio.run(models_route.list_cached_gguf(current_subject = "test-user"))
assert result["cached"] == [
{
"repo_id": "Org/Healthy",
"size_bytes": 5_000,
"cache_path": str(healthy.repo_path),
}
]
def test_list_cached_gguf_skips_repo_with_only_mmproj_gguf(monkeypatch, tmp_path):
"""A repo whose only ``.gguf`` artifact is an mmproj vision adapter
must not be classified as a GGUF repo: the variant selector filters
mmproj out and the picker would otherwise show zero variants."""
mmproj_only = _repo(
"Org/MmprojOnly",
[
_file("mmproj-Q8_0.gguf", 5_000),
_file("model.safetensors", 10_000),
],
tmp_path / "models--Org--MmprojOnly",
)
monkeypatch.setattr(
models_route,
"_all_hf_cache_scans",
lambda: [SimpleNamespace(repos = [mmproj_only])],
)
result = asyncio.run(models_route.list_cached_gguf(current_subject = "test-user"))
assert result["cached"] == []
def test_list_cached_models_includes_repo_with_only_mmproj_gguf(monkeypatch, tmp_path):
"""Mirror of the cached-gguf skip: a safetensors repo with an
auxiliary mmproj vision adapter must still surface in cached-models
so the user can load it as a normal model."""
mmproj_aux = _repo(
"Org/MmprojAux",
[
_file("mmproj-Q8_0.gguf", 5_000),
_file("model.safetensors", 10_000),
],
tmp_path / "models--Org--MmprojAux",
)
monkeypatch.setattr(
models_route,
"_all_hf_cache_scans",
lambda: [SimpleNamespace(repos = [mmproj_aux])],
)
result = asyncio.run(models_route.list_cached_models(current_subject = "test-user"))
assert result["cached"] == [
{
"repo_id": "Org/MmprojAux",
"size_bytes": 15_000,
}
]
def test_list_cached_gguf_includes_vision_repo_with_main_gguf_and_mmproj(
monkeypatch, tmp_path
):
"""A vision-capable GGUF repo (main weight + mmproj adapter) is still
a GGUF repo. The reported size is the main weight size; mmproj is
excluded from the GGUF-size accounting because it is filtered out at
classification time."""
vision_repo = _repo(
"Org/VisionGguf",
[
_file("Q4_K_M.gguf", 5_000),
_file("mmproj-Q8_0.gguf", 1_000),
],
tmp_path / "models--Org--VisionGguf",
)
monkeypatch.setattr(
models_route,
"_all_hf_cache_scans",
lambda: [SimpleNamespace(repos = [vision_repo])],
)
result = asyncio.run(models_route.list_cached_gguf(current_subject = "test-user"))
assert result["cached"] == [
{
"repo_id": "Org/VisionGguf",
"size_bytes": 5_000,
"cache_path": str(vision_repo.repo_path),
}
]

View file

@ -0,0 +1,179 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""
Regression tests for the export log ring-buffer cursor semantics.
Context: the live export log SSE stream has a race where the frontend
opens the SSE connection AFTER the POST that starts the export. Any
lines the worker subprocess emits during the gap between POST and SSE
connect get buffered with seqs 1..k, and then the SSE default cursor
`get_current_log_seq()` returns k -- so lines 1..k are forever
unreachable to that client.
Fix: `clear_logs()` snapshots the pre-run seq into `_run_start_seq`
(exposed via `get_run_start_seq()`), and `routes/export.py` defaults
the SSE cursor to that snapshot instead of the current seq. Every line
appended during the current run has seq strictly greater than the
snapshot, so the client sees the full run regardless of when it
connects.
These tests exercise the orchestrator-side contract only (no
subprocess, no FastAPI, no frontend). The routes-level integration
with get_run_start_seq() is a one-line edit covered by manual testing
and the frontend build.
"""
from __future__ import annotations
import sys
import types
from pathlib import Path
import pytest
# Backend root on sys.path so `from core.export.orchestrator import ...`
# and friends resolve without the studio app bootstrap.
_BACKEND_DIR = Path(__file__).resolve().parent.parent
if str(_BACKEND_DIR) not in sys.path:
sys.path.insert(0, str(_BACKEND_DIR))
# ExportOrchestrator imports structlog and a few heavy modules at the
# top of orchestrator.py. Stub the ones we don't need in these unit
# tests so the import succeeds on machines without the full studio
# venv.
_loggers_stub = types.ModuleType("loggers")
_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name)
sys.modules.setdefault("loggers", _loggers_stub)
# structlog is only used for a module-level import; a bare stub is
# enough because we never call into it in these tests.
sys.modules.setdefault("structlog", types.ModuleType("structlog"))
# utils.paths.outputs_root is only called inside scan_checkpoints which
# we don't hit in these tests. Provide a stub module so the top-level
# import in orchestrator.py resolves.
_utils_pkg = types.ModuleType("utils")
_utils_pkg.__path__ = [] # mark as package
_utils_paths_stub = types.ModuleType("utils.paths")
_utils_paths_stub.outputs_root = lambda: Path("/tmp")
sys.modules.setdefault("utils", _utils_pkg)
sys.modules.setdefault("utils.paths", _utils_paths_stub)
@pytest.fixture
def orchestrator():
"""Fresh ExportOrchestrator with only the log-buffer state exercised."""
from core.export.orchestrator import ExportOrchestrator
return ExportOrchestrator()
def _append(orch, line: str, stream: str = "stdout") -> None:
"""Shortcut for simulating a worker log message."""
orch._append_log({"type": "log", "stream": stream, "line": line, "ts": 0.0})
# ---------------------------------------------------------------------------
# clear_logs() semantics
# ---------------------------------------------------------------------------
def test_run_start_seq_is_zero_before_any_logs(orchestrator) -> None:
"""A brand-new orchestrator must report run_start_seq == 0 so a
first SSE connection picks up every line from seq 1 onward."""
assert orchestrator.get_run_start_seq() == 0
def test_clear_logs_snapshots_current_seq(orchestrator) -> None:
"""clear_logs() must capture _log_seq BEFORE clearing the buffer,
so subsequent runs can anchor their SSE cursor at the snapshot."""
_append(orchestrator, "old run line 1")
_append(orchestrator, "old run line 2")
_append(orchestrator, "old run line 3")
assert orchestrator.get_current_log_seq() == 3
orchestrator.clear_logs()
assert orchestrator.get_run_start_seq() == 3
assert orchestrator.get_current_log_seq() == 3 # seq counter preserved
# ---------------------------------------------------------------------------
# Race regression: SSE connects AFTER lines have been emitted
# ---------------------------------------------------------------------------
def test_sse_default_cursor_catches_all_current_run_lines(orchestrator) -> None:
"""Simulate the POST-then-SSE race: worker starts emitting lines
immediately after clear_logs(), SSE connects several lines later.
Using get_run_start_seq() as the default cursor MUST return every
line emitted since clear_logs() ran.
Pre-fix, the SSE defaulted to get_current_log_seq() at connect
time, which would return the last-seen seq and miss lines N+1..M.
"""
# Previous run leaves some buffered lines.
_append(orchestrator, "previous run line A")
_append(orchestrator, "previous run line B")
# New run starts: orchestrator clears the buffer and snapshots seq.
orchestrator.clear_logs()
run_start = orchestrator.get_run_start_seq()
# Worker emits early lines BEFORE the SSE connects.
_append(orchestrator, "Importing Unsloth...")
_append(orchestrator, "Loading checkpoint: /foo/bar")
_append(orchestrator, "Starting export...")
# SSE connects now and asks "give me everything after the run
# start cursor".
entries, new_cursor = orchestrator.get_logs_since(run_start)
# All three early lines must be present. Pre-fix this was [].
lines = [e["line"] for e in entries]
assert lines == [
"Importing Unsloth...",
"Loading checkpoint: /foo/bar",
"Starting export...",
]
assert new_cursor == entries[-1]["seq"]
def test_sse_default_cursor_excludes_previous_run(orchestrator) -> None:
"""After clear_logs(), lines from the PREVIOUS run must not leak
into the new run's SSE stream. Pre-fix this worked correctly
(clear_logs cleared the deque); the fix must preserve it.
"""
_append(orchestrator, "previous run line 1")
_append(orchestrator, "previous run line 2")
_append(orchestrator, "previous run line 3")
assert orchestrator.get_current_log_seq() == 3
orchestrator.clear_logs()
run_start = orchestrator.get_run_start_seq()
_append(orchestrator, "new run line")
entries, _ = orchestrator.get_logs_since(run_start)
assert [e["line"] for e in entries] == ["new run line"]
def test_clear_logs_twice_advances_run_start(orchestrator) -> None:
"""Back-to-back clear_logs() calls (e.g. cleanup -> load ->
export in the same dialog session) must each re-anchor run_start
at the current seq, so successive runs each start with a fresh
low-water mark."""
_append(orchestrator, "run 1 line a")
_append(orchestrator, "run 1 line b")
orchestrator.clear_logs()
assert orchestrator.get_run_start_seq() == 2
_append(orchestrator, "run 2 line a")
_append(orchestrator, "run 2 line b")
_append(orchestrator, "run 2 line c")
orchestrator.clear_logs()
assert orchestrator.get_run_start_seq() == 5

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,544 @@
#!/usr/bin/env python3
"""
Sandbox test for multi-GPU selection logic.
Tests the core GPU selection, memory estimation, and device_map logic
in an isolated environment. Can be run on Linux, macOS, and Windows
without requiring actual GPUs -- all hardware calls are mocked.
Usage:
python -m pytest studio/backend/tests/test_gpu_selection_sandbox.py -v
# or directly:
python studio/backend/tests/test_gpu_selection_sandbox.py
"""
import os
import sys
import unittest
from pathlib import Path
from unittest.mock import patch, MagicMock
# Ensure backend is on sys.path
_backend_root = Path(__file__).resolve().parent.parent
if str(_backend_root) not in sys.path:
sys.path.insert(0, str(_backend_root))
def _make_fake_config(
vocab_size = 32000,
hidden_size = 4096,
intermediate_size = 11008,
num_hidden_layers = 32,
num_attention_heads = 32,
num_key_value_heads = 8,
tie_word_embeddings = False,
):
"""Create a fake HF config-like object for estimation tests."""
from types import SimpleNamespace
return SimpleNamespace(
vocab_size = vocab_size,
hidden_size = hidden_size,
intermediate_size = intermediate_size,
num_hidden_layers = num_hidden_layers,
num_attention_heads = num_attention_heads,
num_key_value_heads = num_key_value_heads,
tie_word_embeddings = tie_word_embeddings,
)
class TestEstimateFP16ModelSizeFromConfig(unittest.TestCase):
"""Test the config-based model size estimation."""
def test_llama_8b_size_reasonable(self):
from utils.hardware.hardware import _estimate_fp16_model_size_bytes_from_config
config = _make_fake_config(
vocab_size = 128256,
hidden_size = 4096,
intermediate_size = 14336,
num_hidden_layers = 32,
num_attention_heads = 32,
num_key_value_heads = 8,
tie_word_embeddings = False,
)
size = _estimate_fp16_model_size_bytes_from_config(config)
self.assertIsNotNone(size)
size_gb = size / (1024**3)
# Llama 3.1 8B should be ~15GB in fp16
self.assertGreater(size_gb, 12)
self.assertLess(size_gb, 20)
def test_small_model(self):
from utils.hardware.hardware import _estimate_fp16_model_size_bytes_from_config
config = _make_fake_config(
vocab_size = 32000,
hidden_size = 2048,
intermediate_size = 5504,
num_hidden_layers = 22,
num_attention_heads = 32,
num_key_value_heads = 4,
)
size = _estimate_fp16_model_size_bytes_from_config(config)
self.assertIsNotNone(size)
size_gb = size / (1024**3)
# ~1B model should be ~2GB in fp16
self.assertGreater(size_gb, 1)
self.assertLess(size_gb, 5)
def test_returns_none_for_incomplete_config(self):
from utils.hardware.hardware import _estimate_fp16_model_size_bytes_from_config
from types import SimpleNamespace
config = SimpleNamespace(vocab_size = 32000) # Missing most fields
size = _estimate_fp16_model_size_bytes_from_config(config)
self.assertIsNone(size)
def test_moe_model(self):
from utils.hardware.hardware import _estimate_fp16_model_size_bytes_from_config
from types import SimpleNamespace
config = SimpleNamespace(
vocab_size = 152064,
hidden_size = 3584,
intermediate_size = 18944,
num_hidden_layers = 28,
num_attention_heads = 28,
num_key_value_heads = 4,
tie_word_embeddings = False,
num_local_experts = 64,
moe_intermediate_size = 2560,
)
size = _estimate_fp16_model_size_bytes_from_config(config)
self.assertIsNotNone(size)
size_gb = size / (1024**3)
# MoE model with 64 experts should be large
self.assertGreater(size_gb, 50)
class TestEstimateRequiredModelMemory(unittest.TestCase):
"""Test memory requirement estimation."""
def test_inference_fp16_uses_1_3x(self):
from utils.hardware.hardware import estimate_required_model_memory_gb
with patch(
"utils.hardware.hardware.estimate_fp16_model_size_bytes",
return_value = (10 * (1024**3), "config"), # 10GB model
):
required, meta = estimate_required_model_memory_gb(
"test/model",
training_type = None, # inference
load_in_4bit = False,
)
self.assertIsNotNone(required)
self.assertAlmostEqual(required, 13.0, places = 0)
self.assertEqual(meta["mode"], "inference")
def test_inference_4bit_uses_reduced_estimate(self):
from utils.hardware.hardware import estimate_required_model_memory_gb
with patch(
"utils.hardware.hardware.estimate_fp16_model_size_bytes",
return_value = (30 * (1024**3), "config"), # 30GB fp16 model
):
required, meta = estimate_required_model_memory_gb(
"test/model",
training_type = None, # inference
load_in_4bit = True,
)
self.assertIsNotNone(required)
# 4bit base = 30/3.2 = 9.375GB, required = 9.375 + max(9.375*0.3, 2) = 12.19GB
self.assertAlmostEqual(required, 12.2, places = 0)
def test_4bit_training_reduces_base(self):
from utils.hardware.hardware import estimate_required_model_memory_gb
with patch(
"utils.hardware.hardware.estimate_fp16_model_size_bytes",
return_value = (30 * (1024**3), "config"), # 30GB fp16 model
):
required, meta = estimate_required_model_memory_gb(
"test/model",
training_type = "LoRA/QLoRA",
load_in_4bit = True,
)
self.assertIsNotNone(required)
# fallback: base=30/3.2=9.375, lora=30*0.04=1.2, act=30*0.15=4.5, cuda=1.4
self.assertAlmostEqual(required, 16.5, places = 0)
def test_full_finetune_uses_3_5x(self):
from utils.hardware.hardware import estimate_required_model_memory_gb
with patch(
"utils.hardware.hardware.estimate_fp16_model_size_bytes",
return_value = (10 * (1024**3), "config"), # 10GB model
):
required, meta = estimate_required_model_memory_gb(
"test/model",
training_type = "Full Finetuning",
)
self.assertIsNotNone(required)
# fallback: 10 * 3.5 + 1.4 cuda overhead = 36.4
self.assertAlmostEqual(required, 36.4, places = 0)
def test_returns_none_when_unavailable(self):
from utils.hardware.hardware import estimate_required_model_memory_gb
with patch(
"utils.hardware.hardware.estimate_fp16_model_size_bytes",
return_value = (None, "unavailable"),
):
required, meta = estimate_required_model_memory_gb("test/model")
self.assertIsNone(required)
class TestAutoSelectGpuIds(unittest.TestCase):
"""Test automatic GPU selection based on model size and free memory."""
def _make_utilization(self, devices):
"""Create a fake utilization response."""
return {
"available": True,
"devices": [
{
"index": idx,
"vram_total_gb": total,
"vram_used_gb": total - free,
}
for idx, total, free in devices
],
}
def test_single_gpu_sufficient(self):
from utils.hardware.hardware import auto_select_gpu_ids
import utils.hardware.hardware as hw
with (
patch.object(hw, "get_device", return_value = hw.DeviceType.CUDA),
patch.object(
hw,
"estimate_required_model_memory_gb",
return_value = (
10.0,
{
"mode": "inference",
"required_gb": 10.0,
"model_size_source": "config",
"model_size_gb": 7.7,
},
),
),
patch.object(
hw,
"_get_parent_visible_gpu_spec",
return_value = {
"raw": "0,1,2,3",
"numeric_ids": [0, 1, 2, 3],
"supports_explicit_gpu_ids": True,
},
),
patch.object(hw, "get_parent_visible_gpu_ids", return_value = [0, 1, 2, 3]),
patch.object(
hw,
"get_visible_gpu_utilization",
return_value = self._make_utilization(
[
(0, 80.0, 75.0),
(1, 80.0, 78.0),
(2, 80.0, 70.0),
(3, 80.0, 72.0),
]
),
),
):
selected, meta = auto_select_gpu_ids("test/model")
# Should pick GPU 1 (most free memory: 78GB) -- enough for 10GB
self.assertEqual(len(selected), 1)
self.assertEqual(selected[0], 1)
def test_two_gpus_needed(self):
from utils.hardware.hardware import auto_select_gpu_ids
import utils.hardware.hardware as hw
with (
patch.object(hw, "get_device", return_value = hw.DeviceType.CUDA),
patch.object(
hw,
"estimate_required_model_memory_gb",
return_value = (
50.0,
{
"mode": "inference",
"required_gb": 50.0,
"model_size_source": "config",
"model_size_gb": 38.0,
},
),
),
patch.object(
hw,
"_get_parent_visible_gpu_spec",
return_value = {
"raw": "0,1",
"numeric_ids": [0, 1],
"supports_explicit_gpu_ids": True,
},
),
patch.object(hw, "get_parent_visible_gpu_ids", return_value = [0, 1]),
patch.object(
hw,
"get_visible_gpu_utilization",
return_value = self._make_utilization(
[
(0, 40.0, 30.0), # 30GB free
(1, 40.0, 35.0), # 35GB free
]
),
),
):
selected, meta = auto_select_gpu_ids("test/model")
# 35GB (first) + 30*0.85 (second) = 60.5GB > 50GB
self.assertEqual(len(selected), 2)
def test_non_cuda_returns_none(self):
from utils.hardware.hardware import auto_select_gpu_ids
import utils.hardware.hardware as hw
with patch.object(hw, "get_device", return_value = hw.DeviceType.CPU):
selected, meta = auto_select_gpu_ids("test/model")
self.assertIsNone(selected)
self.assertEqual(meta["selection_mode"], "non_cuda")
class TestGetDeviceMap(unittest.TestCase):
"""Test device_map string generation."""
def test_single_gpu_returns_sequential(self):
from utils.hardware.hardware import get_device_map
import utils.hardware.hardware as hw
with (
patch.object(hw, "get_device", return_value = hw.DeviceType.CUDA),
patch.object(
hw,
"_get_parent_visible_gpu_spec",
return_value = {
"raw": "0",
"numeric_ids": [0],
"supports_explicit_gpu_ids": True,
},
),
patch.object(hw, "get_visible_gpu_count", return_value = 1),
):
dm = get_device_map(gpu_ids = [0])
self.assertEqual(dm, "sequential")
def test_multi_gpu_returns_balanced(self):
from utils.hardware.hardware import get_device_map
import utils.hardware.hardware as hw
with patch.object(hw, "get_device", return_value = hw.DeviceType.CUDA):
dm = get_device_map(gpu_ids = [0, 1])
self.assertEqual(dm, "balanced")
def test_cpu_returns_sequential(self):
from utils.hardware.hardware import get_device_map
import utils.hardware.hardware as hw
with patch.object(hw, "get_device", return_value = hw.DeviceType.CPU):
dm = get_device_map(gpu_ids = None)
self.assertEqual(dm, "sequential")
class TestResolveRequestedGpuIds(unittest.TestCase):
"""Test GPU ID validation."""
def test_none_returns_parent_visible(self):
from utils.hardware.hardware import resolve_requested_gpu_ids
with (
patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "2,3"}, clear = False),
patch("utils.hardware.hardware.get_physical_gpu_count", return_value = 8),
):
result = resolve_requested_gpu_ids(None)
self.assertEqual(result, [2, 3])
def test_empty_list_returns_parent_visible(self):
from utils.hardware.hardware import resolve_requested_gpu_ids
with (
patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "2,3"}, clear = False),
patch("utils.hardware.hardware.get_physical_gpu_count", return_value = 8),
):
result = resolve_requested_gpu_ids([])
self.assertEqual(result, [2, 3])
def test_duplicates_rejected(self):
from utils.hardware.hardware import resolve_requested_gpu_ids
with (
patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1,2"}, clear = False),
patch("utils.hardware.hardware.get_physical_gpu_count", return_value = 8),
):
with self.assertRaises(ValueError):
resolve_requested_gpu_ids([1, 1])
def test_out_of_range_rejected(self):
from utils.hardware.hardware import resolve_requested_gpu_ids
with (
patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}, clear = False),
patch("utils.hardware.hardware.get_physical_gpu_count", return_value = 4),
):
with self.assertRaises(ValueError):
resolve_requested_gpu_ids([5])
def test_uuid_env_var_rejects_explicit_ids(self):
from utils.hardware.hardware import resolve_requested_gpu_ids
with (
patch.dict(
os.environ, {"CUDA_VISIBLE_DEVICES": "GPU-abc,GPU-def"}, clear = False
),
patch("utils.hardware.hardware.get_physical_gpu_count", return_value = 8),
):
with self.assertRaises(ValueError):
resolve_requested_gpu_ids([0])
class TestApplyGpuIds(unittest.TestCase):
"""Test CUDA_VISIBLE_DEVICES environment variable setting."""
def test_apply_list(self):
from utils.hardware.hardware import apply_gpu_ids
with patch.dict(os.environ, {}, clear = False):
apply_gpu_ids([3, 5])
self.assertEqual(os.environ.get("CUDA_VISIBLE_DEVICES"), "3,5")
def test_apply_none_does_nothing(self):
from utils.hardware.hardware import apply_gpu_ids
original = os.environ.get("CUDA_VISIBLE_DEVICES")
apply_gpu_ids(None)
self.assertEqual(os.environ.get("CUDA_VISIBLE_DEVICES"), original)
class TestMultiGpuOverheadAccounting(unittest.TestCase):
"""Test that multi-GPU overhead is applied correctly.
The first GPU should keep its full free memory, and only
additional GPUs should have the overhead factor applied.
"""
def _make_utilization(self, devices):
return {
"available": True,
"devices": [
{
"index": idx,
"vram_total_gb": total,
"vram_used_gb": total - free,
}
for idx, total, free in devices
],
}
def test_first_gpu_not_penalized(self):
"""A model that just fits on 1 GPU should not require 2 GPUs."""
from utils.hardware.hardware import auto_select_gpu_ids
import utils.hardware.hardware as hw
# Model requires 79GB, GPU has 80GB free
with (
patch.object(hw, "get_device", return_value = hw.DeviceType.CUDA),
patch.object(
hw,
"estimate_required_model_memory_gb",
return_value = (
79.0,
{
"mode": "inference",
"required_gb": 79.0,
"model_size_source": "config",
"model_size_gb": 60.0,
},
),
),
patch.object(
hw,
"_get_parent_visible_gpu_spec",
return_value = {
"raw": "0,1",
"numeric_ids": [0, 1],
"supports_explicit_gpu_ids": True,
},
),
patch.object(hw, "get_parent_visible_gpu_ids", return_value = [0, 1]),
patch.object(
hw,
"get_visible_gpu_utilization",
return_value = self._make_utilization(
[
(0, 80.0, 80.0),
(1, 80.0, 80.0),
]
),
),
):
selected, meta = auto_select_gpu_ids("test/model")
# Should fit on 1 GPU (80GB >= 79GB)
self.assertEqual(len(selected), 1)
def test_second_gpu_has_overhead(self):
"""When 2 GPUs are needed, the second one's contribution is reduced."""
from utils.hardware.hardware import auto_select_gpu_ids
import utils.hardware.hardware as hw
# Model requires 110GB. First GPU has 80GB, second has 40GB.
# With overhead: 80 + 40*0.85 = 114GB -- just enough
with (
patch.object(hw, "get_device", return_value = hw.DeviceType.CUDA),
patch.object(
hw,
"estimate_required_model_memory_gb",
return_value = (
110.0,
{
"mode": "inference",
"required_gb": 110.0,
"model_size_source": "config",
"model_size_gb": 85.0,
},
),
),
patch.object(
hw,
"_get_parent_visible_gpu_spec",
return_value = {
"raw": "0,1",
"numeric_ids": [0, 1],
"supports_explicit_gpu_ids": True,
},
),
patch.object(hw, "get_parent_visible_gpu_ids", return_value = [0, 1]),
patch.object(
hw,
"get_visible_gpu_utilization",
return_value = self._make_utilization(
[
(0, 80.0, 80.0),
(1, 80.0, 40.0),
]
),
),
):
selected, meta = auto_select_gpu_ids("test/model")
# Should use both GPUs
self.assertEqual(len(selected), 2)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,929 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved.
"""Tests for 5-path architecture-aware KV cache VRAM estimation.
Covers the GGUF metadata parser, _can_estimate_kv gate, all 5 estimation
paths (MLA, Hybrid Mamba, Sliding Window, Standard GQA, Legacy), KV cache
quantization, edge cases, and lifecycle (init/unload/reparse).
Requires no GPU, network, or external libraries beyond pytest.
Cross-platform: Linux, macOS, Windows, WSL.
"""
import io
import struct
import sys
import types as _types
from pathlib import Path
import pytest
# ---------------------------------------------------------------------------
# Stub heavy / unavailable external dependencies before importing the
# module under test. Same pattern as test_native_context_length.py.
# ---------------------------------------------------------------------------
_BACKEND_DIR = str(Path(__file__).resolve().parent.parent)
if _BACKEND_DIR not in sys.path:
sys.path.insert(0, _BACKEND_DIR)
# loggers
_loggers_stub = _types.ModuleType("loggers")
_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name)
sys.modules.setdefault("loggers", _loggers_stub)
# structlog
_structlog_stub = _types.ModuleType("structlog")
sys.modules.setdefault("structlog", _structlog_stub)
# httpx
_httpx_stub = _types.ModuleType("httpx")
for _exc_name in (
"ConnectError",
"TimeoutException",
"ReadTimeout",
"ReadError",
"RemoteProtocolError",
"CloseError",
):
setattr(_httpx_stub, _exc_name, type(_exc_name, (Exception,), {}))
class _FakeTimeout:
def __init__(self, *a, **kw):
pass
_httpx_stub.Timeout = _FakeTimeout
_httpx_stub.Client = type(
"Client",
(),
{
"__init__": lambda self, **kw: None,
"__enter__": lambda self: self,
"__exit__": lambda self, *a: None,
},
)
sys.modules.setdefault("httpx", _httpx_stub)
from core.inference.llama_cpp import LlamaCppBackend
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_gguf_bytes(arch: str, kv_pairs: dict) -> bytes:
"""Build a minimal GGUF v3 binary blob with the given KV metadata.
Only supports UINT32 (type 4), UINT64 (type 10), and STRING (type 8)
values, which is all the metadata parser reads.
"""
buf = io.BytesIO()
# Header: magic, version, tensor_count, kv_count
buf.write(struct.pack("<I", 0x46554747)) # GGUF magic
buf.write(struct.pack("<I", 3)) # version 3
buf.write(struct.pack("<Q", 0)) # tensor_count
buf.write(struct.pack("<Q", len(kv_pairs)))
for key, val in kv_pairs.items():
key_bytes = key.encode("utf-8")
buf.write(struct.pack("<Q", len(key_bytes)))
buf.write(key_bytes)
if isinstance(val, str):
buf.write(struct.pack("<I", 8)) # STRING
val_bytes = val.encode("utf-8")
buf.write(struct.pack("<Q", len(val_bytes)))
buf.write(val_bytes)
elif isinstance(val, int):
if val <= 0xFFFFFFFF:
buf.write(struct.pack("<I", 4)) # UINT32
buf.write(struct.pack("<I", val))
else:
buf.write(struct.pack("<I", 10)) # UINT64
buf.write(struct.pack("<Q", val))
else:
raise TypeError(f"Unsupported value type: {type(val)}")
return buf.getvalue()
def _backend_from_gguf(arch: str, fields: dict) -> LlamaCppBackend:
"""Create a LlamaCppBackend with parsed GGUF metadata from given fields."""
kv = {"general.architecture": arch}
for k, v in fields.items():
kv[f"{arch}.{k}"] = v
import tempfile, os
data = _make_gguf_bytes(arch, kv)
fd, path = tempfile.mkstemp(suffix = ".gguf")
try:
os.write(fd, data)
os.close(fd)
b = LlamaCppBackend()
b._read_gguf_metadata(path)
return b
finally:
os.unlink(path)
# ---------------------------------------------------------------------------
# A. GGUF Parser Tests
# ---------------------------------------------------------------------------
class TestGGUFParserNewFields:
"""Verify that the 8 new architecture-aware fields are correctly parsed."""
@pytest.mark.parametrize(
"field,gguf_key,value",
[
("_kv_key_length", "attention.key_length", 128),
("_kv_value_length", "attention.value_length", 128),
("_sliding_window", "attention.sliding_window", 1024),
("_full_attention_interval", "full_attention_interval", 4),
("_kv_lora_rank", "attention.kv_lora_rank", 512),
("_key_length_mla", "attention.key_length_mla", 256),
("_ssm_inner_size", "ssm.inner_size", 6144),
("_ssm_state_size", "ssm.state_size", 128),
],
)
def test_field_parsed(self, field, gguf_key, value):
b = _backend_from_gguf("testarch", {gguf_key: value})
assert getattr(b, field) == value
def test_missing_fields_are_none(self):
b = _backend_from_gguf("testarch", {"block_count": 10})
for attr in [
"_kv_key_length",
"_kv_value_length",
"_sliding_window",
"_full_attention_interval",
"_kv_lora_rank",
"_key_length_mla",
"_ssm_inner_size",
"_ssm_state_size",
]:
assert getattr(b, attr) is None
def test_all_13_fields_parsed_together(self):
fields = {
"context_length": 131072,
"block_count": 62,
"attention.head_count_kv": 16,
"attention.head_count": 32,
"embedding_length": 5376,
"attention.key_length": 128,
"attention.value_length": 128,
"attention.sliding_window": 1024,
"full_attention_interval": 6,
"attention.kv_lora_rank": 512,
"attention.key_length_mla": 256,
"ssm.inner_size": 4096,
"ssm.state_size": 128,
}
b = _backend_from_gguf("testarch", fields)
assert b._context_length == 131072
assert b._n_layers == 62
assert b._n_kv_heads == 16
assert b._n_heads == 32
assert b._embedding_length == 5376
assert b._kv_key_length == 128
assert b._kv_value_length == 128
assert b._sliding_window == 1024
assert b._full_attention_interval == 6
assert b._kv_lora_rank == 512
assert b._key_length_mla == 256
assert b._ssm_inner_size == 4096
assert b._ssm_state_size == 128
class TestGGUFParserReset:
"""Verify that fields are properly reset between parses."""
def test_reset_between_parses(self):
# First parse with all fields
b = _backend_from_gguf(
"arch1",
{
"block_count": 32,
"attention.key_length": 128,
"attention.kv_lora_rank": 512,
"ssm.inner_size": 4096,
},
)
assert b._kv_key_length == 128
assert b._kv_lora_rank == 512
assert b._ssm_inner_size == 4096
# Second parse without those fields -- they should be None
kv = {"general.architecture": "arch2", "arch2.block_count": 64}
import tempfile, os
data = _make_gguf_bytes("arch2", kv)
fd, path = tempfile.mkstemp(suffix = ".gguf")
os.write(fd, data)
os.close(fd)
try:
b._read_gguf_metadata(path)
finally:
os.unlink(path)
assert b._kv_key_length is None
assert b._kv_lora_rank is None
assert b._ssm_inner_size is None
assert b._n_layers == 64
# ---------------------------------------------------------------------------
# B. _can_estimate_kv Gate Tests
# ---------------------------------------------------------------------------
class TestCanEstimateKV:
"""Verify gate logic for all field combinations."""
def test_no_layers_returns_false(self):
b = LlamaCppBackend()
b._n_layers = None
b._kv_key_length = 128
assert not b._can_estimate_kv()
def test_explicit_both_dims_sufficient(self):
b = LlamaCppBackend()
b._n_layers = 32
b._kv_key_length = 128
b._kv_value_length = 128
assert b._can_estimate_kv()
def test_key_length_alone_insufficient(self):
"""key_length without value_length should NOT be enough."""
b = LlamaCppBackend()
b._n_layers = 32
b._kv_key_length = 128
assert not b._can_estimate_kv()
def test_kv_lora_rank_sufficient(self):
b = LlamaCppBackend()
b._n_layers = 61
b._kv_lora_rank = 512
assert b._can_estimate_kv()
def test_legacy_embed_plus_heads(self):
b = LlamaCppBackend()
b._n_layers = 28
b._embedding_length = 1024
b._n_heads = 16
assert b._can_estimate_kv()
def test_legacy_embed_plus_kv_heads(self):
b = LlamaCppBackend()
b._n_layers = 28
b._embedding_length = 1024
b._n_kv_heads = 8
assert b._can_estimate_kv()
def test_legacy_no_embed_returns_false(self):
b = LlamaCppBackend()
b._n_layers = 28
b._n_heads = 16
# No embedding_length, no new-style fields
assert not b._can_estimate_kv()
def test_fresh_backend_returns_false(self):
b = LlamaCppBackend()
assert not b._can_estimate_kv()
# ---------------------------------------------------------------------------
# C. Path 1: MLA Estimation
# ---------------------------------------------------------------------------
class TestMLAEstimation:
"""MLA: K-only cache using compressed KV latent + RoPE."""
def _mla_backend(self, **overrides):
defaults = {
"_n_layers": 61,
"_n_kv_heads": 1,
"_n_heads": 128,
"_embedding_length": 7168,
"_kv_key_length": 576,
"_kv_value_length": 512,
"_kv_lora_rank": 512,
"_key_length_mla": 192,
}
defaults.update(overrides)
b = LlamaCppBackend()
for k, v in defaults.items():
setattr(b, k, v)
return b
def test_deepseek_v3_f16(self):
b = self._mla_backend()
# 61 layers * 163840 ctx * 1 head * 576 key_len * 2 bpe
expected = 61 * 163840 * 1 * 576 * 2
assert b._estimate_kv_cache_bytes(163840, "f16") == expected
def test_mla_ignores_value_length(self):
"""MLA should NOT add value_length -- V is reconstructed from the latent."""
b = self._mla_backend()
result = b._estimate_kv_cache_bytes(1000, "f16")
# Should be n_layers * ctx * 1 * key_len(576) * 2
expected = 61 * 1000 * 1 * 576 * 2
assert result == expected
def test_mla_fallback_when_no_key_length(self):
"""If key_length is missing, fallback to kv_lora_rank + key_length_mla."""
b = self._mla_backend(_kv_key_length = None)
# _key_length_mla=192 in default, so rope_dim=192
result = b._estimate_kv_cache_bytes(1000, "f16")
expected = 61 * 1000 * 1 * (512 + 192) * 2 # 704
assert result == expected
def test_mla_fallback_no_key_length_mla(self):
"""If both key_length and key_length_mla are missing, fallback to +64."""
b = self._mla_backend(_kv_key_length = None, _key_length_mla = None)
result = b._estimate_kv_cache_bytes(1000, "f16")
expected = 61 * 1000 * 1 * (512 + 64) * 2 # 576
assert result == expected
def test_mla_defaults_n_kv_to_1_when_heads_absent(self):
"""MLA should use n_kv=1 even if n_kv_heads is None (not n_heads)."""
b = self._mla_backend(_n_kv_heads = None) # n_heads=128 still set
result = b._estimate_kv_cache_bytes(1000, "f16")
# Should use n_kv_mla=1, NOT n_heads=128
expected = 61 * 1000 * 1 * 576 * 2
assert result == expected
def test_mla_q4_quantization(self):
b = self._mla_backend()
result_f16 = b._estimate_kv_cache_bytes(1000, "f16")
result_q4 = b._estimate_kv_cache_bytes(1000, "q4_0")
assert result_q4 < result_f16
# q4_0 bpe = 0.5625, f16 bpe = 2.0
assert result_q4 == int(61 * 1000 * 1 * 576 * 0.5625)
# ---------------------------------------------------------------------------
# D. Path 2: Hybrid Mamba Estimation
# ---------------------------------------------------------------------------
class TestHybridMambaEstimation:
"""Hybrid Mamba: only attention layers (1 in N) need KV cache."""
def _hybrid_backend(self, **overrides):
defaults = {
"_n_layers": 64,
"_n_kv_heads": 4,
"_n_heads": 24,
"_embedding_length": 5120,
"_kv_key_length": 256,
"_kv_value_length": 256,
"_full_attention_interval": 4,
"_ssm_inner_size": 6144,
"_ssm_state_size": 128,
}
defaults.update(overrides)
b = LlamaCppBackend()
for k, v in defaults.items():
setattr(b, k, v)
return b
def test_qwen35_27b(self):
b = self._hybrid_backend()
# n_attn = 64 // 4 = 16
expected = 16 * 262144 * 4 * (256 + 256) * 2
assert b._estimate_kv_cache_bytes(262144, "f16") == expected
def test_qwen35_35b_a3b(self):
b = self._hybrid_backend(
_n_layers = 40,
_n_kv_heads = 2,
_n_heads = 16,
_embedding_length = 2048,
_ssm_inner_size = 4096,
)
# n_attn = 40 // 4 = 10
expected = 10 * 262144 * 2 * (256 + 256) * 2
assert b._estimate_kv_cache_bytes(262144, "f16") == expected
def test_hybrid_without_explicit_dims(self):
"""Fallback to head_dim when key_length/value_length are missing."""
b = self._hybrid_backend(_kv_key_length = None, _kv_value_length = None)
head_dim = 5120 // 24 # 213
expected = 16 * 4096 * 4 * 2 * head_dim * 2
assert b._estimate_kv_cache_bytes(4096, "f16") == expected
def test_fai_zero_safety(self):
"""full_attention_interval=0 should not cause ZeroDivisionError."""
b = self._hybrid_backend(_full_attention_interval = 0)
result = b._estimate_kv_cache_bytes(4096, "f16")
# fai=0 -> n_attn = n_layers (all layers)
expected = 64 * 4096 * 4 * (256 + 256) * 2
assert result == expected
# ---------------------------------------------------------------------------
# E. Path 3: Sliding Window Estimation
# ---------------------------------------------------------------------------
class TestSlidingWindowEstimation:
"""SWA: half global (full ctx) + half sliding window."""
def _swa_backend(self, **overrides):
defaults = {
"_n_layers": 62,
"_n_kv_heads": 16,
"_n_heads": 32,
"_embedding_length": 5376,
"_kv_key_length": 128,
"_kv_value_length": 128,
"_sliding_window": 1024,
}
defaults.update(overrides)
b = LlamaCppBackend()
for k, v in defaults.items():
setattr(b, k, v)
return b
def test_gemma3(self):
b = self._swa_backend()
# 1/4 heuristic: 62 // 4 = 15 global, 47 SWA
n_global = max(1, 62 // 4) # 15
n_swa = 62 - n_global # 47
kv_per = 16 * (128 + 128) * 2
expected = int(n_global * 131072 * kv_per + n_swa * min(131072, 1024) * kv_per)
assert b._estimate_kv_cache_bytes(131072, "f16") == expected
def test_gpt_oss(self):
b = self._swa_backend(
_n_layers = 24,
_n_kv_heads = 8,
_n_heads = 64,
_embedding_length = 2880,
_kv_key_length = 64,
_kv_value_length = 64,
_sliding_window = 128,
)
# 1/4 heuristic: 24 // 4 = 6 global, 18 SWA
n_global = max(1, 24 // 4) # 6
n_swa = 24 - n_global # 18
kv_per = 8 * (64 + 64) * 2
expected = int(n_global * 131072 * kv_per + n_swa * min(131072, 128) * kv_per)
assert b._estimate_kv_cache_bytes(131072, "f16") == expected
def test_ctx_smaller_than_window(self):
"""When context < sliding_window, SWA layers use full context anyway."""
b = self._swa_backend(_sliding_window = 8192)
n_global = max(1, 62 // 4) # 15
n_swa = 62 - n_global # 47
kv_per = 16 * (128 + 128) * 2
ctx = 4096
expected = int(n_global * ctx * kv_per + n_swa * min(ctx, 8192) * kv_per)
# min(4096, 8192) = 4096, so both pools use full ctx
assert b._estimate_kv_cache_bytes(ctx, "f16") == expected
def test_odd_layer_count(self):
"""Odd layer count: n_global = max(1, n//4), n_swa = n - n_global."""
b = self._swa_backend(_n_layers = 63)
n_global = max(1, 63 // 4) # 15
n_swa = 63 - n_global # 48
kv_per = 16 * (128 + 128) * 2
expected = int(n_global * 1000 * kv_per + n_swa * min(1000, 1024) * kv_per)
assert b._estimate_kv_cache_bytes(1000, "f16") == expected
# ---------------------------------------------------------------------------
# F. Path 4: Standard GQA Estimation
# ---------------------------------------------------------------------------
class TestStandardGQAEstimation:
"""Standard GQA with explicit key_length/value_length."""
def _gqa_backend(self, **overrides):
defaults = {
"_n_layers": 28,
"_n_kv_heads": 8,
"_n_heads": 16,
"_embedding_length": 1024,
"_kv_key_length": 128,
"_kv_value_length": 128,
}
defaults.update(overrides)
b = LlamaCppBackend()
for k, v in defaults.items():
setattr(b, k, v)
return b
def test_qwen3_06b(self):
b = self._gqa_backend()
expected = 28 * 40960 * 8 * (128 + 128) * 2
assert b._estimate_kv_cache_bytes(40960, "f16") == expected
def test_asymmetric_kv_dims(self):
"""key_length != value_length (some architectures have this)."""
b = self._gqa_backend(_kv_key_length = 192, _kv_value_length = 64)
expected = 28 * 4096 * 8 * (192 + 64) * 2
assert b._estimate_kv_cache_bytes(4096, "f16") == expected
def test_differs_from_legacy(self):
"""GQA path should differ from legacy when key_length != embed//n_heads."""
b = self._gqa_backend()
head_dim = 1024 // 16 # 64
gqa_result = b._estimate_kv_cache_bytes(4096, "f16")
# Legacy would use: 2 * 8 * 64 * 28 * 4096 * 2
legacy_result = int(2 * 8 * head_dim * 28 * 4096 * 2)
# GQA: 28 * 4096 * 8 * (128+128) * 2 -- uses actual key_length=128
assert gqa_result != legacy_result
assert gqa_result > legacy_result # key_length (128) > head_dim (64)
# ---------------------------------------------------------------------------
# G. Path 5: Legacy Fallback Estimation
# ---------------------------------------------------------------------------
class TestLegacyEstimation:
"""Legacy: embed // n_heads, for old GGUFs without new fields."""
def _legacy_backend(self, **overrides):
defaults = {
"_n_layers": 32,
"_n_kv_heads": 8,
"_n_heads": 32,
"_embedding_length": 4096,
}
defaults.update(overrides)
b = LlamaCppBackend()
for k, v in defaults.items():
setattr(b, k, v)
return b
def test_basic_legacy(self):
b = self._legacy_backend()
head_dim = 4096 // 32 # 128
expected = int(2 * 8 * 128 * 32 * 4096 * 2)
assert b._estimate_kv_cache_bytes(4096, "f16") == expected
def test_legacy_with_only_n_heads(self):
"""n_kv_heads is None, falls back to n_heads."""
b = self._legacy_backend(_n_kv_heads = None)
head_dim = 4096 // 32
expected = int(2 * 32 * head_dim * 32 * 4096 * 2)
assert b._estimate_kv_cache_bytes(4096, "f16") == expected
def test_legacy_identical_to_old_formula(self):
"""Confirm legacy path produces the same result as the pre-PR formula."""
b = self._legacy_backend()
n_layers = 32
n_kv_heads = 8
head_dim = 4096 // 32
n_ctx = 8192
bpe = 2.0
old_formula = int(2 * n_kv_heads * head_dim * n_layers * n_ctx * bpe)
assert b._estimate_kv_cache_bytes(n_ctx, "f16") == old_formula
# ---------------------------------------------------------------------------
# H. Path Priority (selection order)
# ---------------------------------------------------------------------------
class TestPathPriority:
"""Confirm: MLA > Hybrid Mamba > SWA > GQA > Legacy."""
def test_mla_takes_priority_over_all(self):
"""If kv_lora_rank is set, MLA path is used even if other fields are present."""
b = LlamaCppBackend()
b._n_layers = 61
b._n_kv_heads = 1
b._n_heads = 128
b._embedding_length = 7168
b._kv_key_length = 576
b._kv_value_length = 512
b._kv_lora_rank = 512
b._ssm_inner_size = 4096 # Would trigger Hybrid
b._full_attention_interval = 4
b._sliding_window = 1024 # Would trigger SWA
# MLA: 61 * 1000 * 1 * 576 * 2
expected_mla = int(61 * 1000 * 1 * 576 * 2)
assert b._estimate_kv_cache_bytes(1000, "f16") == expected_mla
def test_hybrid_over_swa(self):
"""Hybrid takes priority over SWA when both fields present."""
b = LlamaCppBackend()
b._n_layers = 64
b._n_kv_heads = 4
b._n_heads = 24
b._embedding_length = 5120
b._kv_key_length = 256
b._kv_value_length = 256
b._ssm_inner_size = 6144
b._full_attention_interval = 4
b._sliding_window = 1024 # Would trigger SWA
n_attn = 64 // 4
expected_hybrid = int(n_attn * 1000 * 4 * (256 + 256) * 2)
assert b._estimate_kv_cache_bytes(1000, "f16") == expected_hybrid
def test_all_paths_produce_different_values(self):
"""With carefully chosen params, each path should yield a distinct value."""
# Use embedding_length=768 so legacy head_dim (768//16=48) differs from
# key_length (256), and MLA key_len (256) != legacy K+V (2*48=96).
params = {
"_n_layers": 40,
"_n_kv_heads": 4,
"_n_heads": 16,
"_embedding_length": 768,
"_kv_key_length": 256,
"_kv_value_length": 256,
}
ctx = 4096
# Path 4: Standard GQA
b_gqa = LlamaCppBackend()
for k, v in params.items():
setattr(b_gqa, k, v)
gqa_val = b_gqa._estimate_kv_cache_bytes(ctx, "f16")
# Path 1: MLA
b_mla = LlamaCppBackend()
for k, v in params.items():
setattr(b_mla, k, v)
b_mla._kv_lora_rank = 512
mla_val = b_mla._estimate_kv_cache_bytes(ctx, "f16")
# Path 2: Hybrid Mamba
b_hybrid = LlamaCppBackend()
for k, v in params.items():
setattr(b_hybrid, k, v)
b_hybrid._ssm_inner_size = 4096
b_hybrid._full_attention_interval = 4
hybrid_val = b_hybrid._estimate_kv_cache_bytes(ctx, "f16")
# Path 3: SWA
b_swa = LlamaCppBackend()
for k, v in params.items():
setattr(b_swa, k, v)
b_swa._sliding_window = 512
swa_val = b_swa._estimate_kv_cache_bytes(ctx, "f16")
# Path 5: Legacy (no key_length/value_length)
b_legacy = LlamaCppBackend()
b_legacy._n_layers = 40
b_legacy._n_kv_heads = 4
b_legacy._n_heads = 16
b_legacy._embedding_length = 768
legacy_val = b_legacy._estimate_kv_cache_bytes(ctx, "f16")
values = [mla_val, hybrid_val, swa_val, gqa_val, legacy_val]
assert len(set(values)) == 5, f"Expected 5 distinct values, got {values}"
# ---------------------------------------------------------------------------
# I. KV Cache Quantization
# ---------------------------------------------------------------------------
class TestQuantization:
"""Verify all supported cache_type_kv values produce correct scaling."""
@pytest.mark.parametrize(
"cache_type,expected_bpe",
[
("f32", 4.0),
("f16", 2.0),
("bf16", 2.0),
("q8_0", 34 / 32),
("q5_1", 0.75),
("q5_0", 0.6875),
("q4_1", 0.625),
("q4_0", 0.5625),
("iq4_nl", 0.5625),
(None, 2.0), # default is f16
("unknown", 2.0), # unknown falls back to f16
],
)
def test_quantization_scaling(self, cache_type, expected_bpe):
b = LlamaCppBackend()
b._n_layers = 10
b._n_kv_heads = 1
b._n_heads = 8
b._embedding_length = 512
b._kv_key_length = 64
b._kv_value_length = 64
result = b._estimate_kv_cache_bytes(1000, cache_type)
expected = int(10 * 1000 * 1 * (64 + 64) * expected_bpe)
assert result == expected
# ---------------------------------------------------------------------------
# J. Edge Cases
# ---------------------------------------------------------------------------
class TestEdgeCases:
"""Boundary conditions and degenerate inputs."""
def test_zero_context(self):
b = LlamaCppBackend()
b._n_layers = 32
b._kv_key_length = 128
assert b._estimate_kv_cache_bytes(0, "f16") == 0
def test_negative_context(self):
b = LlamaCppBackend()
b._n_layers = 32
b._kv_key_length = 128
assert b._estimate_kv_cache_bytes(-1, "f16") == 0
def test_context_of_one(self):
b = LlamaCppBackend()
b._n_layers = 10
b._n_kv_heads = 1
b._kv_key_length = 64
b._kv_value_length = 64
result = b._estimate_kv_cache_bytes(1, "f16")
assert result == int(10 * 1 * 1 * (64 + 64) * 2)
def test_very_large_context(self):
"""1M context should not overflow or crash."""
b = LlamaCppBackend()
b._n_layers = 10
b._n_kv_heads = 1
b._kv_key_length = 128
b._kv_value_length = 128
result = b._estimate_kv_cache_bytes(1_000_000, "f16")
assert result > 0
assert isinstance(result, int)
def test_n_kv_heads_none_falls_to_n_heads(self):
b = LlamaCppBackend()
b._n_layers = 10
b._n_kv_heads = None
b._n_heads = 8
b._kv_key_length = 64
b._kv_value_length = 64
result = b._estimate_kv_cache_bytes(100, "f16")
expected = int(10 * 100 * 8 * (64 + 64) * 2)
assert result == expected
def test_both_heads_none_falls_to_one(self):
b = LlamaCppBackend()
b._n_layers = 10
b._n_kv_heads = None
b._n_heads = None
b._kv_key_length = 64
b._kv_value_length = 64
result = b._estimate_kv_cache_bytes(100, "f16")
expected = int(10 * 100 * 1 * (64 + 64) * 2)
assert result == expected
# ---------------------------------------------------------------------------
# K. Lifecycle Tests
# ---------------------------------------------------------------------------
class TestLifecycle:
"""Init, unload, and reparse field management."""
def test_init_fields_none(self):
b = LlamaCppBackend()
for attr in [
"_kv_key_length",
"_kv_value_length",
"_sliding_window",
"_full_attention_interval",
"_kv_lora_rank",
"_key_length_mla",
"_ssm_inner_size",
"_ssm_state_size",
]:
assert getattr(b, attr) is None
def test_unload_resets_fields(self):
b = LlamaCppBackend()
b._n_layers = 32
b._kv_key_length = 128
b._kv_lora_rank = 512
b._sliding_window = 1024
b._ssm_inner_size = 4096
b._full_attention_interval = 4
b.unload_model()
for attr in [
"_kv_key_length",
"_kv_value_length",
"_sliding_window",
"_full_attention_interval",
"_kv_lora_rank",
"_key_length_mla",
"_ssm_inner_size",
"_ssm_state_size",
]:
assert getattr(b, attr) is None
def test_end_to_end_synthetic_mla(self):
"""Full round-trip: write GGUF -> parse -> estimate."""
b = _backend_from_gguf(
"deepseek2",
{
"context_length": 163840,
"block_count": 61,
"attention.head_count_kv": 1,
"attention.head_count": 128,
"embedding_length": 7168,
"attention.key_length": 576,
"attention.value_length": 512,
"attention.kv_lora_rank": 512,
"attention.key_length_mla": 192,
},
)
assert b._can_estimate_kv()
result = b._estimate_kv_cache_bytes(163840, "f16")
expected = 61 * 163840 * 1 * 576 * 2
assert result == expected
def test_end_to_end_synthetic_hybrid(self):
b = _backend_from_gguf(
"qwen35",
{
"context_length": 262144,
"block_count": 64,
"attention.head_count_kv": 4,
"attention.head_count": 24,
"embedding_length": 5120,
"attention.key_length": 256,
"attention.value_length": 256,
"full_attention_interval": 4,
"ssm.inner_size": 6144,
"ssm.state_size": 128,
},
)
assert b._can_estimate_kv()
result = b._estimate_kv_cache_bytes(262144, "f16")
n_attn = 64 // 4
expected = n_attn * 262144 * 4 * (256 + 256) * 2
assert result == expected
def test_end_to_end_synthetic_swa(self):
b = _backend_from_gguf(
"gemma3",
{
"context_length": 131072,
"block_count": 62,
"attention.head_count_kv": 16,
"attention.head_count": 32,
"embedding_length": 5376,
"attention.key_length": 128,
"attention.value_length": 128,
"attention.sliding_window": 1024,
},
)
assert b._can_estimate_kv()
result = b._estimate_kv_cache_bytes(131072, "f16")
n_global = max(1, 62 // 4) # 15
n_swa = 62 - n_global # 47
kv_per = 16 * 256 * 2
expected = int(n_global * 131072 * kv_per + n_swa * 1024 * kv_per)
assert result == expected
def test_end_to_end_synthetic_gqa(self):
b = _backend_from_gguf(
"qwen3",
{
"context_length": 40960,
"block_count": 28,
"attention.head_count_kv": 8,
"attention.head_count": 16,
"embedding_length": 1024,
"attention.key_length": 128,
"attention.value_length": 128,
},
)
assert b._can_estimate_kv()
result = b._estimate_kv_cache_bytes(40960, "f16")
expected = 28 * 40960 * 8 * 256 * 2
assert result == expected
def test_end_to_end_synthetic_legacy(self):
b = _backend_from_gguf(
"llama",
{
"context_length": 4096,
"block_count": 32,
"attention.head_count_kv": 8,
"attention.head_count": 32,
"embedding_length": 4096,
},
)
assert b._can_estimate_kv()
result = b._estimate_kv_cache_bytes(4096, "f16")
head_dim = 4096 // 32
expected = int(2 * 8 * head_dim * 32 * 4096 * 2)
assert result == expected

View file

@ -0,0 +1,243 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""Tests for the cache-aware disk-space preflight in
``LlamaCppBackend.load_model``.
The preflight used to compare the repo's total GGUF download size against
free disk without accounting for bytes already present in the Hugging
Face cache. That made re-loading a cached large model (e.g.
``unsloth/MiniMax-M2.7-GGUF`` at 131 GB) fail cold whenever free disk was
below the full weight footprint, even though nothing needed
downloading.
These tests exercise the preflight arithmetic in isolation by driving
``get_paths_info`` and ``try_to_load_from_cache`` through ``mock.patch``.
No network, GPU, or subprocess use.
Cross-platform: Linux, macOS, Windows, WSL.
"""
from __future__ import annotations
import sys
import tempfile
import types as _types
from pathlib import Path
from unittest.mock import patch
import pytest
# ---------------------------------------------------------------------------
# Stub heavy / unavailable external dependencies before importing the
# module under test. Same pattern as test_kv_cache_estimation.py.
# ---------------------------------------------------------------------------
_BACKEND_DIR = str(Path(__file__).resolve().parent.parent)
if _BACKEND_DIR not in sys.path:
sys.path.insert(0, _BACKEND_DIR)
# loggers
_loggers_stub = _types.ModuleType("loggers")
_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name)
sys.modules.setdefault("loggers", _loggers_stub)
# structlog
_structlog_stub = _types.ModuleType("structlog")
sys.modules.setdefault("structlog", _structlog_stub)
# httpx
_httpx_stub = _types.ModuleType("httpx")
for _exc_name in (
"ConnectError",
"TimeoutException",
"ReadTimeout",
"ReadError",
"RemoteProtocolError",
"CloseError",
):
setattr(_httpx_stub, _exc_name, type(_exc_name, (Exception,), {}))
class _FakeTimeout:
def __init__(self, *a, **kw):
pass
_httpx_stub.Timeout = _FakeTimeout
_httpx_stub.Client = type(
"Client",
(),
{
"__init__": lambda self, **kw: None,
"__enter__": lambda self: self,
"__exit__": lambda self, *a: None,
},
)
sys.modules.setdefault("httpx", _httpx_stub)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
GIB = 1024**3
class _FakePathInfo:
"""Mimics huggingface_hub's RepoFile-ish return type from get_paths_info."""
def __init__(self, path: str, size: int):
self.path = path
self.size = size
def _preflight(
repo_files,
cached_files,
free_bytes,
hf_repo = "unsloth/Example-GGUF",
hf_token = None,
):
"""Run the preflight arithmetic as written in llama_cpp.py and return
the decision outcome as a dict.
``repo_files``: list of (filename, remote_bytes).
``cached_files``: dict {filename: on_disk_bytes} for files already in cache.
``free_bytes``: value returned by shutil.disk_usage(cache_dir).free.
"""
import os
import shutil
path_infos = [_FakePathInfo(name, size) for name, size in repo_files]
with tempfile.TemporaryDirectory() as tmp:
# Create SPARSE files for the cached ones so os.path.exists /
# os.path.getsize pass without actually allocating bytes on disk.
# This is critical when simulating multi-GB models.
cache_paths = {}
for name, sz in cached_files.items():
p = Path(tmp) / name.replace("/", "_")
with open(p, "wb") as fh:
if sz > 0:
fh.truncate(sz) # sparse allocation: no data blocks written
cache_paths[name] = str(p)
def fake_try_to_load_from_cache(repo_id, filename):
return cache_paths.get(filename)
# Mirror the same variable names and control flow as the real code
# so behavioral drift is caught immediately.
total_bytes = sum((p.size or 0) for p in path_infos)
already_cached_bytes = 0
for p in path_infos:
if not p.size:
continue
cached_path = fake_try_to_load_from_cache(hf_repo, p.path)
if isinstance(cached_path, str) and os.path.exists(cached_path):
try:
on_disk = os.path.getsize(cached_path)
except OSError:
on_disk = 0
if on_disk >= p.size:
already_cached_bytes += p.size
total_download_bytes = max(0, total_bytes - already_cached_bytes)
needed_download = total_download_bytes > free_bytes
return {
"total_bytes": total_bytes,
"already_cached_bytes": already_cached_bytes,
"total_download_bytes": total_download_bytes,
"would_raise_disk_error": (needed_download and total_download_bytes > 0),
}
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestCacheAwarePreflight:
def test_fully_cached_model_does_not_require_disk(self):
"""The MiniMax case: 131 GB weights cached, only 36 GB free.
Preflight must not raise."""
shards = [(f"UD-Q4_K_XL/shard-{i}.gguf", 35 * GIB) for i in range(4)]
cached = {name: size for name, size in shards}
out = _preflight(
repo_files = shards,
cached_files = cached,
free_bytes = 36 * GIB,
)
assert out["total_download_bytes"] == 0
assert out["already_cached_bytes"] == 140 * GIB
assert out["would_raise_disk_error"] is False
def test_partial_cache_only_counts_remaining_bytes(self):
"""Two of four shards cached: preflight against remaining 70 GB."""
shards = [(f"UD-Q4_K_XL/shard-{i}.gguf", 35 * GIB) for i in range(4)]
cached = {
shards[0][0]: shards[0][1],
shards[1][0]: shards[1][1],
}
out = _preflight(
repo_files = shards,
cached_files = cached,
free_bytes = 80 * GIB,
)
assert out["already_cached_bytes"] == 70 * GIB
assert out["total_download_bytes"] == 70 * GIB
assert out["would_raise_disk_error"] is False
def test_partial_cache_insufficient_disk_for_rest_still_raises(self):
"""Two of four shards cached; remaining 70 GB still bigger than
free disk -> preflight correctly wants to raise."""
shards = [(f"UD-Q4_K_XL/shard-{i}.gguf", 35 * GIB) for i in range(4)]
cached = {
shards[0][0]: shards[0][1],
shards[1][0]: shards[1][1],
}
out = _preflight(
repo_files = shards,
cached_files = cached,
free_bytes = 50 * GIB,
)
assert out["total_download_bytes"] == 70 * GIB
assert out["would_raise_disk_error"] is True
def test_nothing_cached_preserves_existing_behavior(self):
"""Cold-cache path still compares full download vs free disk."""
shards = [("UD-Q4_K_XL/shard-0.gguf", 40 * GIB)]
out = _preflight(
repo_files = shards,
cached_files = {},
free_bytes = 50 * GIB,
)
assert out["already_cached_bytes"] == 0
assert out["total_download_bytes"] == 40 * GIB
assert out["would_raise_disk_error"] is False
def test_incomplete_cached_blob_is_not_credited(self):
"""A partial file on disk (e.g. interrupted download) is not
counted as cached -- we still require bytes for it."""
shards = [("UD-Q4_K_XL/shard-0.gguf", 40 * GIB)]
partial = {"UD-Q4_K_XL/shard-0.gguf": 10 * GIB}
out = _preflight(
repo_files = shards,
cached_files = partial,
free_bytes = 50 * GIB,
)
assert out["already_cached_bytes"] == 0
assert out["total_download_bytes"] == 40 * GIB
assert out["would_raise_disk_error"] is False
def test_zero_size_path_infos_do_not_crash(self):
"""A path_info with size=0 should not be credited or break the
arithmetic."""
shards = [("mmproj.gguf", 0), ("UD-Q4_K_XL/shard-0.gguf", 40 * GIB)]
out = _preflight(
repo_files = shards,
cached_files = {},
free_bytes = 50 * GIB,
)
assert out["already_cached_bytes"] == 0
assert out["total_bytes"] == 40 * GIB

View file

@ -0,0 +1,389 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""Tests for the GGUF load-time context auto-fit decision.
Guards two regressions in ``LlamaCppBackend.load_model``:
1. **Auto mode on weights-exceed-VRAM** (``n_ctx == 0``): when the model
weights alone exceed 90% of every GPU subset's free memory, the
auto-pick loop used to exit without matching, leaving
``effective_ctx`` at the model's native context (e.g. 196608 for
MiniMax-M2.7). The intended default per Studio's UI spec is 4096 so
the slider lands on a usable value; the user can still drag higher
and trigger ``--fit on`` with a warning.
2. **Explicit ctx silently shrunk when KV overflows**: with fittable
weights but a requested ctx whose KV cache pushes total memory over
90% of VRAM, the old code binary-searched a smaller ctx and emitted
``-c <capped> -ngl -1`` without informing the caller. The UI had
already surfaced its "might be slower" warning and expects the user's
explicit ctx to be honored with ``--fit on`` flexing ``-ngl`` instead.
Tests avoid GPU probing, subprocess spawning, and GGUF I/O by driving the
post-metadata decision block directly against a stubbed instance.
Requires no GPU, network, or external libraries beyond pytest.
Cross-platform: Linux, macOS, Windows, WSL.
"""
from __future__ import annotations
import sys
import types as _types
from pathlib import Path
import pytest
# ---------------------------------------------------------------------------
# Stub heavy / unavailable external dependencies before importing the
# module under test. Same pattern as test_kv_cache_estimation.py.
# ---------------------------------------------------------------------------
_BACKEND_DIR = str(Path(__file__).resolve().parent.parent)
if _BACKEND_DIR not in sys.path:
sys.path.insert(0, _BACKEND_DIR)
# loggers
_loggers_stub = _types.ModuleType("loggers")
_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name)
sys.modules.setdefault("loggers", _loggers_stub)
# structlog
_structlog_stub = _types.ModuleType("structlog")
sys.modules.setdefault("structlog", _structlog_stub)
# httpx
_httpx_stub = _types.ModuleType("httpx")
for _exc_name in (
"ConnectError",
"TimeoutException",
"ReadTimeout",
"ReadError",
"RemoteProtocolError",
"CloseError",
):
setattr(_httpx_stub, _exc_name, type(_exc_name, (Exception,), {}))
class _FakeTimeout:
def __init__(self, *a, **kw):
pass
_httpx_stub.Timeout = _FakeTimeout
_httpx_stub.Client = type(
"Client",
(),
{
"__init__": lambda self, **kw: None,
"__enter__": lambda self: self,
"__exit__": lambda self, *a: None,
},
)
sys.modules.setdefault("httpx", _httpx_stub)
from core.inference.llama_cpp import LlamaCppBackend
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
GIB = 1024**3
FALLBACK_CTX = 4096
def _make_backend(
native_ctx = 131072,
n_layers = 80,
n_kv_heads = 8,
n_heads = 64,
kv_key_length = 128,
kv_value_length = 128,
):
"""Create a LlamaCppBackend instance with GGUF metadata fields set and
the helpers used by the decision block stubbed out."""
inst = LlamaCppBackend.__new__(LlamaCppBackend)
inst._context_length = native_ctx
inst._n_layers = n_layers
inst._n_kv_heads = n_kv_heads
inst._n_heads = n_heads
inst._embedding_length = 8192
inst._kv_key_length = kv_key_length
inst._kv_value_length = kv_value_length
inst._kv_lora_rank = None
inst._sliding_window = None
inst._ssm_inner_size = None
inst._full_attention_interval = None
inst._key_length_mla = None
return inst
def _drive(
n_ctx,
model_gib,
gpus,
native_ctx = 131072,
kv_per_token_bytes = 325_000,
can_estimate_kv = True,
):
"""Drive the post-metadata portion of load_model with stubbed inputs.
Mirrors the decision block at llama_cpp.py:1137-1296 so we can assert
the command that would be built, without subprocesses or GPU probes.
"""
inst = _make_backend(native_ctx = native_ctx)
model_size = int(model_gib * GIB)
cache_type_kv = None
def fake_estimate(n_ctx_, _type = None):
return 0 if n_ctx_ <= 0 else n_ctx_ * kv_per_token_bytes
inst._estimate_kv_cache_bytes = fake_estimate
inst._can_estimate_kv = lambda: can_estimate_kv
context_length = inst._context_length
effective_ctx = n_ctx if n_ctx > 0 else (context_length or 0)
max_available_ctx = context_length or effective_ctx
if n_ctx > 0:
effective_ctx = n_ctx
elif context_length is not None:
effective_ctx = context_length
else:
effective_ctx = 0
original_ctx = effective_ctx
max_available_ctx = context_length or effective_ctx
gpu_indices, use_fit = None, True
explicit_ctx = n_ctx > 0
if gpus and inst._can_estimate_kv() and effective_ctx > 0:
native_ctx_for_cap = context_length or effective_ctx
if native_ctx_for_cap > 0:
ranked_for_cap = sorted(gpus, key = lambda g: g[1], reverse = True)
best_cap = 0
for n_gpus in range(1, len(ranked_for_cap) + 1):
subset = ranked_for_cap[:n_gpus]
pool_mib = sum(free for _, free in subset)
capped = inst._fit_context_to_vram(
native_ctx_for_cap,
pool_mib,
model_size,
cache_type_kv,
)
kv = inst._estimate_kv_cache_bytes(capped, cache_type_kv)
total_mib = (model_size + kv) / (1024 * 1024)
if total_mib <= pool_mib * 0.90:
best_cap = max(best_cap, capped)
if best_cap > 0:
max_available_ctx = best_cap
if explicit_ctx:
requested_total = model_size + inst._estimate_kv_cache_bytes(
effective_ctx, cache_type_kv
)
gpu_indices, use_fit = inst._select_gpus(requested_total, gpus)
else:
ranked = sorted(gpus, key = lambda g: g[1], reverse = True)
matched = False
for n_gpus in range(1, len(ranked) + 1):
subset = ranked[:n_gpus]
pool_mib = sum(free for _, free in subset)
capped = inst._fit_context_to_vram(
effective_ctx,
pool_mib,
model_size,
cache_type_kv,
)
kv = inst._estimate_kv_cache_bytes(capped, cache_type_kv)
total_mib = (model_size + kv) / (1024 * 1024)
if total_mib <= pool_mib * 0.90:
effective_ctx = capped
gpu_indices = sorted(idx for idx, _ in subset)
use_fit = False
matched = True
break
if not matched:
effective_ctx = min(FALLBACK_CTX, effective_ctx)
elif gpus:
gpu_indices, use_fit = inst._select_gpus(model_size, gpus)
if use_fit and not explicit_ctx:
effective_ctx = (
min(FALLBACK_CTX, effective_ctx) if effective_ctx > 0 else FALLBACK_CTX
)
return {
"c_arg": effective_ctx if effective_ctx > 0 else 0,
"use_fit": use_fit,
"gpu_indices": gpu_indices,
"max_available_ctx": max_available_ctx,
"original_ctx": original_ctx,
}
# ---------------------------------------------------------------------------
# Auto mode, model weights exceed VRAM (Bug A guard)
# ---------------------------------------------------------------------------
class TestAutoModeWeightsExceedVRAM:
"""``n_ctx == 0`` on a model whose weights don't fit anywhere."""
def test_minimax_like_single_gpu(self):
plan = _drive(
n_ctx = 0,
model_gib = 131,
gpus = [(0, 97_000)],
native_ctx = 196608,
)
assert plan["c_arg"] == FALLBACK_CTX
assert plan["use_fit"] is True
assert plan["gpu_indices"] is None
# UI slider ceiling stays at native: user can still drag higher
# and get the "might be slower" path.
assert plan["max_available_ctx"] == 196608
def test_multi_gpu_all_subsets_fail(self):
plan = _drive(
n_ctx = 0,
model_gib = 400,
gpus = [(0, 80_000), (1, 80_000), (2, 80_000), (3, 80_000)],
native_ctx = 131072,
)
assert plan["c_arg"] == FALLBACK_CTX
assert plan["use_fit"] is True
assert plan["gpu_indices"] is None
def test_no_kv_metadata_auto(self):
"""File-size-only fallback path also defaults to 4096."""
plan = _drive(
n_ctx = 0,
model_gib = 131,
gpus = [(0, 97_000)],
native_ctx = 196608,
can_estimate_kv = False,
)
assert plan["c_arg"] == FALLBACK_CTX
assert plan["use_fit"] is True
# ---------------------------------------------------------------------------
# Explicit ctx, KV overflows fittable weights (Bug B guard)
# ---------------------------------------------------------------------------
class TestExplicitCtxRespectsUser:
"""``n_ctx > 0`` must never be silently shrunk."""
def test_fittable_weights_oversized_kv(self):
# 8 GB weights + 131k ctx KV on 24 GB VRAM.
# Budget = 21.6 GB, KV at 131k >> 13.6 GB remaining, so
# _select_gpus flips use_fit=True.
plan = _drive(
n_ctx = 131072,
model_gib = 8,
gpus = [(0, 24_000)],
native_ctx = 131072,
)
assert plan["c_arg"] == 131072
assert plan["use_fit"] is True
assert plan["gpu_indices"] is None
def test_explicit_that_fits_uses_ngl(self):
plan = _drive(
n_ctx = 8192,
model_gib = 8,
gpus = [(0, 24_000)],
native_ctx = 131072,
)
assert plan["c_arg"] == 8192
assert plan["use_fit"] is False
assert plan["gpu_indices"] == [0]
def test_explicit_on_weights_exceed_vram(self):
# User drags the slider to 32k on a too-big model: honored.
plan = _drive(
n_ctx = 32768,
model_gib = 131,
gpus = [(0, 97_000)],
native_ctx = 196608,
)
assert plan["c_arg"] == 32768
assert plan["use_fit"] is True
def test_explicit_at_fallback_on_too_big(self):
plan = _drive(
n_ctx = FALLBACK_CTX,
model_gib = 131,
gpus = [(0, 97_000)],
native_ctx = 196608,
)
assert plan["c_arg"] == FALLBACK_CTX
assert plan["use_fit"] is True
def test_explicit_below_floor_honored(self):
# 2048 is below --fit-ctx default; still honored since user set it.
plan = _drive(
n_ctx = 2048,
model_gib = 8,
gpus = [(0, 24_000)],
)
assert plan["c_arg"] == 2048
# ---------------------------------------------------------------------------
# Non-regression: fittable + auto still auto-picks largest fitting ctx
# ---------------------------------------------------------------------------
class TestFittableAutoPickRegressions:
def test_small_model_one_gpu(self):
plan = _drive(
n_ctx = 0,
model_gib = 8,
gpus = [(0, 24_000)],
native_ctx = 131072,
kv_per_token_bytes = 8192,
)
assert plan["use_fit"] is False
assert plan["gpu_indices"] == [0]
assert plan["c_arg"] > FALLBACK_CTX
def test_medium_model_needs_multi_gpu(self):
plan = _drive(
n_ctx = 0,
model_gib = 60,
gpus = [(0, 40_000), (1, 40_000)],
native_ctx = 131072,
kv_per_token_bytes = 8192,
)
assert plan["use_fit"] is False
assert plan["gpu_indices"] == [0, 1]
def test_no_kv_metadata_fittable_auto(self):
plan = _drive(
n_ctx = 0,
model_gib = 8,
gpus = [(0, 24_000)],
native_ctx = 131072,
can_estimate_kv = False,
)
assert plan["use_fit"] is False
assert plan["gpu_indices"] == [0]
# ---------------------------------------------------------------------------
# Platform-agnostic input shape
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("platform_tag", ["linux", "windows", "mac", "rocm"])
def test_identical_decision_across_platforms(platform_tag):
"""The decision function takes ``[(gpu_idx, free_mib), ...]`` regardless
of how upstream (nvidia-smi / nvidia-smi.exe / Metal / rocm-smi) produced
it. Identical inputs must yield identical plans."""
plan_a = _drive(n_ctx = 0, model_gib = 8, gpus = [(0, 24_000)])
plan_b = _drive(n_ctx = 0, model_gib = 8, gpus = [(0, 24_000)])
assert plan_a == plan_b, platform_tag

View file

@ -0,0 +1,258 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""Tests for ``LlamaCppBackend.load_progress()``.
The chat settings flow and the training overlay both show a generic
"Starting model..." spinner during the window after a GGUF download
finishes and before llama-server reports healthy. For small models
that window is a second or two and nobody notices. For large MoE GGUFs
(MiniMax-M2.7, Qwen3.5-397B-A17B, etc.) the llama-server process spends
minutes in kernel state D, paging tens or hundreds of GB of shards
into the page cache. The UI has no way to show a real progress bar,
rate, or ETA during that window.
``load_progress()`` samples ``/proc/<pid>/status VmRSS`` (what the
kernel has actually paged in) against the total shard file size on
disk, so the frontend can render a real bar plus rate/ETA. This
module pins that contract:
* returns ``None`` when no load is in flight
* returns ``{"phase": "mmap", ...}`` while the subprocess is alive
but ``_healthy`` is False
* returns ``{"phase": "ready", ...}`` once ``_healthy`` flips
* ``bytes_total`` is derived from the resolved on-disk path
(which the paired fix assigns to ``self._gguf_path`` on both the
local-GGUF and HF-download code paths)
* ``bytes_loaded`` is VmRSS in bytes, capped by total, rounded
* ``fraction`` is clamped to 0..1 and rounded to 4 decimal places
Linux-only via ``/proc``. On platforms without ``/proc`` the method
returns ``None`` instead of raising.
Cross-platform test: skips cleanly on macOS / Windows if ``/proc`` is
not available.
"""
from __future__ import annotations
import os
import sys
import tempfile
import types as _types
from pathlib import Path
from unittest.mock import patch
import pytest
# ---------------------------------------------------------------------------
# Stub heavy / unavailable external dependencies before importing the
# module under test. Same pattern as test_kv_cache_estimation.py.
# ---------------------------------------------------------------------------
_BACKEND_DIR = str(Path(__file__).resolve().parent.parent)
if _BACKEND_DIR not in sys.path:
sys.path.insert(0, _BACKEND_DIR)
_loggers_stub = _types.ModuleType("loggers")
_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name)
sys.modules.setdefault("loggers", _loggers_stub)
_structlog_stub = _types.ModuleType("structlog")
sys.modules.setdefault("structlog", _structlog_stub)
_httpx_stub = _types.ModuleType("httpx")
for _exc_name in (
"ConnectError",
"TimeoutException",
"ReadTimeout",
"ReadError",
"RemoteProtocolError",
"CloseError",
):
setattr(_httpx_stub, _exc_name, type(_exc_name, (Exception,), {}))
class _FakeTimeout:
def __init__(self, *a, **kw):
pass
_httpx_stub.Timeout = _FakeTimeout
_httpx_stub.Client = type(
"Client",
(),
{
"__init__": lambda self, **kw: None,
"__enter__": lambda self: self,
"__exit__": lambda self, *a: None,
},
)
sys.modules.setdefault("httpx", _httpx_stub)
from core.inference.llama_cpp import LlamaCppBackend
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_instance():
inst = LlamaCppBackend.__new__(LlamaCppBackend)
inst._process = None
inst._gguf_path = None
inst._healthy = False
return inst
class _FakeProc:
"""Minimal stand-in for subprocess.Popen that just carries a pid."""
def __init__(self, pid: int):
self.pid = pid
def _write_sparse_file(path: Path, size_bytes: int) -> None:
"""Create a sparse file of the given size without allocating blocks."""
with open(path, "wb") as fh:
if size_bytes > 0:
fh.truncate(size_bytes)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestLoadProgressEmptyStates:
def test_returns_none_when_no_process(self):
inst = _make_instance()
assert inst.load_progress() is None
def test_returns_none_when_process_has_no_pid(self):
inst = _make_instance()
inst._process = _FakeProc(pid = None) # type: ignore[arg-type]
assert inst.load_progress() is None
class TestLoadProgressSingleShard:
def test_mmap_phase_for_alive_but_unhealthy(self, tmp_path):
"""VmRSS below total -> phase='mmap', fraction reflects progress."""
gguf = tmp_path / "model.gguf"
_write_sparse_file(gguf, 40 * 1024**3) # 40 GB
inst = _make_instance()
inst._process = _FakeProc(pid = os.getpid()) # use our own pid
inst._gguf_path = str(gguf)
inst._healthy = False
# Patch /proc read to claim 10 GB RSS.
def fake_open(path, *args, **kwargs):
if str(path).startswith("/proc/"):
import io
return io.StringIO(f"Name:\ttest\nVmRSS:\t{10 * 1024 ** 2}\tkB\n")
return open(path, *args, **kwargs) # fall through
with patch("builtins.open", side_effect = fake_open):
out = inst.load_progress()
assert out is not None
assert out["phase"] == "mmap"
assert out["bytes_total"] == 40 * 1024**3
assert out["bytes_loaded"] == 10 * 1024**3
assert 0.24 < out["fraction"] < 0.26 # ~25%
def test_ready_phase_when_healthy(self, tmp_path):
gguf = tmp_path / "model.gguf"
_write_sparse_file(gguf, 8 * 1024**3)
inst = _make_instance()
inst._process = _FakeProc(pid = os.getpid())
inst._gguf_path = str(gguf)
inst._healthy = True
def fake_open(path, *args, **kwargs):
if str(path).startswith("/proc/"):
import io
return io.StringIO(f"VmRSS:\t{8 * 1024 ** 2}\tkB\n")
return open(path, *args, **kwargs)
with patch("builtins.open", side_effect = fake_open):
out = inst.load_progress()
assert out is not None
assert out["phase"] == "ready"
assert out["bytes_total"] == 8 * 1024**3
assert out["bytes_loaded"] == 8 * 1024**3
assert out["fraction"] == 1.0
class TestLoadProgressMultiShard:
"""Shard-aware total: for ``*-00001-of-00004.gguf`` primaries the
method sums sibling files with the same prefix."""
def test_sharded_total_aggregates_siblings(self, tmp_path):
for i in range(1, 5):
_write_sparse_file(
tmp_path / f"model-{i:05d}-of-00004.gguf",
size_bytes = 20 * 1024**3,
)
# Drop an unrelated .gguf in the same folder -- must not be counted.
_write_sparse_file(tmp_path / "mmproj-BF16.gguf", 2 * 1024**3)
inst = _make_instance()
inst._process = _FakeProc(pid = os.getpid())
inst._gguf_path = str(tmp_path / "model-00001-of-00004.gguf")
inst._healthy = False
def fake_open(path, *args, **kwargs):
if str(path).startswith("/proc/"):
import io
return io.StringIO("VmRSS:\t0\tkB\n")
return open(path, *args, **kwargs)
with patch("builtins.open", side_effect = fake_open):
out = inst.load_progress()
assert out is not None
assert out["bytes_total"] == 80 * 1024**3 # 4 x 20 GB, no mmproj
class TestLoadProgressDegradation:
"""Broken / unusual inputs never raise; they produce best-effort output."""
def test_missing_gguf_path_still_reports_rss(self, tmp_path):
inst = _make_instance()
inst._process = _FakeProc(pid = os.getpid())
inst._gguf_path = None
inst._healthy = False
def fake_open(path, *args, **kwargs):
if str(path).startswith("/proc/"):
import io
return io.StringIO("VmRSS:\t1024\tkB\n")
return open(path, *args, **kwargs)
with patch("builtins.open", side_effect = fake_open):
out = inst.load_progress()
assert out is not None
assert out["phase"] == "mmap"
assert out["bytes_total"] == 0
assert out["bytes_loaded"] == 1024 * 1024
assert out["fraction"] == 0.0
def test_unreadable_proc_returns_none(self, tmp_path):
inst = _make_instance()
# Pid that doesn't exist -> /proc read fails.
inst._process = _FakeProc(pid = 999_999_999)
inst._gguf_path = str(tmp_path / "model.gguf") # doesn't need to exist
inst._healthy = False
out = inst.load_progress()
# FileNotFoundError on /proc path -> load_progress returns None.
assert out is None

View file

@ -0,0 +1,202 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""Live, no-mock integration test for ``LlamaCppBackend.load_progress()``.
The companion files (``test_llama_cpp_load_progress.py`` and
``test_llama_cpp_load_progress_matrix.py``) patch ``builtins.open`` to
feed synthetic VmRSS values. This file is the opposite: it uses **real**
subprocesses, **real** file sizes, and the **real** ``/proc``
interface. It is the sanity check that the contract we keep in the
mocked tests still maps to what the kernel actually returns on a live
Linux system.
Why both: the mocked tests can be fooled by a buggy implementation that
parses ``/proc`` output in a format the kernel no longer uses, or that
makes assumptions about ``Path.stat()`` vs ``os.path.getsize``. This
file hits the real APIs so any format drift gets caught.
Skipped cleanly on non-Linux (no ``/proc``).
"""
from __future__ import annotations
import os
import subprocess
import sys
import time
import types as _types
from pathlib import Path
import pytest
# ---------------------------------------------------------------------------
# Same stubs as the matrix file (keep self-contained so the file can be
# run standalone as well as via the full suite).
# ---------------------------------------------------------------------------
_BACKEND_DIR = str(Path(__file__).resolve().parent.parent)
if _BACKEND_DIR not in sys.path:
sys.path.insert(0, _BACKEND_DIR)
_loggers_stub = _types.ModuleType("loggers")
_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name)
sys.modules.setdefault("loggers", _loggers_stub)
_structlog_stub = _types.ModuleType("structlog")
sys.modules.setdefault("structlog", _structlog_stub)
_httpx_stub = _types.ModuleType("httpx")
for _exc in (
"ConnectError",
"TimeoutException",
"ReadTimeout",
"ReadError",
"RemoteProtocolError",
"CloseError",
):
setattr(_httpx_stub, _exc, type(_exc, (Exception,), {}))
_httpx_stub.Timeout = type("Timeout", (), {"__init__": lambda self, *a, **k: None})
_httpx_stub.Client = type(
"Client",
(),
{
"__init__": lambda self, **kw: None,
"__enter__": lambda self: self,
"__exit__": lambda self, *a: None,
},
)
sys.modules.setdefault("httpx", _httpx_stub)
from core.inference.llama_cpp import LlamaCppBackend
pytestmark = pytest.mark.skipif(
not Path("/proc").exists(),
reason = "live /proc test is Linux-only",
)
def _make_backend(pid: int, gguf_path: str, healthy: bool = False):
inst = LlamaCppBackend.__new__(LlamaCppBackend)
inst._process = type("P", (), {"pid": pid})()
inst._gguf_path = gguf_path
inst._healthy = healthy
return inst
def test_live_rss_matches_kernel_vmrss(tmp_path):
"""Spawn a real child, let it allocate real bytes, confirm
``bytes_loaded`` tracks the kernel's VmRSS within a sane tolerance."""
# Child that allocates ~100 MB of zero'd bytes and then idles.
script = tmp_path / "burn.py"
script.write_text(
"import time, sys\n"
"buf = bytearray(100 * 1024 * 1024)\n" # 100 MB
"# touch every page so RSS actually grows\n"
"for i in range(0, len(buf), 4096):\n"
" buf[i] = 1\n"
"sys.stdout.write('ready\\n')\n"
"sys.stdout.flush()\n"
"time.sleep(10)\n"
)
proc = subprocess.Popen(
[sys.executable, str(script)],
stdout = subprocess.PIPE,
stderr = subprocess.PIPE,
)
try:
# Wait for the child to finish touching pages.
ready = proc.stdout.readline()
assert ready.strip() == b"ready"
# Create a fake 200 MB sparse gguf so bytes_total is concrete.
gguf = tmp_path / "model.gguf"
with open(gguf, "wb") as f:
f.truncate(200 * 1024 * 1024)
inst = _make_backend(proc.pid, str(gguf), healthy = False)
out = inst.load_progress()
assert out is not None, "load_progress returned None for live pid"
assert out["phase"] == "mmap"
assert out["bytes_total"] == 200 * 1024 * 1024
# VmRSS for the Python child includes the interpreter + the 100MB
# buffer, so a realistic floor is 50 MB and ceiling is 200 MB.
assert (
out["bytes_loaded"] >= 50 * 1024 * 1024
), f"bytes_loaded unexpectedly low: {out['bytes_loaded']}"
assert out["bytes_loaded"] <= 200 * 1024 * 1024
assert 0.0 < out["fraction"] <= 1.0
finally:
proc.terminate()
try:
proc.wait(timeout = 5)
except subprocess.TimeoutExpired:
proc.kill()
def test_live_ready_phase_when_healthy(tmp_path):
gguf = tmp_path / "m.gguf"
with open(gguf, "wb") as f:
f.truncate(1 * 1024 * 1024)
inst = _make_backend(os.getpid(), str(gguf), healthy = True)
out = inst.load_progress()
assert out is not None
assert out["phase"] == "ready"
assert out["bytes_total"] == 1 * 1024 * 1024
# Self-pid RSS is well above 1 MiB for CPython; fraction caps at 1.
assert out["fraction"] == 1.0
def test_live_dead_pid_returns_none(tmp_path):
"""A recently-dead pid may linger in /proc for ms; use a clearly
invalid id so the read reliably fails."""
gguf = tmp_path / "m.gguf"
gguf.touch()
inst = _make_backend(9_999_999_999, str(gguf), healthy = False)
out = inst.load_progress()
assert out is None
def test_live_shard_aggregation_counts_real_files(tmp_path):
"""With 4 real sibling shards on disk, ``bytes_total`` equals their
summed size to the byte."""
shard_size = 7 * 1024 * 1024 # 7 MB each
for i in range(1, 5):
f = tmp_path / f"model-{i:05d}-of-00004.gguf"
with open(f, "wb") as fh:
fh.truncate(shard_size)
# Unrelated file in same dir -- must not be counted.
with open(tmp_path / "config.json", "wb") as fh:
fh.truncate(123)
inst = _make_backend(
os.getpid(),
str(tmp_path / "model-00001-of-00004.gguf"),
healthy = False,
)
out = inst.load_progress()
assert out is not None
assert out["bytes_total"] == 4 * shard_size
def test_live_repeated_polling_stays_sane(tmp_path):
"""Sampling the same backend 20 times should not raise or produce
non-numeric output, even under normal kernel RSS jitter."""
gguf = tmp_path / "m.gguf"
with open(gguf, "wb") as f:
f.truncate(500 * 1024 * 1024)
inst = _make_backend(os.getpid(), str(gguf), healthy = False)
seen = []
for _ in range(20):
out = inst.load_progress()
assert out is not None
assert isinstance(out["bytes_loaded"], int)
assert isinstance(out["bytes_total"], int)
assert 0.0 <= out["fraction"] <= 1.0
seen.append(out["bytes_loaded"])
time.sleep(0.01)
# RSS of a healthy Python process doesn't go below ~5 MB.
assert min(seen) > 1 * 1024 * 1024

View file

@ -0,0 +1,473 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""Extended test matrix for ``LlamaCppBackend.load_progress()``.
Companion to ``test_llama_cpp_load_progress.py`` (which pins the basic
contract). This file widens coverage to the edge cases that bit users
or were hypothesized to bite them on cross-platform installs:
* Platform matrix macOS/Windows simulation via ``/proc`` absence.
* ``VmRSS`` parsing tab vs space delimiter, missing line, malformed
integer.
* Filesystem edges HF-cache symlinks, broken symlinks, nonexistent
paths, relative paths.
* Shard aggregation partial multi-shard downloads where some shards
are still ``.incomplete``, two shard series in the same dir,
``mmproj-*.gguf`` sibling exclusion for non-sharded primaries,
single-file models.
* Lifecycle races process set before ``_gguf_path`` is assigned,
process dead mid-sample, ``_healthy`` flipped to True.
* Concurrent sampling 10 threads × 50 iterations against a single
backend, hitting real ``/proc`` (no mocks see the note in
``TestConcurrentSampling`` for why).
* Fraction bounds capped at 1.0 when RSS exceeds total; 0.0 when
total is zero.
All tests are Linux-only in practice (we stub ``/proc`` where needed).
The stable subset runs in well under a second.
"""
from __future__ import annotations
import io
import os
import sys
import threading
import types as _types
from pathlib import Path
from unittest.mock import patch
import pytest
# ---------------------------------------------------------------------------
# Stub heavy / unavailable external dependencies before importing the
# module under test. Same pattern as test_llama_cpp_load_progress.py.
# ---------------------------------------------------------------------------
_BACKEND_DIR = str(Path(__file__).resolve().parent.parent)
if _BACKEND_DIR not in sys.path:
sys.path.insert(0, _BACKEND_DIR)
_loggers_stub = _types.ModuleType("loggers")
_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name)
sys.modules.setdefault("loggers", _loggers_stub)
_structlog_stub = _types.ModuleType("structlog")
sys.modules.setdefault("structlog", _structlog_stub)
_httpx_stub = _types.ModuleType("httpx")
for _exc_name in (
"ConnectError",
"TimeoutException",
"ReadTimeout",
"ReadError",
"RemoteProtocolError",
"CloseError",
):
setattr(_httpx_stub, _exc_name, type(_exc_name, (Exception,), {}))
class _FakeTimeout:
def __init__(self, *a, **kw):
pass
_httpx_stub.Timeout = _FakeTimeout
_httpx_stub.Client = type(
"Client",
(),
{
"__init__": lambda self, **kw: None,
"__enter__": lambda self: self,
"__exit__": lambda self, *a: None,
},
)
sys.modules.setdefault("httpx", _httpx_stub)
from core.inference.llama_cpp import LlamaCppBackend
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make():
inst = LlamaCppBackend.__new__(LlamaCppBackend)
inst._process = None
inst._gguf_path = None
inst._healthy = False
return inst
class _Proc:
def __init__(self, pid):
self.pid = pid
def _sparse(path, size):
with open(path, "wb") as f:
if size > 0:
f.truncate(size)
def _fake_proc_reader(rss_kb):
"""Return an ``open()`` replacement that fakes /proc reads with a VmRSS line."""
def fake_open(path, *args, **kwargs):
if str(path).startswith("/proc/"):
return io.StringIO(f"VmRSS:\t{rss_kb}\tkB\n")
return open(path, *args, **kwargs)
return fake_open
# ---------------------------------------------------------------------------
# A. Platform matrix
# ---------------------------------------------------------------------------
class TestPlatformMatrix:
"""The method is Linux-first via /proc. On macOS/Windows it must
degrade to None rather than crash."""
def test_linux_live_proc_is_self_pid(self, tmp_path):
"""Self-pid /proc read uses the real kernel interface."""
gguf = tmp_path / "m.gguf"
_sparse(gguf, 1 * 1024**3)
inst = _make()
inst._process = _Proc(os.getpid())
inst._gguf_path = str(gguf)
inst._healthy = False
out = inst.load_progress()
assert out is not None
assert out["phase"] == "mmap"
assert out["bytes_total"] == 1 * 1024**3
# Our Python process has some RSS -- just sanity-check positive.
assert out["bytes_loaded"] > 0
def test_macos_no_proc_returns_none(self, tmp_path):
"""Simulate macOS: /proc open fails with FileNotFoundError."""
gguf = tmp_path / "m.gguf"
_sparse(gguf, 1 * 1024**3)
inst = _make()
inst._process = _Proc(pid = 12345)
inst._gguf_path = str(gguf)
def fake_open(path, *args, **kwargs):
if str(path).startswith("/proc/"):
raise FileNotFoundError(f"No such file: {path}")
return open(path, *args, **kwargs)
with patch("builtins.open", side_effect = fake_open):
out = inst.load_progress()
assert out is None
def test_windows_no_proc_returns_none(self, tmp_path):
"""Simulate Windows: opening /proc raises PermissionError or OSError."""
gguf = tmp_path / "m.gguf"
_sparse(gguf, 1 * 1024**3)
inst = _make()
inst._process = _Proc(pid = 4567)
inst._gguf_path = str(gguf)
def fake_open(path, *args, **kwargs):
if str(path).startswith("/proc/"):
raise PermissionError("access denied")
return open(path, *args, **kwargs)
with patch("builtins.open", side_effect = fake_open):
out = inst.load_progress()
assert out is None
# ---------------------------------------------------------------------------
# B. VmRSS parsing edge cases
# ---------------------------------------------------------------------------
class TestVmRSSParsing:
def test_standard_tab_delimited(self, tmp_path):
gguf = tmp_path / "m.gguf"
_sparse(gguf, 4 * 1024**3)
inst = _make()
inst._process = _Proc(os.getpid())
inst._gguf_path = str(gguf)
with patch("builtins.open", side_effect = _fake_proc_reader(2 * 1024**2)):
out = inst.load_progress()
assert out["bytes_loaded"] == 2 * 1024**3
def test_space_separated_fallback(self, tmp_path):
"""Some kernels emit single-space rather than tab."""
gguf = tmp_path / "m.gguf"
_sparse(gguf, 4 * 1024**3)
inst = _make()
inst._process = _Proc(os.getpid())
inst._gguf_path = str(gguf)
def fake_open(path, *a, **kw):
if str(path).startswith("/proc/"):
return io.StringIO("VmRSS: 4194304 kB\n")
return open(path, *a, **kw)
with patch("builtins.open", side_effect = fake_open):
out = inst.load_progress()
assert out["bytes_loaded"] == 4 * 1024**3
def test_missing_vmrss_line(self, tmp_path):
"""Kernel with VmRSS stripped (zombie / kthread) -> 0."""
gguf = tmp_path / "m.gguf"
_sparse(gguf, 1 * 1024**3)
inst = _make()
inst._process = _Proc(os.getpid())
inst._gguf_path = str(gguf)
def fake_open(path, *a, **kw):
if str(path).startswith("/proc/"):
return io.StringIO("Name:\ttest\nState:\tZ (zombie)\n")
return open(path, *a, **kw)
with patch("builtins.open", side_effect = fake_open):
out = inst.load_progress()
assert out is not None
assert out["bytes_loaded"] == 0
assert out["fraction"] == 0.0
def test_malformed_vmrss_value(self, tmp_path):
"""Non-integer VmRSS value should be treated as if the line were
absent (early ValueError caught)."""
gguf = tmp_path / "m.gguf"
_sparse(gguf, 1 * 1024**3)
inst = _make()
inst._process = _Proc(os.getpid())
inst._gguf_path = str(gguf)
def fake_open(path, *a, **kw):
if str(path).startswith("/proc/"):
return io.StringIO("VmRSS:\tXXXX\tkB\n")
return open(path, *a, **kw)
with patch("builtins.open", side_effect = fake_open):
out = inst.load_progress()
# The implementation catches ValueError on int() and returns None.
assert out is None
# ---------------------------------------------------------------------------
# C. Filesystem edge cases
# ---------------------------------------------------------------------------
class TestFilesystemEdges:
def test_symlink_primary_follows_to_blob(self, tmp_path):
"""HF cache stores blobs under blobs/ and symlinks them from
snapshots/. The method must follow the symlink."""
blob = tmp_path / "blob"
_sparse(blob, 12 * 1024**3)
snap = tmp_path / "snap"
snap.mkdir()
link = snap / "m.gguf"
link.symlink_to(blob)
inst = _make()
inst._process = _Proc(os.getpid())
inst._gguf_path = str(link)
with patch("builtins.open", side_effect = _fake_proc_reader(6 * 1024**2)):
out = inst.load_progress()
assert out["bytes_total"] == 12 * 1024**3
def test_broken_symlink_skipped(self, tmp_path):
snap = tmp_path / "snap"
snap.mkdir()
link = snap / "m.gguf"
link.symlink_to(tmp_path / "missing-blob")
inst = _make()
inst._process = _Proc(os.getpid())
inst._gguf_path = str(link)
with patch("builtins.open", side_effect = _fake_proc_reader(1024)):
out = inst.load_progress()
assert out["bytes_total"] == 0
assert out["bytes_loaded"] == 1024 * 1024
def test_nonexistent_path_skipped(self, tmp_path):
inst = _make()
inst._process = _Proc(os.getpid())
inst._gguf_path = str(tmp_path / "ghost.gguf")
with patch("builtins.open", side_effect = _fake_proc_reader(1024)):
out = inst.load_progress()
assert out["bytes_total"] == 0
def test_relative_gguf_path(self, tmp_path):
"""Relative paths shouldn't crash; behaviour depends on CWD but
the method must not raise."""
cwd = os.getcwd()
try:
os.chdir(tmp_path)
_sparse(Path("rel.gguf"), 8 * 1024**3)
inst = _make()
inst._process = _Proc(os.getpid())
inst._gguf_path = "rel.gguf"
with patch("builtins.open", side_effect = _fake_proc_reader(0)):
out = inst.load_progress()
assert out is not None
assert out["bytes_total"] == 8 * 1024**3
finally:
os.chdir(cwd)
# ---------------------------------------------------------------------------
# D. Shard aggregation
# ---------------------------------------------------------------------------
class TestShardAggregation:
def test_partial_multi_shard_download(self, tmp_path):
"""Primary present but shards 2..N still downloading as
``.incomplete``. Sums only the fully-arrived ``.gguf`` files."""
_sparse(tmp_path / "m-00001-of-00004.gguf", 30 * 1024**3)
_sparse(tmp_path / "m-00002-of-00004.gguf", 30 * 1024**3)
# 3 and 4 still downloading as .incomplete
_sparse(tmp_path / "m-00003-of-00004.gguf.incomplete", 5 * 1024**3)
inst = _make()
inst._process = _Proc(os.getpid())
inst._gguf_path = str(tmp_path / "m-00001-of-00004.gguf")
with patch("builtins.open", side_effect = _fake_proc_reader(0)):
out = inst.load_progress()
assert out["bytes_total"] == 60 * 1024**3 # only the .gguf siblings
def test_two_shard_series_in_same_dir(self, tmp_path):
"""Defensive: if two quant series share a dir, prefix filter
only sums siblings of the chosen primary."""
for i in range(1, 3):
_sparse(tmp_path / f"m_q4-{i:05d}-of-00002.gguf", 10 * 1024**3)
_sparse(tmp_path / f"m_q8-{i:05d}-of-00002.gguf", 20 * 1024**3)
inst = _make()
inst._process = _Proc(os.getpid())
inst._gguf_path = str(tmp_path / "m_q8-00001-of-00002.gguf")
with patch("builtins.open", side_effect = _fake_proc_reader(0)):
out = inst.load_progress()
assert out["bytes_total"] == 40 * 1024**3 # just q8 series
def test_mmproj_sibling_not_counted(self, tmp_path):
"""Vision models drop an ``mmproj-*.gguf`` alongside. For a
single-file (non-sharded) primary we only count the primary."""
_sparse(tmp_path / "m.gguf", 8 * 1024**3)
_sparse(tmp_path / "mmproj-BF16.gguf", 2 * 1024**3)
inst = _make()
inst._process = _Proc(os.getpid())
inst._gguf_path = str(tmp_path / "m.gguf")
with patch("builtins.open", side_effect = _fake_proc_reader(0)):
out = inst.load_progress()
# Non-sharded primary: only the primary is counted.
assert out["bytes_total"] == 8 * 1024**3
def test_single_file_model(self, tmp_path):
"""Non-sharded model: primary only."""
_sparse(tmp_path / "small.gguf", 4 * 1024**3)
inst = _make()
inst._process = _Proc(os.getpid())
inst._gguf_path = str(tmp_path / "small.gguf")
with patch("builtins.open", side_effect = _fake_proc_reader(2 * 1024**2)):
out = inst.load_progress()
assert out["bytes_total"] == 4 * 1024**3
assert out["bytes_loaded"] == 2 * 1024**3
# ---------------------------------------------------------------------------
# E. Lifecycle races
# ---------------------------------------------------------------------------
class TestLifecycleRaces:
def test_process_set_but_gguf_path_not_yet(self, tmp_path):
"""Moment between Popen and self._gguf_path=model_path."""
inst = _make()
inst._process = _Proc(os.getpid())
inst._gguf_path = None
with patch("builtins.open", side_effect = _fake_proc_reader(1024)):
out = inst.load_progress()
assert out is not None
assert out["phase"] == "mmap"
assert out["bytes_total"] == 0
assert out["bytes_loaded"] == 1024 * 1024
def test_process_died_mid_sample(self, tmp_path):
"""/proc/<pid> disappears -> None."""
_sparse(tmp_path / "m.gguf", 1 * 1024**3)
inst = _make()
inst._process = _Proc(pid = 999_999_999)
inst._gguf_path = str(tmp_path / "m.gguf")
assert inst.load_progress() is None
def test_healthy_true_ready_phase(self, tmp_path):
_sparse(tmp_path / "m.gguf", 1 * 1024**3)
inst = _make()
inst._process = _Proc(os.getpid())
inst._gguf_path = str(tmp_path / "m.gguf")
inst._healthy = True
with patch("builtins.open", side_effect = _fake_proc_reader(1024)):
out = inst.load_progress()
assert out["phase"] == "ready"
# ---------------------------------------------------------------------------
# F. Concurrent sampling (simulates multiple browser tabs polling)
# ---------------------------------------------------------------------------
class TestConcurrentSampling:
def test_parallel_invocations_never_raise(self, tmp_path):
"""Many concurrent samplers hitting the same backend must not raise.
We intentionally do NOT patch ``builtins.open`` here because
``unittest.mock.patch`` is not thread-safe: interleaved
enter/exit across threads can leak a Mock into ``builtins.open``
and poison every subsequent test in the session. Instead, we
let each thread hit the real ``/proc/self/status`` of the test
process, which is exactly the code path that matters in prod.
"""
_sparse(tmp_path / "m.gguf", 1 * 1024**3)
inst = _make()
inst._process = _Proc(os.getpid())
inst._gguf_path = str(tmp_path / "m.gguf")
errors = []
def run():
try:
for _ in range(50):
inst.load_progress()
except Exception as e: # pragma: no cover
errors.append(e)
threads = [threading.Thread(target = run) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert not errors, errors
# ---------------------------------------------------------------------------
# G. Fraction bounds
# ---------------------------------------------------------------------------
class TestFractionBounds:
def test_fraction_capped_at_one(self, tmp_path):
_sparse(tmp_path / "m.gguf", 1 * 1024**3)
inst = _make()
inst._process = _Proc(os.getpid())
inst._gguf_path = str(tmp_path / "m.gguf")
# RSS > total (post-paged-in + extra structures)
with patch("builtins.open", side_effect = _fake_proc_reader(2 * 1024**2)):
out = inst.load_progress()
assert 0.0 <= out["fraction"] <= 1.0
def test_fraction_zero_when_total_zero(self):
inst = _make()
inst._process = _Proc(os.getpid())
inst._gguf_path = None
with patch("builtins.open", side_effect = _fake_proc_reader(1024**2)):
out = inst.load_progress()
assert out["fraction"] == 0.0

View file

@ -0,0 +1,244 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""Tests for the ``max_context_length`` warning-threshold semantics.
``/api/inference/status.max_context_length`` is what the ctx slider in
the chat settings sheet reads to decide when to render the "Exceeds
estimated VRAM capacity. The model may use system RAM." warning:
ctxDisplayValue > ggufMaxContextLength show warning
For models whose weights fit on some GPU subset, the warning threshold
is the largest ctx that fits fully in VRAM (the binary-search cap from
``_fit_context_to_vram``). For models whose weights exceed 90% of every
GPU subset's free memory, the warning must fire as soon as the user
drags above the 4096 spec default (otherwise a user loading e.g.
MiniMax-M2.7 on a 97 GB GPU sees a slider up to 196608 with no
indication that any value above 4096 will trigger ``--fit on`` and
degrade performance).
These tests pin both cases. No GPU probing, no subprocess, no GGUF I/O.
Cross-platform: Linux, macOS, Windows, WSL.
"""
from __future__ import annotations
import sys
import types as _types
from pathlib import Path
import pytest
# ---------------------------------------------------------------------------
# Stub heavy / unavailable external dependencies before importing the
# module under test. Same pattern as test_kv_cache_estimation.py.
# ---------------------------------------------------------------------------
_BACKEND_DIR = str(Path(__file__).resolve().parent.parent)
if _BACKEND_DIR not in sys.path:
sys.path.insert(0, _BACKEND_DIR)
# loggers
_loggers_stub = _types.ModuleType("loggers")
_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name)
sys.modules.setdefault("loggers", _loggers_stub)
# structlog
_structlog_stub = _types.ModuleType("structlog")
sys.modules.setdefault("structlog", _structlog_stub)
# httpx
_httpx_stub = _types.ModuleType("httpx")
for _exc_name in (
"ConnectError",
"TimeoutException",
"ReadTimeout",
"ReadError",
"RemoteProtocolError",
"CloseError",
):
setattr(_httpx_stub, _exc_name, type(_exc_name, (Exception,), {}))
class _FakeTimeout:
def __init__(self, *a, **kw):
pass
_httpx_stub.Timeout = _FakeTimeout
_httpx_stub.Client = type(
"Client",
(),
{
"__init__": lambda self, **kw: None,
"__enter__": lambda self: self,
"__exit__": lambda self, *a: None,
},
)
sys.modules.setdefault("httpx", _httpx_stub)
from core.inference.llama_cpp import LlamaCppBackend
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
GIB = 1024**3
def _make_backend(native_ctx = 131072):
inst = LlamaCppBackend.__new__(LlamaCppBackend)
inst._context_length = native_ctx
inst._n_layers = 80
inst._n_kv_heads = 8
inst._n_heads = 64
inst._embedding_length = 8192
inst._kv_key_length = 128
inst._kv_value_length = 128
inst._kv_lora_rank = None
inst._sliding_window = None
inst._ssm_inner_size = None
inst._full_attention_interval = None
inst._key_length_mla = None
return inst
def _compute_max_available_ctx(native_ctx, model_gib, gpus, kv_per_token_bytes = 325_000):
"""Run the ceiling-probe block from load_model and return the final
``max_available_ctx`` value the backend would assign to
``_max_context_length``.
"""
inst = _make_backend(native_ctx = native_ctx)
model_size = int(model_gib * GIB)
inst._estimate_kv_cache_bytes = (
lambda n, _t = None: 0 if n <= 0 else n * kv_per_token_bytes
)
inst._can_estimate_kv = lambda: True
context_length = inst._context_length
effective_ctx = context_length
max_available_ctx = context_length
cache_type_kv = None
native_ctx_for_cap = context_length
ranked_for_cap = sorted(gpus, key = lambda g: g[1], reverse = True)
best_cap = 0
for n_gpus in range(1, len(ranked_for_cap) + 1):
subset = ranked_for_cap[:n_gpus]
pool_mib = sum(free for _, free in subset)
capped = inst._fit_context_to_vram(
native_ctx_for_cap,
pool_mib,
model_size,
cache_type_kv,
)
kv = inst._estimate_kv_cache_bytes(capped, cache_type_kv)
total_mib = (model_size + kv) / (1024 * 1024)
if total_mib <= pool_mib * 0.90:
best_cap = max(best_cap, capped)
if best_cap > 0:
max_available_ctx = best_cap
else:
max_available_ctx = min(4096, native_ctx_for_cap)
return max_available_ctx
# ---------------------------------------------------------------------------
# Weights exceed every GPU subset's VRAM (MiniMax-M2.7-like)
# ---------------------------------------------------------------------------
class TestMaxContextLengthForWeightsExceedVRAM:
"""The UI ``max_context_length`` threshold must fall back to 4096 so
the warning fires as soon as the user drags above the spec default.
"""
def test_minimax_like(self):
"""131 GB weights, single 97 GB GPU, native ctx 196608."""
got = _compute_max_available_ctx(
native_ctx = 196608,
model_gib = 131,
gpus = [(0, 97_000)],
)
assert got == 4096
def test_multi_gpu_all_subsets_fail(self):
"""400 GB weights across a 4x80 GB pool (320 GB total, still too small)."""
got = _compute_max_available_ctx(
native_ctx = 131072,
model_gib = 400,
gpus = [(0, 80_000), (1, 80_000), (2, 80_000), (3, 80_000)],
)
assert got == 4096
def test_native_below_fallback_is_preserved(self):
"""If the model's native ctx is itself smaller than 4096, do not
advertise a larger value than the model supports."""
got = _compute_max_available_ctx(
native_ctx = 2048,
model_gib = 200,
gpus = [(0, 80_000)],
)
assert got == 2048
# ---------------------------------------------------------------------------
# Fittable models (regression guard)
# ---------------------------------------------------------------------------
class TestMaxContextLengthForFittableModels:
"""The existing best-cap behaviour must be unchanged."""
def test_small_model_fits_easily(self):
"""8 GB model on 24 GB GPU: should auto-pick a large ctx."""
got = _compute_max_available_ctx(
native_ctx = 131072,
model_gib = 8,
gpus = [(0, 24_000)],
kv_per_token_bytes = 8192,
)
assert got > 4096
assert got <= 131072
def test_medium_model_multi_gpu(self):
"""60 GB model split across 2 GPUs: picks a fitting ctx."""
got = _compute_max_available_ctx(
native_ctx = 131072,
model_gib = 60,
gpus = [(0, 40_000), (1, 40_000)],
kv_per_token_bytes = 8192,
)
assert got > 4096
def test_tiny_model_on_huge_gpu_near_native(self):
"""2 GB model, 80 GB GPU, negligible KV: should approach native."""
got = _compute_max_available_ctx(
native_ctx = 131072,
model_gib = 2,
gpus = [(0, 80_000)],
kv_per_token_bytes = 64,
)
assert got >= 131072 - 256 # rounded to 256 boundary
# ---------------------------------------------------------------------------
# Property plumbing
# ---------------------------------------------------------------------------
class TestMaxContextLengthProperty:
def test_falls_back_to_native_when_unset(self):
inst = _make_backend(native_ctx = 131072)
inst._max_context_length = None
assert inst.max_context_length == 131072
def test_returns_stored_value_when_set(self):
inst = _make_backend(native_ctx = 131072)
inst._max_context_length = 4096
assert inst.max_context_length == 4096

View file

@ -0,0 +1,137 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""``--no-context-shift`` launch-flag contract.
When llama-server runs with its default context-shift behavior, the UI
has no way to tell the user that the KV cache has been rotated --
earlier turns silently vanish from the conversation. The Studio
backend always passes ``--no-context-shift`` so the server returns a
clean error instead, and the chat adapter can point the user at the
``Context Length`` input in the settings panel.
This file is a static read of the launch command: we ask
``LlamaCppBackend`` to assemble its ``cmd`` list and assert the flag
is always present. Testing via the real subprocess would require an
actual GGUF on disk, which is out of scope for the fast test suite.
"""
from __future__ import annotations
import inspect
import sys
import types as _types
from pathlib import Path
import pytest
# ---------------------------------------------------------------------------
# Same external-dep stubs as the other llama_cpp tests.
# ---------------------------------------------------------------------------
_BACKEND_DIR = str(Path(__file__).resolve().parent.parent)
if _BACKEND_DIR not in sys.path:
sys.path.insert(0, _BACKEND_DIR)
_loggers_stub = _types.ModuleType("loggers")
_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name)
sys.modules.setdefault("loggers", _loggers_stub)
_structlog_stub = _types.ModuleType("structlog")
sys.modules.setdefault("structlog", _structlog_stub)
_httpx_stub = _types.ModuleType("httpx")
for _exc in (
"ConnectError",
"TimeoutException",
"ReadTimeout",
"ReadError",
"RemoteProtocolError",
"CloseError",
):
setattr(_httpx_stub, _exc, type(_exc, (Exception,), {}))
_httpx_stub.Timeout = type("T", (), {"__init__": lambda s, *a, **k: None})
_httpx_stub.Client = type(
"C",
(),
{
"__init__": lambda s, **kw: None,
"__enter__": lambda s: s,
"__exit__": lambda s, *a: None,
},
)
sys.modules.setdefault("httpx", _httpx_stub)
from core.inference import llama_cpp as llama_cpp_module
def _load_model_source() -> str:
"""Return the source of ``LlamaCppBackend.load_model``.
Using ``inspect.getsource`` instead of reading the file directly
scopes the assertions to the function that actually launches
llama-server, so neither the presence check nor the location check
can be fooled by a stray occurrence of ``"--no-context-shift"``
elsewhere in the module.
"""
return inspect.getsource(llama_cpp_module.LlamaCppBackend.load_model)
def test_no_context_shift_is_in_load_model():
"""The flag is part of the static launch-command template.
We check the source of ``load_model`` rather than mocking the whole
call chain (GPU probing, GGUF stat, etc.): the flag is written as
a literal in one place and any regression has to delete it, which
a text search will catch.
"""
assert '"--no-context-shift"' in _load_model_source(), (
"llama-server must be launched with --no-context-shift so the "
"UI can surface a clean 'context full' error instead of silently "
"losing old turns to a KV-cache rotation."
)
def test_flag_sits_inside_the_base_cmd_list():
"""Pin the flag's location so a future refactor can't accidentally
move it into a branch that only fires on some code paths.
We slice from ``cmd = [`` to the first ``]`` at the same indent.
Using ``inspect.getsource`` means the function lives in its own
string and there are no siblings to worry about, so a plain
bracket search would also work -- anchoring on the trailing indent
just keeps the slice from wandering into a later expression if the
opening literal ever grows an in-line comment trailing it.
"""
source = _load_model_source()
start = source.find("cmd = [")
assert start >= 0, "could not find the base cmd = [...] block"
# Find the first line containing only ``]`` (possibly indented).
# Works for any indentation style the formatter picks.
rest = source[start:]
end_rel = -1
for line_start, line in _iter_lines_with_offset(rest):
if line_start == 0:
# Skip the opening ``cmd = [`` line itself.
continue
if line.strip() == "]":
end_rel = line_start
break
assert end_rel > 0, "could not find end of cmd = [...] block"
block = rest[:end_rel]
assert '"--no-context-shift"' in block, (
"--no-context-shift must be in the base cmd list, not in a "
"conditional branch -- otherwise some code paths would still "
"run with silent context shift enabled."
)
# Also pin that it is next to -c / --ctx so the grouping makes sense.
assert '"-c"' in block
assert '"--flash-attn"' in block
def _iter_lines_with_offset(text: str):
"""Yield (offset, line) pairs over ``text`` without losing offsets."""
offset = 0
for line in text.splitlines(keepends = True):
yield offset, line
offset += len(line)

View file

@ -0,0 +1,81 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
import asyncio
import sys
import types
# Keep this test runnable in lightweight environments where optional logging
# deps are not installed.
if "structlog" not in sys.modules:
class _DummyLogger:
def __getattr__(self, _name):
return lambda *args, **kwargs: None
sys.modules["structlog"] = types.SimpleNamespace(
BoundLogger = _DummyLogger,
get_logger = lambda *args, **kwargs: _DummyLogger(),
)
import routes.models as models_route
import utils.models.model_config as model_config_module
def test_get_model_config_resolves_cached_case_before_model_checks(monkeypatch):
calls: dict[str, str] = {}
class _DummyModelConfig:
is_lora = False
base_model = None
def _record_load(model_name):
calls["load_model_defaults"] = model_name
return {}
def _record_vision(model_name, hf_token = None):
calls["is_vision_model"] = model_name
return False
def _record_embedding(model_name, hf_token = None):
calls["is_embedding_model"] = model_name
return False
def _record_audio(model_name, hf_token = None):
calls["detect_audio_type"] = model_name
return None
def _record_from_identifier(cls, model_name):
calls["from_identifier"] = model_name
return _DummyModelConfig()
monkeypatch.setattr(models_route, "is_local_path", lambda _: False)
monkeypatch.setattr(
models_route, "resolve_cached_repo_id_case", lambda _: "Org/Model"
)
monkeypatch.setattr(models_route, "load_model_defaults", _record_load)
monkeypatch.setattr(models_route, "is_vision_model", _record_vision)
monkeypatch.setattr(models_route, "is_embedding_model", _record_embedding)
monkeypatch.setattr(model_config_module, "detect_audio_type", _record_audio)
monkeypatch.setattr(
models_route.ModelConfig,
"from_identifier",
classmethod(_record_from_identifier),
)
monkeypatch.setattr(models_route, "_get_max_position_embeddings", lambda _: 4096)
monkeypatch.setattr(models_route, "_get_model_size_bytes", lambda *_args, **_kw: 0)
result = asyncio.run(
models_route.get_model_config(
model_name = "org/model",
hf_token = None,
current_subject = "test-subject",
)
)
assert result.model_name == "Org/Model"
assert calls["load_model_defaults"] == "Org/Model"
assert calls["is_vision_model"] == "Org/Model"
assert calls["is_embedding_model"] == "Org/Model"
assert calls["detect_audio_type"] == "Org/Model"
assert calls["from_identifier"] == "Org/Model"

View file

@ -0,0 +1,518 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""Tests for the native_context_length feature (PR #4746).
Verifies that the new `native_context_length` property on LlamaCppBackend
and the corresponding Pydantic model fields work correctly. The raw GGUF
`_context_length` must never be overwritten by VRAM-capping logic.
Requires no GPU, network, or external libraries beyond pytest and pydantic.
"""
import io
import json
import struct
import sys
import types as _types
from pathlib import Path
from unittest.mock import patch
import pytest
# ---------------------------------------------------------------------------
# Stub heavy / unavailable external dependencies before importing the
# module under test. Same pattern as test_kv_cache_estimation.py.
# ---------------------------------------------------------------------------
_BACKEND_DIR = str(Path(__file__).resolve().parent.parent)
if _BACKEND_DIR not in sys.path:
sys.path.insert(0, _BACKEND_DIR)
# loggers
_loggers_stub = _types.ModuleType("loggers")
_loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name)
sys.modules.setdefault("loggers", _loggers_stub)
# structlog
_structlog_stub = _types.ModuleType("structlog")
sys.modules.setdefault("structlog", _structlog_stub)
# httpx -- stub only the names referenced at import / class-definition time
_httpx_stub = _types.ModuleType("httpx")
for _exc_name in (
"ConnectError",
"TimeoutException",
"ReadTimeout",
"ReadError",
"RemoteProtocolError",
"CloseError",
):
setattr(_httpx_stub, _exc_name, type(_exc_name, (Exception,), {}))
class _FakeTimeout:
def __init__(self, *a, **kw):
pass
_httpx_stub.Timeout = _FakeTimeout
_httpx_stub.Client = type(
"Client",
(),
{
"__init__": lambda self, **kw: None,
"__enter__": lambda self: self,
"__exit__": lambda self, *a: None,
},
)
sys.modules.setdefault("httpx", _httpx_stub)
from core.inference.llama_cpp import LlamaCppBackend
from models.inference import LoadResponse, InferenceStatusResponse
# ── Helpers ──────────────────────────────────────────────────────────
def _write_kv(buf: io.BytesIO, key: str, value, vtype: int) -> None:
"""Append a single GGUF KV pair to *buf*."""
key_bytes = key.encode("utf-8")
buf.write(struct.pack("<Q", len(key_bytes)))
buf.write(key_bytes)
buf.write(struct.pack("<I", vtype))
if vtype == 4: # UINT32
buf.write(struct.pack("<I", value))
elif vtype == 10: # UINT64
buf.write(struct.pack("<Q", value))
elif vtype == 8: # STRING
val_bytes = value.encode("utf-8")
buf.write(struct.pack("<Q", len(val_bytes)))
buf.write(val_bytes)
else:
raise ValueError(f"Unsupported vtype in test helper: {vtype}")
def make_gguf(
tmp_path: Path,
arch: str,
kvs: list,
*,
arch_first: bool = True,
filename: str = "test.gguf",
) -> str:
"""Create a minimal valid GGUF v3 binary in *tmp_path*."""
buf = io.BytesIO()
buf.write(struct.pack("<I", 0x46554747)) # GGUF magic
buf.write(struct.pack("<I", 3)) # version 3
buf.write(struct.pack("<Q", 0)) # tensor count = 0
ordered = []
arch_entry = ("general.architecture", arch, 8)
if arch_first:
ordered.append(arch_entry)
for suffix, val, vt in kvs:
ordered.append((f"{arch}.{suffix}", val, vt))
if not arch_first:
ordered.append(arch_entry)
buf.write(struct.pack("<Q", len(ordered)))
for key, val, vt in ordered:
_write_kv(buf, key, val, vt)
path = tmp_path / filename
path.write_bytes(buf.getvalue())
return str(path)
@pytest.fixture
def backend():
"""Create a fresh LlamaCppBackend with side effects disabled."""
with patch.object(LlamaCppBackend, "_kill_orphaned_servers"):
with patch("atexit.register"):
return LlamaCppBackend()
# =====================================================================
# A. TestNativeContextLengthProperty -- the new property
# =====================================================================
class TestNativeContextLengthProperty:
"""Tests the new `native_context_length` property on LlamaCppBackend."""
def test_none_on_fresh_backend(self, backend):
"""Returns None when no model loaded."""
assert backend.native_context_length is None
def test_returns_raw_gguf_value(self, backend):
"""Directly returns _context_length when set."""
backend._context_length = 131072
assert backend.native_context_length == 131072
def test_not_capped_by_effective(self, backend):
"""native_context_length ignores _effective_context_length."""
backend._context_length = 131072
backend._effective_context_length = 32768
assert backend.native_context_length == 131072
def test_not_capped_by_max(self, backend):
"""native_context_length ignores _max_context_length."""
backend._context_length = 131072
backend._max_context_length = 65536
assert backend.native_context_length == 131072
def test_none_after_unload(self, backend):
"""After unload_model(), returns None."""
backend._context_length = 131072
assert backend.native_context_length == 131072
backend.unload_model()
assert backend.native_context_length is None
def test_after_gguf_parse(self, tmp_path, backend):
"""Synthetic GGUF with context_length=16384 populates the property."""
path = make_gguf(
tmp_path,
"llama",
[("context_length", 16384, 4)],
)
backend._read_gguf_metadata(path)
assert backend.native_context_length == 16384
def test_resets_between_parses(self, tmp_path, backend):
"""Second GGUF without context_length resets native to None."""
path_a = make_gguf(
tmp_path,
"llama",
[("context_length", 16384, 4)],
filename = "a.gguf",
)
backend._read_gguf_metadata(path_a)
assert backend.native_context_length == 16384
path_b = make_gguf(
tmp_path,
"gpt2",
[("block_count", 12, 4)],
filename = "b.gguf",
)
backend._read_gguf_metadata(path_b)
assert backend.native_context_length is None
# =====================================================================
# B. TestContextValueSeparation -- core invariant
# =====================================================================
class TestContextValueSeparation:
"""_context_length is never overwritten by VRAM logic."""
def test_preserved_after_effective_set(self, backend):
"""Setting _effective_context_length does not change _context_length."""
backend._context_length = 131072
backend._effective_context_length = 32768
assert backend._context_length == 131072
assert backend.native_context_length == 131072
def test_ordering_when_capped(self, backend):
"""native >= max >= effective holds when VRAM-capped."""
backend._context_length = 131072
backend._max_context_length = 65536
backend._effective_context_length = 32768
assert backend.native_context_length >= backend.max_context_length
assert backend.max_context_length >= backend.context_length
def test_all_equal_when_uncapped(self, backend):
"""All three equal when no VRAM constraint."""
backend._context_length = 8192
# No effective or max set -- properties fall back to _context_length
assert backend.native_context_length == 8192
assert backend.max_context_length == 8192
assert backend.context_length == 8192
def test_fit_context_does_not_modify(self, backend):
"""_fit_context_to_vram() does not touch _context_length."""
backend._context_length = 131072
backend._n_layers = 32
backend._n_kv_heads = 8
backend._n_heads = 32
backend._embedding_length = 4096
original = backend._context_length
# Simulate a very small VRAM budget that forces capping
result = backend._fit_context_to_vram(
requested_ctx = 131072,
available_mib = 512, # very small
model_size_bytes = 0,
)
# _fit_context_to_vram returns the capped value, not modifying _context_length
assert backend._context_length == original
assert backend.native_context_length == original
# The returned capped value should be <= requested
assert result <= 131072
def test_native_gt_context_when_capped(self, backend):
"""native_context_length > context_length after VRAM capping."""
backend._context_length = 131072
backend._effective_context_length = 16384
assert backend.native_context_length > backend.context_length
# =====================================================================
# C. TestPydanticModels -- LoadResponse & InferenceStatusResponse
# =====================================================================
class TestPydanticModels:
"""Tests native_context_length field on Pydantic models."""
def test_load_response_has_field(self):
"""Field exists in LoadResponse.model_fields."""
assert "native_context_length" in LoadResponse.model_fields
def test_load_response_defaults_none(self):
"""Omitting native_context_length defaults to None."""
resp = LoadResponse(
status = "loaded",
model = "test",
display_name = "Test",
inference = {},
)
assert resp.native_context_length is None
def test_load_response_accepts_int(self):
"""native_context_length=131072 stores correctly."""
resp = LoadResponse(
status = "loaded",
model = "test",
display_name = "Test",
inference = {},
native_context_length = 131072,
)
assert resp.native_context_length == 131072
def test_load_response_json_null(self):
"""None serializes to JSON null."""
resp = LoadResponse(
status = "loaded",
model = "test",
display_name = "Test",
inference = {},
)
data = json.loads(resp.model_dump_json())
assert data["native_context_length"] is None
def test_load_response_json_int(self):
"""131072 serializes to JSON number."""
resp = LoadResponse(
status = "loaded",
model = "test",
display_name = "Test",
inference = {},
native_context_length = 131072,
)
data = json.loads(resp.model_dump_json())
assert data["native_context_length"] == 131072
def test_status_response_has_field(self):
"""Field exists in InferenceStatusResponse.model_fields."""
assert "native_context_length" in InferenceStatusResponse.model_fields
def test_status_response_defaults_none(self):
"""Omitting native_context_length defaults to None."""
resp = InferenceStatusResponse()
assert resp.native_context_length is None
def test_roundtrip_preserves_value(self):
"""model_validate_json(model_dump_json()) round-trips."""
resp = LoadResponse(
status = "loaded",
model = "test",
display_name = "Test",
inference = {},
native_context_length = 131072,
)
roundtripped = LoadResponse.model_validate_json(resp.model_dump_json())
assert roundtripped.native_context_length == 131072
# =====================================================================
# D. TestRouteCompleteness -- source-level verification
# =====================================================================
class TestRouteCompleteness:
"""All response construction sites in routes/inference.py include native_context_length."""
@pytest.fixture(autouse = True)
def _load_source(self):
"""Read routes/inference.py source once."""
routes_path = Path(__file__).resolve().parent.parent / "routes" / "inference.py"
self._source = routes_path.read_text()
def _find_construction_blocks(self, class_name: str) -> list[str]:
"""Extract all code blocks that construct a given response class."""
blocks = []
idx = 0
while True:
start = self._source.find(f"{class_name}(", idx)
if start == -1:
break
# Find matching closing paren (simple depth counter)
depth = 0
end = start
for i, ch in enumerate(self._source[start:], start):
if ch == "(":
depth += 1
elif ch == ")":
depth -= 1
if depth == 0:
end = i + 1
break
blocks.append(self._source[start:end])
idx = end
return blocks
def test_gguf_load_responses_have_field(self):
"""Every GGUF LoadResponse (is_gguf = True) includes native_context_length."""
blocks = self._find_construction_blocks("LoadResponse")
gguf_blocks = [
b for b in blocks if "is_gguf = True" in b or "is_gguf=True" in b
]
assert (
len(gguf_blocks) >= 2
), f"Expected at least 2 GGUF LoadResponse blocks, found {len(gguf_blocks)}"
for i, block in enumerate(gguf_blocks):
assert (
"native_context_length" in block
), f"GGUF LoadResponse block #{i} missing native_context_length:\n{block[:200]}"
def test_non_gguf_load_responses_omit_field(self):
"""Non-GGUF LoadResponse blocks do not set native_context_length (defaults to None)."""
blocks = self._find_construction_blocks("LoadResponse")
non_gguf = [
b for b in blocks if "is_gguf = True" not in b and "is_gguf=True" not in b
]
# Non-GGUF paths should not reference native_context_length
# (Pydantic defaults it to None, so not setting it is correct)
for block in non_gguf:
assert (
"native_context_length" not in block
), f"Non-GGUF LoadResponse should not set native_context_length:\n{block[:200]}"
def test_status_path(self):
"""InferenceStatusResponse construction with llama_backend has the field."""
blocks = self._find_construction_blocks("InferenceStatusResponse")
found = False
for block in blocks:
if "llama_backend" in block and "native_context_length" in block:
found = True
break
assert found, "No InferenceStatusResponse block with llama_backend has native_context_length"
# =====================================================================
# E. TestEdgeCases
# =====================================================================
class TestNativeContextEdgeCases:
"""Edge cases for native_context_length."""
def test_context_length_zero(self, tmp_path, backend):
"""GGUF context_length=0 returns 0, not None."""
path = make_gguf(tmp_path, "llama", [("context_length", 0, 4)])
backend._read_gguf_metadata(path)
assert backend.native_context_length == 0
def test_context_length_uint32_max(self, tmp_path, backend):
"""2^32 - 1 survives without truncation."""
val = 2**32 - 1
path = make_gguf(tmp_path, "llama", [("context_length", val, 4)])
backend._read_gguf_metadata(path)
assert backend.native_context_length == val
def test_context_length_uint64(self, tmp_path, backend):
"""UINT64 type context_length parsed correctly."""
val = 2**33 # exceeds UINT32 range
path = make_gguf(tmp_path, "llama", [("context_length", val, 10)])
backend._read_gguf_metadata(path)
assert backend.native_context_length == val
def test_no_context_length_in_gguf(self, tmp_path, backend):
"""GGUF without context_length key yields None."""
path = make_gguf(tmp_path, "llama", [("block_count", 32, 4)])
backend._read_gguf_metadata(path)
assert backend.native_context_length is None
def test_native_equals_context_when_uncapped(self, backend):
"""Both equal when no VRAM cap applied."""
backend._context_length = 8192
assert backend.native_context_length == backend.context_length
def test_native_survives_parse_then_cap(self, tmp_path, backend):
"""Parse then set effective cap: native unchanged."""
path = make_gguf(
tmp_path,
"llama",
[
("context_length", 131072, 4),
("block_count", 32, 4),
("attention.head_count", 32, 4),
("attention.head_count_kv", 8, 4),
("embedding_length", 4096, 4),
],
)
backend._read_gguf_metadata(path)
assert backend.native_context_length == 131072
# Simulate VRAM capping by setting effective and max
backend._effective_context_length = 16384
backend._max_context_length = 32768
assert backend.native_context_length == 131072
# =====================================================================
# F. TestCrossPlatform -- binary I/O and serialization
# =====================================================================
class TestCrossPlatform:
"""Binary I/O and serialization correctness across platforms."""
def test_le_uint32_context_length(self, tmp_path, backend):
"""Little-endian UINT32 parsed correctly."""
path = make_gguf(tmp_path, "llama", [("context_length", 16384, 4)])
backend._read_gguf_metadata(path)
assert backend.native_context_length == 16384
def test_le_uint64_context_length(self, tmp_path, backend):
"""Little-endian UINT64 parsed correctly."""
path = make_gguf(tmp_path, "llama", [("context_length", 16384, 10)])
backend._read_gguf_metadata(path)
assert backend.native_context_length == 16384
def test_gguf_magic_le_byte_order(self, tmp_path):
"""Magic 0x46554747 matches GGUF spec (little-endian 'GGUF')."""
path = tmp_path / "magic_check.gguf"
buf = io.BytesIO()
buf.write(struct.pack("<I", 0x46554747))
raw = buf.getvalue()
# 'G' = 0x47, 'G' = 0x47, 'U' = 0x55, 'F' = 0x46
assert raw == b"GGUF"
def test_json_serialization_deterministic(self):
"""model_dump_json() is consistent across calls."""
resp = LoadResponse(
status = "loaded",
model = "test",
display_name = "Test",
inference = {},
native_context_length = 131072,
)
json1 = resp.model_dump_json()
json2 = resp.model_dump_json()
assert json1 == json2
assert '"native_context_length":131072' in json1

View file

@ -0,0 +1,465 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved.
"""
Tests for the OpenAI /v1/chat/completions client-side tool pass-through.
Covers:
- ChatCompletionRequest accepts standard OpenAI `tools` / `tool_choice` / `stop`.
- ChatMessage accepts role="tool" with `tool_call_id` and role="assistant"
with `content: None` + `tool_calls`.
- ChatCompletionRequest carries unknown fields via `extra="allow"`.
- anthropic_tool_choice_to_openai() covers all four Anthropic shapes.
- _build_passthrough_payload() honors a caller-supplied tool_choice and
defaults to "auto" when unset.
- _friendly_error() maps httpx transport errors to a "Lost connection"
message so passthrough failures are legible instead of bare 500s.
No running server or GPU required.
"""
import os
import sys
_backend = os.path.join(os.path.dirname(__file__), "..")
sys.path.insert(0, _backend)
import httpx
import pytest
from pydantic import ValidationError
from models.inference import (
ChatCompletionRequest,
ChatMessage,
)
from core.inference.anthropic_compat import (
anthropic_tool_choice_to_openai,
)
from routes.inference import _build_passthrough_payload, _friendly_error
# =====================================================================
# ChatMessage — tool role, tool_calls, optional content
# =====================================================================
class TestChatMessageToolRoles:
def test_tool_role_with_tool_call_id(self):
msg = ChatMessage(
role = "tool",
tool_call_id = "call_abc123",
content = '{"temperature": 72}',
)
assert msg.role == "tool"
assert msg.tool_call_id == "call_abc123"
assert msg.content == '{"temperature": 72}'
def test_tool_role_with_name(self):
msg = ChatMessage(
role = "tool",
tool_call_id = "call_abc123",
name = "get_weather",
content = '{"temperature": 72}',
)
assert msg.name == "get_weather"
def test_assistant_with_tool_calls_no_content(self):
msg = ChatMessage(
role = "assistant",
content = None,
tool_calls = [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Paris"}',
},
}
],
)
assert msg.role == "assistant"
assert msg.content is None
assert msg.tool_calls is not None
assert len(msg.tool_calls) == 1
assert msg.tool_calls[0]["function"]["name"] == "get_weather"
def test_assistant_with_content_and_tool_calls(self):
msg = ChatMessage(
role = "assistant",
content = "Let me check the weather.",
tool_calls = [
{
"id": "call_1",
"type": "function",
"function": {"name": "get_weather", "arguments": "{}"},
}
],
)
assert msg.content == "Let me check the weather."
assert msg.tool_calls[0]["id"] == "call_1"
def test_plain_user_message_still_works(self):
msg = ChatMessage(role = "user", content = "Hello")
assert msg.role == "user"
assert msg.tool_call_id is None
assert msg.tool_calls is None
assert msg.name is None
def test_invalid_role_rejected(self):
with pytest.raises(ValidationError):
ChatMessage(role = "function", content = "x")
def test_content_absent_on_assistant_tool_call_defaults_to_none(self):
# Assistant messages that carry only tool_calls are the one
# documented case where `content=None` is permitted.
msg = ChatMessage(
role = "assistant",
tool_calls = [
{
"id": "call_1",
"type": "function",
"function": {"name": "f", "arguments": "{}"},
}
],
)
assert msg.content is None
def test_tool_role_missing_tool_call_id_rejected(self):
# Per OpenAI spec, role="tool" messages must carry tool_call_id so
# upstream backends can associate the result with its prior call.
# Pin the boundary-level rejection so a malformed tool-result
# message never reaches the passthrough path.
with pytest.raises(ValidationError) as exc_info:
ChatMessage(role = "tool", content = '{"temperature": 72}')
assert "tool_call_id" in str(exc_info.value)
def test_tool_role_empty_tool_call_id_rejected(self):
with pytest.raises(ValidationError):
ChatMessage(
role = "tool",
tool_call_id = "",
content = '{"temperature": 72}',
)
# ── Role-aware content requirements ────────────────────────────
def test_user_empty_content_rejected(self):
with pytest.raises(ValidationError):
ChatMessage(role = "user", content = "")
def test_system_empty_content_rejected(self):
with pytest.raises(ValidationError):
ChatMessage(role = "system", content = "")
def test_user_empty_list_content_rejected(self):
with pytest.raises(ValidationError):
ChatMessage(role = "user", content = [])
def test_tool_empty_content_rejected(self):
with pytest.raises(ValidationError) as exc_info:
ChatMessage(role = "tool", tool_call_id = "call_1", content = "")
assert "content" in str(exc_info.value)
def test_assistant_without_content_or_tool_calls_rejected(self):
with pytest.raises(ValidationError) as exc_info:
ChatMessage(role = "assistant")
assert "content" in str(exc_info.value) or "tool_calls" in str(exc_info.value)
# ── Role-constrained tool-call metadata ────────────────────────
def test_tool_calls_on_user_rejected(self):
with pytest.raises(ValidationError) as exc_info:
ChatMessage(
role = "user",
content = "Hi",
tool_calls = [
{
"id": "c1",
"type": "function",
"function": {"name": "f", "arguments": "{}"},
}
],
)
assert "tool_calls" in str(exc_info.value)
def test_tool_call_id_on_user_rejected(self):
with pytest.raises(ValidationError) as exc_info:
ChatMessage(role = "user", content = "Hi", tool_call_id = "call_1")
assert "tool_call_id" in str(exc_info.value)
def test_name_on_user_rejected(self):
with pytest.raises(ValidationError) as exc_info:
ChatMessage(role = "user", content = "Hi", name = "get_weather")
assert "name" in str(exc_info.value)
# =====================================================================
# ChatCompletionRequest — standard OpenAI tool fields
# =====================================================================
class TestChatCompletionRequestToolFields:
def _make(self, **kwargs):
base = {"messages": [{"role": "user", "content": "Hi"}]}
base.update(kwargs)
return ChatCompletionRequest(**base)
def test_tools_parses(self):
req = self._make(
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Return the weather in a city",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
},
}
],
)
assert req.tools is not None
assert len(req.tools) == 1
assert req.tools[0]["function"]["name"] == "get_weather"
def test_tool_choice_string_auto(self):
assert self._make(tool_choice = "auto").tool_choice == "auto"
def test_tool_choice_string_required(self):
assert self._make(tool_choice = "required").tool_choice == "required"
def test_tool_choice_string_none(self):
assert self._make(tool_choice = "none").tool_choice == "none"
def test_tool_choice_named_function(self):
tc = {"type": "function", "function": {"name": "get_weather"}}
assert self._make(tool_choice = tc).tool_choice == tc
def test_stop_string(self):
assert self._make(stop = "\nUser:").stop == "\nUser:"
def test_stop_list(self):
assert self._make(stop = ["\nUser:", "\nAssistant:"]).stop == [
"\nUser:",
"\nAssistant:",
]
def test_tools_default_none(self):
req = self._make()
assert req.tools is None
assert req.tool_choice is None
assert req.stop is None
def test_extra_fields_accepted(self):
# `frequency_penalty`, `seed`, `response_format` are not yet
# explicitly declared but must survive Pydantic parsing now that
# extra="allow" is set.
req = self._make(
frequency_penalty = 0.5,
seed = 42,
response_format = {"type": "json_object"},
)
# Extras land in model_extra
assert req.model_extra is not None
assert req.model_extra.get("frequency_penalty") == 0.5
assert req.model_extra.get("seed") == 42
assert req.model_extra.get("response_format") == {"type": "json_object"}
def test_unsloth_extensions_still_work(self):
req = self._make(
enable_tools = True,
enabled_tools = ["web_search", "python"],
session_id = "abc",
)
assert req.enable_tools is True
assert req.enabled_tools == ["web_search", "python"]
assert req.session_id == "abc"
def test_stream_defaults_false_matching_openai_spec(self):
# OpenAI's /v1/chat/completions spec defaults `stream` to false.
# Studio previously defaulted to true, which broke naive curl
# clients that omit `stream` (they expect a JSON blob, got SSE).
# Pin the corrected default so it can't silently regress.
req = self._make()
assert req.stream is False
def test_multiturn_tool_loop_messages(self):
req = ChatCompletionRequest(
messages = [
{"role": "user", "content": "What's the weather in Paris?"},
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Paris"}',
},
}
],
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": '{"temperature": 14, "unit": "celsius"}',
},
],
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {"type": "object"},
},
}
],
)
assert len(req.messages) == 3
assert req.messages[1].role == "assistant"
assert req.messages[1].content is None
assert req.messages[1].tool_calls[0]["id"] == "call_1"
assert req.messages[2].role == "tool"
assert req.messages[2].tool_call_id == "call_1"
# =====================================================================
# anthropic_tool_choice_to_openai — pure translation helper
# =====================================================================
class TestAnthropicToolChoiceToOpenAI:
def test_auto(self):
assert anthropic_tool_choice_to_openai({"type": "auto"}) == "auto"
def test_any_becomes_required(self):
assert anthropic_tool_choice_to_openai({"type": "any"}) == "required"
def test_none(self):
assert anthropic_tool_choice_to_openai({"type": "none"}) == "none"
def test_tool_named(self):
result = anthropic_tool_choice_to_openai(
{"type": "tool", "name": "get_weather"}
)
assert result == {
"type": "function",
"function": {"name": "get_weather"},
}
def test_tool_missing_name_returns_none(self):
assert anthropic_tool_choice_to_openai({"type": "tool"}) is None
def test_none_input_returns_none(self):
assert anthropic_tool_choice_to_openai(None) is None
def test_unrecognized_shape_returns_none(self):
assert anthropic_tool_choice_to_openai({"type": "wibble"}) is None
assert anthropic_tool_choice_to_openai("auto") is None
assert anthropic_tool_choice_to_openai(42) is None
# =====================================================================
# _build_passthrough_payload — tool_choice propagation
# =====================================================================
class TestBuildPassthroughPayloadToolChoice:
def _args(self):
return dict(
openai_messages = [{"role": "user", "content": "Hi"}],
openai_tools = [
{
"type": "function",
"function": {"name": "f", "parameters": {"type": "object"}},
}
],
temperature = 0.6,
top_p = 0.95,
top_k = 20,
max_tokens = 128,
stream = False,
)
def test_default_tool_choice_is_auto(self):
body = _build_passthrough_payload(**self._args())
assert body["tool_choice"] == "auto"
def test_override_tool_choice_required(self):
body = _build_passthrough_payload(**self._args(), tool_choice = "required")
assert body["tool_choice"] == "required"
def test_override_tool_choice_none(self):
body = _build_passthrough_payload(**self._args(), tool_choice = "none")
assert body["tool_choice"] == "none"
def test_override_tool_choice_named_function(self):
tc = {"type": "function", "function": {"name": "f"}}
body = _build_passthrough_payload(**self._args(), tool_choice = tc)
assert body["tool_choice"] == tc
def test_stream_adds_include_usage(self):
args = self._args()
args["stream"] = True
body = _build_passthrough_payload(**args)
assert body.get("stream_options") == {"include_usage": True}
def test_repetition_penalty_renamed(self):
body = _build_passthrough_payload(**self._args(), repetition_penalty = 1.1)
assert body.get("repeat_penalty") == 1.1
assert "repetition_penalty" not in body
# =====================================================================
# _friendly_error — httpx transport failures
# =====================================================================
class TestFriendlyErrorHttpx:
"""The async pass-through helpers talk to llama-server via httpx.
When the subprocess is down, httpx raises RequestError subclasses
whose string form (``"All connection attempts failed"``, ``"[Errno 111]
Connection refused"``, ...) does NOT contain the substring
``"Lost connection to llama-server"`` the sync path uses, so the
previous substring-only `_friendly_error` returned a useless generic
message. These tests pin the new isinstance-based mapping.
"""
def _req(self):
return httpx.Request("POST", "http://127.0.0.1:65535/v1/chat/completions")
def test_connect_error_mapped(self):
exc = httpx.ConnectError("All connection attempts failed", request = self._req())
assert "Lost connection" in _friendly_error(exc)
def test_read_error_mapped(self):
exc = httpx.ReadError("EOF", request = self._req())
assert "Lost connection" in _friendly_error(exc)
def test_remote_protocol_error_mapped(self):
exc = httpx.RemoteProtocolError("peer closed", request = self._req())
assert "Lost connection" in _friendly_error(exc)
def test_read_timeout_mapped(self):
exc = httpx.ReadTimeout("timed out", request = self._req())
assert "Lost connection" in _friendly_error(exc)
def test_non_httpx_unchanged(self):
# Non-httpx exceptions still fall through to the existing substring
# heuristics — a context-size message must still produce the
# "Message too long" path.
ctx_msg = (
"request (4096 tokens) exceeds the available context size (2048 tokens)"
)
assert "Message too long" in _friendly_error(ValueError(ctx_msg))
def test_generic_exception_returns_generic_message(self):
assert (
_friendly_error(RuntimeError("unrelated")) == "An internal error occurred"
)

Some files were not shown because too many files have changed in this diff Show more