mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
studio: trim comments on stop-button review changes
Collapse multi-paragraph rationale blocks on the cancel registry, _openai_passthrough_stream, and the frontend onAbortCancel handler into one-line explanations of why the non-obvious behaviour exists. Drop authFetch import that became unused when the cancel POST switched to plain fetch.
This commit is contained in:
parent
9f60dfedd9
commit
e770e76e9f
3 changed files with 39 additions and 72 deletions
|
|
@ -48,15 +48,11 @@ _INTENT_SIGNAL = re.compile(
|
|||
)
|
||||
_MAX_REPROMPTS = 3
|
||||
|
||||
# Default generation caps sent to llama-server. Prevents the runaway where
|
||||
# a user-unset max_tokens makes llama-server default to n_predict = n_ctx
|
||||
# (up to 262144 tokens for Qwen3.5), producing many-minute "zombie" decodes
|
||||
# that ignore stop-button requests. Override max_tokens per-call with an
|
||||
# explicit kwarg; t_max_predict_ms is applied unconditionally as a
|
||||
# wall-clock backstop so callers that set a large max_tokens are still
|
||||
# bounded if cancel signaling fails.
|
||||
# Without max_tokens, llama-server defaults to n_predict = n_ctx (up to
|
||||
# 262144 for Qwen3.5), producing many-minute zombie decodes when cancel
|
||||
# fails. t_max_predict_ms is a wall-clock backstop applied unconditionally.
|
||||
_DEFAULT_MAX_TOKENS = 4096
|
||||
_DEFAULT_T_MAX_PREDICT_MS = 600_000 # 10 minutes wall-clock per request
|
||||
_DEFAULT_T_MAX_PREDICT_MS = 600_000 # 10 min
|
||||
_REPROMPT_MAX_CHARS = 2000
|
||||
|
||||
# ── Pre-compiled patterns for GGUF shard detection ───────────
|
||||
|
|
|
|||
|
|
@ -139,18 +139,15 @@ router = APIRouter()
|
|||
studio_router = APIRouter()
|
||||
|
||||
|
||||
# ── Cancel registry ──────────────────────────────────────────
|
||||
# Tracks every in-flight cancel_event so POST /inference/cancel
|
||||
# can reach them. Keyed by cancel_id (preferred, exclusive per-run),
|
||||
# session_id (thread fallback), and completion_id (internal fallback).
|
||||
# Routed separately from request.is_disconnected() polling because
|
||||
# some proxies (e.g. Colab's) do not propagate client-side fetch
|
||||
# aborts, so the backend never observes disconnect.
|
||||
# Cancel registry. Proxies (e.g. Colab) can swallow client fetch aborts
|
||||
# so is_disconnected() never fires. POST /inference/cancel looks up
|
||||
# in-flight cancel_events here by cancel_id (per-run) or session_id /
|
||||
# completion_id (fallbacks).
|
||||
_CANCEL_REGISTRY: dict[str, set[threading.Event]] = {}
|
||||
_CANCEL_LOCK = threading.Lock()
|
||||
|
||||
# Stash for cancel POSTs that arrive before their stream registers;
|
||||
# the next matching __enter__ replays set() within the TTL.
|
||||
# Cancel POSTs that arrive before registration are stashed; the next
|
||||
# matching __enter__ replays set() within the TTL.
|
||||
_PENDING_CANCELS: dict[str, float] = {}
|
||||
_PENDING_CANCEL_TTL_S = 30.0
|
||||
|
||||
|
|
@ -163,16 +160,15 @@ def _prune_pending(now: float) -> None:
|
|||
|
||||
|
||||
class _TrackedCancel:
|
||||
"""Context manager: register cancel_event in _CANCEL_REGISTRY for the
|
||||
duration of the block so external POST /inference/cancel can reach it."""
|
||||
"""Register cancel_event in _CANCEL_REGISTRY for the block's duration."""
|
||||
|
||||
def __init__(self, event: threading.Event, *keys):
|
||||
self.event = event
|
||||
self.keys = tuple(k for k in keys if k)
|
||||
|
||||
def __enter__(self):
|
||||
# Register + consume-pending must be one critical section to
|
||||
# close the TOCTOU race against a concurrent cancel POST.
|
||||
# Register + consume-pending must be one critical section to close
|
||||
# the TOCTOU race against a concurrent cancel POST.
|
||||
should_cancel = False
|
||||
with _CANCEL_LOCK:
|
||||
for k in self.keys:
|
||||
|
|
@ -199,11 +195,10 @@ class _TrackedCancel:
|
|||
|
||||
|
||||
def _cancel_by_keys(keys) -> int:
|
||||
"""Set cancel_event for matching registry entries. Returns count set.
|
||||
Does NOT stash unmatched keys: session_id and completion_id are shared
|
||||
across runs on the same thread, so stashing them would cancel the
|
||||
user's next unrelated request. cancel_id is the only per-run unique
|
||||
key and is handled by _cancel_by_cancel_id_or_stash."""
|
||||
"""Set cancel_event for matching registry entries; no stash.
|
||||
session_id/completion_id are shared across runs on the same thread,
|
||||
so stashing them would ghost-cancel the user's next request. Only
|
||||
cancel_id is per-run unique (see _cancel_by_cancel_id_or_stash)."""
|
||||
if not keys:
|
||||
return 0
|
||||
events: set[threading.Event] = set()
|
||||
|
|
@ -717,20 +712,15 @@ async def cancel_inference(
|
|||
request: Request,
|
||||
current_subject: str = Depends(get_current_subject),
|
||||
):
|
||||
"""
|
||||
Cancel in-flight inference requests.
|
||||
"""Cancel in-flight inference requests.
|
||||
|
||||
Body (JSON, at least one key required):
|
||||
cancel_id - preferred: per-run UUID, matched exclusively so a
|
||||
stale POST cannot hit a later run on the same thread.
|
||||
cancel_id - preferred: per-run UUID, matched exclusively.
|
||||
session_id - fallback when cancel_id is absent.
|
||||
completion_id - fallback when cancel_id is absent.
|
||||
|
||||
Returns {"cancelled": N}. A cancel_id arriving before its stream has
|
||||
registered is stashed for a short TTL and replayed on registration.
|
||||
|
||||
Exists because some proxies (Colab, etc.) do not propagate client
|
||||
aborts, so request.is_disconnected() is not sufficient.
|
||||
A cancel_id arriving before its stream registers is stashed briefly
|
||||
and replayed on registration. Returns {"cancelled": N}.
|
||||
"""
|
||||
try:
|
||||
body = await request.json()
|
||||
|
|
@ -3357,17 +3347,14 @@ async def _openai_passthrough_stream(
|
|||
_tracker = _TrackedCancel(cancel_event, *_cancel_keys)
|
||||
_tracker.__enter__()
|
||||
|
||||
# Outer guard: `await client.send(...)` below is an await point,
|
||||
# so asyncio.CancelledError (BaseException, not Exception) can strike
|
||||
# there and bypass `except httpx.RequestError`, leaking the registry
|
||||
# entry. The generator's own `finally` only runs once iteration starts;
|
||||
# this outer except covers the gap before StreamingResponse returns.
|
||||
# Outer guard: asyncio.CancelledError at `await client.send(...)` is
|
||||
# a BaseException that bypasses `except httpx.RequestError`; without
|
||||
# this the tracker leaks. The generator's finally only runs once
|
||||
# iteration starts.
|
||||
try:
|
||||
# Dispatch the upstream request BEFORE returning StreamingResponse so
|
||||
# transport errors and non-200 upstream statuses surface as real HTTP
|
||||
# errors to the client. OpenAI SDKs rely on status codes to raise
|
||||
# ``APIError``/``BadRequestError``/...; burying the failure inside a
|
||||
# 200 SSE ``error`` frame silently breaks their error handling.
|
||||
# Dispatch BEFORE returning StreamingResponse so transport errors
|
||||
# and non-200 upstream statuses surface as real HTTP errors --
|
||||
# OpenAI SDKs rely on status codes to raise APIError/BadRequestError.
|
||||
client = httpx.AsyncClient(
|
||||
timeout = 600,
|
||||
limits = httpx.Limits(max_keepalive_connections = 0),
|
||||
|
|
@ -3417,15 +3404,8 @@ async def _openai_passthrough_stream(
|
|||
|
||||
async def _stream():
|
||||
# Same httpx lifecycle pattern as _anthropic_passthrough_stream:
|
||||
# avoid `async with` on the client/response AND explicitly save
|
||||
# resp.aiter_lines() so we can close it ourselves in the finally
|
||||
# block. See the long comment there for the full rationale on
|
||||
# why the anonymous `async for raw_line in resp.aiter_lines():`
|
||||
# pattern leaks an unclosed async generator that Python's
|
||||
# asyncgen GC hook then finalizes in a different asyncio task,
|
||||
# producing "Exception ignored in:" / "async generator ignored
|
||||
# GeneratorExit" / anyio cancel-scope traces on Python 3.13 +
|
||||
# httpcore 1.0.x.
|
||||
# save resp.aiter_lines() so the finally block can aclose() it
|
||||
# on our task. See that function for full rationale.
|
||||
lines_iter = None
|
||||
try:
|
||||
lines_iter = resp.aiter_lines()
|
||||
|
|
@ -3439,16 +3419,13 @@ async def _openai_passthrough_stream(
|
|||
continue
|
||||
if not raw_line.startswith("data: "):
|
||||
continue
|
||||
# Relay the llama-server SSE chunk verbatim so the client
|
||||
# sees its native `id`, `finish_reason`, `delta.tool_calls`,
|
||||
# and final `usage` unchanged.
|
||||
# Relay verbatim to preserve llama-server's native id,
|
||||
# finish_reason, delta.tool_calls, and usage chunks.
|
||||
yield raw_line + "\n\n"
|
||||
if raw_line[6:].strip() == "[DONE]":
|
||||
break
|
||||
except Exception as e:
|
||||
# Mid-stream failures still have to be reported inside the SSE
|
||||
# body because the 200 response headers have already been
|
||||
# committed by the time the first chunk flushes.
|
||||
# 200 headers are already flushed; errors must be in the SSE body.
|
||||
logger.error("openai passthrough stream error: %s", e)
|
||||
err = {
|
||||
"error": {
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
import type { ChatModelAdapter } from "@assistant-ui/react";
|
||||
import type { MessageTiming, ToolCallMessagePart } from "@assistant-ui/core";
|
||||
import { toast } from "sonner";
|
||||
import { authFetch } from "@/features/auth";
|
||||
import { getAuthToken } from "@/features/auth/session";
|
||||
import {
|
||||
generateAudio,
|
||||
|
|
@ -697,25 +696,20 @@ export function createOpenAIStreamAdapter(): ChatModelAdapter {
|
|||
const toolCallParts: ToolCallMessagePart[] = [];
|
||||
let serverMetadata: { usage?: ServerUsage; timings?: ServerTimings } | null = null;
|
||||
|
||||
// Per-run cancellation token. Scoped to this single generation so a
|
||||
// delayed stop POST cannot match the next run on the same thread.
|
||||
// Per-run cancellation token so a delayed stop POST cannot match
|
||||
// the next run on the same thread.
|
||||
const cancelId =
|
||||
typeof crypto !== "undefined" && "randomUUID" in crypto
|
||||
? crypto.randomUUID()
|
||||
: `${Date.now()}-${Math.random().toString(36).slice(2)}`;
|
||||
|
||||
// Some proxies (e.g. Colab) do not propagate fetch aborts to the
|
||||
// backend, so request.is_disconnected() never fires server-side
|
||||
// and the tool-loop keeps running. Explicitly POST /inference/cancel
|
||||
// on abort so the backend can signal its cancel_event.
|
||||
// Colab-style proxies can swallow fetch aborts, so also POST
|
||||
// /inference/cancel explicitly on abort.
|
||||
const onAbortCancel = () => {
|
||||
const body: Record<string, string> = { cancel_id: cancelId };
|
||||
if (resolvedThreadId) body.session_id = resolvedThreadId;
|
||||
// Use plain fetch (+ manual Authorization) instead of authFetch.
|
||||
// authFetch redirects to login on 401; that would kick the user
|
||||
// out mid-stop if the access token expired during a long stream.
|
||||
// The cancel POST is best-effort and the server accepts a bare
|
||||
// Authorization header with keepalive: true.
|
||||
// Plain fetch, not authFetch: authFetch redirects to login on
|
||||
// 401, which would kick the user out mid-stop.
|
||||
const token = getAuthToken();
|
||||
void fetch("/api/inference/cancel", {
|
||||
method: "POST",
|
||||
|
|
|
|||
Loading…
Reference in a new issue