unsloth/studio/backend/models/inference.py
2026-04-13 20:43:05 +04:00

696 lines
24 KiB
Python

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""
Pydantic schemas for Inference API
"""
from __future__ import annotations
import time
import uuid
from typing import Annotated, Any, Dict, Literal, Optional, List, Union
from pydantic import BaseModel, Discriminator, Field, Tag
class LoadRequest(BaseModel):
"""Request to load a model for inference"""
model_path: str = Field(..., description = "Model identifier or local path")
hf_token: Optional[str] = Field(
None, description = "HuggingFace token for gated models"
)
max_seq_length: int = Field(
0,
ge = 0,
le = 1048576,
description = "Maximum sequence length (0 = model default for GGUF)",
)
load_in_4bit: bool = Field(True, description = "Load model in 4-bit quantization")
is_lora: bool = Field(False, description = "Whether this is a LoRA adapter")
gguf_variant: Optional[str] = Field(
None, description = "GGUF quantization variant (e.g. 'Q4_K_M')"
)
trust_remote_code: bool = Field(
False,
description = "Allow loading models with custom code (e.g. NVIDIA Nemotron). Only enable for repos you trust.",
)
chat_template_override: Optional[str] = Field(
None,
description = "Custom Jinja2 chat template to use instead of the model's default",
)
cache_type_kv: Optional[str] = Field(
None,
description = "KV cache data type for both K and V (e.g. 'f16', 'bf16', 'q8_0', 'q4_1', 'q5_1')",
)
gpu_ids: Optional[List[int]] = Field(
None,
description = "Physical GPU indices to use, for example [0, 1]. Omit or pass [] to use automatic selection. Explicit gpu_ids are unsupported when the parent CUDA_VISIBLE_DEVICES uses UUID/MIG entries. Not supported for GGUF models.",
)
speculative_type: Optional[str] = Field(
None,
description = "Speculative decoding mode for GGUF models (e.g. 'ngram-simple', 'ngram-mod'). Ignored for non-GGUF and vision models.",
)
class UnloadRequest(BaseModel):
"""Request to unload a model"""
model_path: str = Field(..., description = "Model identifier to unload")
class ValidateModelRequest(BaseModel):
"""
Lightweight validation request to check whether a model identifier
*can be resolved* into a ModelConfig.
This does NOT actually load weights into GPU memory.
"""
model_path: str = Field(..., description = "Model identifier or local path")
hf_token: Optional[str] = Field(
None, description = "HuggingFace token for gated models"
)
gguf_variant: Optional[str] = Field(
None, description = "GGUF quantization variant (e.g. 'Q4_K_M')"
)
class ValidateModelResponse(BaseModel):
"""
Result of model validation.
valid == True means ModelConfig.from_identifier() succeeded and basic
introspection (GGUF / LoRA / vision flags) is available.
"""
valid: bool = Field(..., description = "Whether the model identifier looks valid")
message: str = Field(..., description = "Human-readable validation message")
identifier: Optional[str] = Field(None, description = "Resolved model identifier")
display_name: Optional[str] = Field(
None, description = "Display name derived from identifier"
)
is_gguf: bool = Field(False, description = "Whether this is a GGUF model (llama.cpp)")
is_lora: bool = Field(False, description = "Whether this is a LoRA adapter")
is_vision: bool = Field(False, description = "Whether this is a vision-capable model")
requires_trust_remote_code: bool = Field(
False,
description = "Whether the model defaults require trust_remote_code to be enabled for loading.",
)
class GenerateRequest(BaseModel):
"""Request for text generation (legacy /generate/stream endpoint)"""
messages: List[dict] = Field(..., description = "Chat messages in OpenAI format")
system_prompt: str = Field("", description = "System prompt")
temperature: float = Field(0.6, ge = 0.0, le = 2.0, description = "Sampling temperature")
top_p: float = Field(0.95, ge = 0.0, le = 1.0, description = "Top-p sampling")
top_k: int = Field(20, ge = -1, le = 100, description = "Top-k sampling")
max_new_tokens: int = Field(
2048, ge = 1, le = 4096, description = "Maximum tokens to generate"
)
repetition_penalty: float = Field(
1.0, ge = 1.0, le = 2.0, description = "Repetition penalty"
)
presence_penalty: float = Field(0.0, ge = 0.0, le = 2.0, description = "Presence penalty")
image_base64: Optional[str] = Field(
None, description = "Base64 encoded image for vision models"
)
class LoadResponse(BaseModel):
"""Response after loading a model"""
status: str = Field(..., description = "Load status")
model: str = Field(..., description = "Model identifier")
display_name: str = Field(..., description = "Display name of the model")
is_vision: bool = Field(False, description = "Whether model is a vision model")
is_lora: bool = Field(False, description = "Whether model is a LoRA adapter")
is_gguf: bool = Field(
False, description = "Whether model is a GGUF model (llama.cpp)"
)
is_audio: bool = Field(False, description = "Whether model is a TTS audio model")
audio_type: Optional[str] = Field(
None, description = "Audio codec type: snac, csm, bicodec, dac"
)
has_audio_input: bool = Field(
False, description = "Whether model accepts audio input (ASR)"
)
inference: dict = Field(
..., description = "Inference parameters (temperature, top_p, top_k, min_p)"
)
requires_trust_remote_code: bool = Field(
False,
description = "Whether the model defaults require trust_remote_code to be enabled for loading.",
)
context_length: Optional[int] = Field(
None, description = "Model's native context length (from GGUF metadata)"
)
max_context_length: Optional[int] = Field(
None, description = "Maximum context length currently available on this hardware"
)
native_context_length: Optional[int] = Field(
None,
description = "Model's native context length from GGUF metadata (not capped by VRAM)",
)
supports_reasoning: bool = Field(
False,
description = "Whether model supports thinking/reasoning mode (enable_thinking)",
)
reasoning_always_on: bool = Field(
False,
description = "Whether reasoning is always on (hardcoded <think> tags, not toggleable)",
)
supports_tools: bool = Field(
False,
description = "Whether model supports tool calling (web search, etc.)",
)
cache_type_kv: Optional[str] = Field(
None,
description = "KV cache data type for K and V (e.g. 'f16', 'bf16', 'q8_0')",
)
chat_template: Optional[str] = Field(
None,
description = "Jinja2 chat template string (from GGUF metadata or tokenizer)",
)
speculative_type: Optional[str] = Field(
None,
description = "Active speculative decoding mode (e.g. 'ngram-simple', 'ngram-mod'), or None if disabled",
)
class UnloadResponse(BaseModel):
"""Response after unloading a model"""
status: str = Field(..., description = "Unload status")
model: str = Field(..., description = "Model identifier that was unloaded")
class InferenceStatusResponse(BaseModel):
"""Current inference backend status"""
active_model: Optional[str] = Field(
None, description = "Currently active model identifier"
)
is_vision: bool = Field(
False, description = "Whether the active model is a vision model"
)
is_gguf: bool = Field(
False, description = "Whether the active model is a GGUF model (llama.cpp)"
)
gguf_variant: Optional[str] = Field(
None, description = "GGUF quantization variant (e.g. Q4_K_M)"
)
is_audio: bool = Field(
False, description = "Whether the active model is a TTS audio model"
)
audio_type: Optional[str] = Field(
None, description = "Audio codec type: snac, csm, bicodec, dac"
)
has_audio_input: bool = Field(
False, description = "Whether model accepts audio input (ASR)"
)
loading: List[str] = Field(
default_factory = list, description = "Models currently being loaded"
)
loaded: List[str] = Field(
default_factory = list, description = "Models currently loaded"
)
inference: Optional[Dict[str, Any]] = Field(
None, description = "Recommended inference parameters for the active model"
)
requires_trust_remote_code: bool = Field(
False,
description = "Whether the active model requires trust_remote_code to be enabled for loading.",
)
supports_reasoning: bool = Field(
False, description = "Whether the active model supports reasoning/thinking mode"
)
reasoning_always_on: bool = Field(
False, description = "Whether reasoning is always on (not toggleable)"
)
supports_tools: bool = Field(
False, description = "Whether the active model supports tool calling"
)
context_length: Optional[int] = Field(
None, description = "Context length of the active model"
)
max_context_length: Optional[int] = Field(
None,
description = "Maximum context length currently available for the active model",
)
native_context_length: Optional[int] = Field(
None,
description = "Model's native context length from GGUF metadata (not capped by VRAM)",
)
speculative_type: Optional[str] = Field(
None,
description = "Active speculative decoding mode (e.g. 'ngram-simple', 'ngram-mod'), or None if disabled",
)
# =====================================================================
# OpenAI-Compatible Chat Completions Models
# =====================================================================
# ── Multimodal content parts (OpenAI vision format) ──────────────
class TextContentPart(BaseModel):
"""Text content part in a multimodal message."""
type: Literal["text"]
text: str
class ImageUrl(BaseModel):
"""Image URL object — supports data URIs and remote URLs."""
url: str = Field(..., description = "data:image/png;base64,... or https://...")
detail: Optional[Literal["auto", "low", "high"]] = "auto"
class ImageContentPart(BaseModel):
"""Image content part in a multimodal message."""
type: Literal["image_url"]
image_url: ImageUrl
def _content_part_discriminator(v):
if isinstance(v, dict):
return v.get("type")
return getattr(v, "type", None)
ContentPart = Annotated[
Union[
Annotated[TextContentPart, Tag("text")],
Annotated[ImageContentPart, Tag("image_url")],
],
Discriminator(_content_part_discriminator),
]
"""Union type for multimodal content parts, discriminated by the 'type' field."""
# ── Messages ─────────────────────────────────────────────────────
class ChatMessage(BaseModel):
"""
A single message in the conversation.
``content`` may be a plain string (text-only) or a list of
content parts for multimodal messages (OpenAI vision format).
"""
role: Literal["system", "user", "assistant"] = Field(
..., description = "Message role"
)
content: Union[str, list[ContentPart]] = Field(
..., description = "Message content (string or multimodal parts)"
)
class ChatCompletionRequest(BaseModel):
"""
OpenAI-compatible chat completion request.
Extensions (non-OpenAI fields) are marked with 'x-unsloth'.
"""
model: str = Field(
"default",
description = "Model identifier (informational; the active model is used)",
)
messages: list[ChatMessage] = Field(..., description = "Conversation messages")
stream: bool = Field(True, description = "Whether to stream the response via SSE")
temperature: float = Field(0.6, ge = 0.0, le = 2.0)
top_p: float = Field(0.95, ge = 0.0, le = 1.0)
max_tokens: Optional[int] = Field(
None, ge = 1, description = "Maximum tokens to generate (None = until EOS)"
)
presence_penalty: float = Field(0.0, ge = 0.0, le = 2.0, description = "Presence penalty")
# ── Unsloth extensions (ignored by standard OpenAI clients) ──
top_k: int = Field(20, ge = -1, le = 100, description = "[x-unsloth] Top-k sampling")
min_p: float = Field(
0.01, ge = 0.0, le = 1.0, description = "[x-unsloth] Min-p sampling threshold"
)
repetition_penalty: float = Field(
1.0, ge = 1.0, le = 2.0, description = "[x-unsloth] Repetition penalty"
)
image_base64: Optional[str] = Field(
None, description = "[x-unsloth] Base64-encoded image for vision models"
)
audio_base64: Optional[str] = Field(
None, description = "[x-unsloth] Base64-encoded WAV for audio-input models (ASR)"
)
use_adapter: Optional[Union[bool, str]] = Field(
None,
description = (
"[x-unsloth] Adapter control for compare mode. "
"null = no change (default), "
"false = disable adapters (base model), "
"true = enable the current adapter, "
"string = enable a specific adapter by name."
),
)
enable_thinking: Optional[bool] = Field(
None,
description = "[x-unsloth] Enable/disable thinking/reasoning mode for supported models",
)
enable_tools: Optional[bool] = Field(
None,
description = "[x-unsloth] Enable tool calling for supported models",
)
enabled_tools: Optional[list[str]] = Field(
None,
description = "[x-unsloth] List of enabled tool names (e.g. ['web_search', 'python', 'terminal']). If None, all tools are enabled.",
)
auto_heal_tool_calls: Optional[bool] = Field(
True,
description = "[x-unsloth] Auto-detect and fix malformed tool calls from model output.",
)
max_tool_calls_per_message: Optional[int] = Field(
25,
ge = 0,
description = "[x-unsloth] Maximum number of tool call iterations per message (0 = disabled, 9999 = unlimited).",
)
tool_call_timeout: Optional[int] = Field(
300,
ge = 1,
description = "[x-unsloth] Timeout in seconds for each tool call execution (9999 = no limit).",
)
session_id: Optional[str] = Field(
None,
description = "[x-unsloth] Session/thread ID for scoping tool execution sandbox.",
)
# ── Streaming response chunks ────────────────────────────────────
class ChoiceDelta(BaseModel):
"""Delta content for a streaming chunk."""
role: Optional[str] = None
content: Optional[str] = None
class ChunkChoice(BaseModel):
"""A single choice in a streaming chunk."""
index: int = 0
delta: ChoiceDelta
finish_reason: Optional[Literal["stop", "length"]] = None
class ChatCompletionChunk(BaseModel):
"""A single SSE chunk in OpenAI streaming format."""
id: str = Field(default_factory = lambda: f"chatcmpl-{uuid.uuid4().hex[:12]}")
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory = lambda: int(time.time()))
model: str = "default"
choices: list[ChunkChoice]
usage: Optional[CompletionUsage] = None
timings: Optional[dict] = None
# ── Non-streaming response ───────────────────────────────────────
class CompletionMessage(BaseModel):
"""The assistant's complete response message."""
role: Literal["assistant"] = "assistant"
content: str
class CompletionChoice(BaseModel):
"""A single choice in a non-streaming response."""
index: int = 0
message: CompletionMessage
finish_reason: Literal["stop", "length"] = "stop"
class CompletionUsage(BaseModel):
"""Token usage statistics (approximate)."""
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
class ChatCompletion(BaseModel):
"""Non-streaming chat completion response."""
id: str = Field(default_factory = lambda: f"chatcmpl-{uuid.uuid4().hex[:12]}")
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory = lambda: int(time.time()))
model: str = "default"
choices: list[CompletionChoice]
usage: CompletionUsage = Field(default_factory = CompletionUsage)
# =====================================================================
# OpenAI Responses API Models (/v1/responses)
# =====================================================================
# ── Request models ──────────────────────────────────────────────
class ResponsesInputTextPart(BaseModel):
"""Text content part in a Responses API message (type=input_text)."""
type: Literal["input_text"]
text: str
class ResponsesInputImagePart(BaseModel):
"""Image content part in a Responses API message (type=input_image)."""
type: Literal["input_image"]
image_url: str = Field(..., description = "data:image/png;base64,... or https://...")
detail: Optional[Literal["auto", "low", "high"]] = "auto"
ResponsesContentPart = Union[ResponsesInputTextPart, ResponsesInputImagePart]
class ResponsesInputMessage(BaseModel):
"""A single message in the Responses API input array."""
role: Literal["system", "user", "assistant", "developer"]
content: Union[str, list[ResponsesContentPart]]
class ResponsesRequest(BaseModel):
"""OpenAI Responses API request."""
model: str = Field("default", description = "Model identifier")
input: Union[str, list[ResponsesInputMessage]] = Field(
default = [],
description = "Input text or message list",
)
instructions: Optional[str] = Field(
None, description = "System / developer instructions"
)
temperature: Optional[float] = Field(None, ge = 0.0, le = 2.0)
top_p: Optional[float] = Field(None, ge = 0.0, le = 1.0)
max_output_tokens: Optional[int] = Field(None, ge = 1)
stream: bool = Field(False, description = "Whether to stream the response via SSE")
# Accepted but ignored -- keeps SDK clients from failing on unsupported fields
tools: Optional[list] = None
tool_choice: Optional[Any] = None
previous_response_id: Optional[str] = None
store: Optional[bool] = None
metadata: Optional[dict] = None
truncation: Optional[Any] = None
user: Optional[str] = None
text: Optional[Any] = None
reasoning: Optional[Any] = None
model_config = {"extra": "allow"}
# ── Response models ─────────────────────────────────────────────
class ResponsesOutputTextContent(BaseModel):
"""A text content block inside an output message."""
type: Literal["output_text"] = "output_text"
text: str
annotations: list = Field(default_factory = list)
class ResponsesOutputMessage(BaseModel):
"""An output message in the Responses API response."""
type: Literal["message"] = "message"
id: str = Field(default_factory = lambda: f"msg_{uuid.uuid4().hex[:12]}")
status: Literal["completed", "in_progress"] = "completed"
role: Literal["assistant"] = "assistant"
content: list[ResponsesOutputTextContent] = Field(default_factory = list)
class ResponsesUsage(BaseModel):
"""Token usage for a Responses API response (input_tokens, not prompt_tokens)."""
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
class ResponsesResponse(BaseModel):
"""Top-level Responses API response object."""
id: str = Field(default_factory = lambda: f"resp_{uuid.uuid4().hex[:12]}")
object: Literal["response"] = "response"
created_at: int = Field(default_factory = lambda: int(time.time()))
status: Literal["completed", "in_progress", "failed"] = "completed"
model: str = "default"
output: list[ResponsesOutputMessage] = Field(default_factory = list)
usage: ResponsesUsage = Field(default_factory = ResponsesUsage)
error: Optional[Any] = None
incomplete_details: Optional[Any] = None
instructions: Optional[str] = None
metadata: dict = Field(default_factory = dict)
temperature: Optional[float] = None
top_p: Optional[float] = None
max_output_tokens: Optional[int] = None
previous_response_id: Optional[str] = None
text: Optional[Any] = None
tool_choice: Optional[Any] = None
tools: list = Field(default_factory = list)
truncation: Optional[Any] = None
# =====================================================================
# Anthropic Messages API Models (/v1/messages)
# =====================================================================
# ── Request models ─────────────────────────────────────────────
class AnthropicTextBlock(BaseModel):
type: Literal["text"]
text: str
class AnthropicImageSource(BaseModel):
type: Literal["base64", "url"]
media_type: Optional[str] = None
data: Optional[str] = None
url: Optional[str] = None
class AnthropicImageBlock(BaseModel):
type: Literal["image"]
source: AnthropicImageSource
class AnthropicToolUseBlock(BaseModel):
type: Literal["tool_use"]
id: str
name: str
input: dict
class AnthropicToolResultBlock(BaseModel):
type: Literal["tool_result"]
tool_use_id: str
content: Union[str, list] = ""
AnthropicContentBlock = Union[
AnthropicTextBlock,
AnthropicImageBlock,
AnthropicToolUseBlock,
AnthropicToolResultBlock,
]
class AnthropicMessage(BaseModel):
role: Literal["user", "assistant"]
content: Union[str, list[AnthropicContentBlock]]
class AnthropicTool(BaseModel):
name: str
description: Optional[str] = None
input_schema: dict
class AnthropicMessagesRequest(BaseModel):
model: str = "default"
max_tokens: Optional[int] = None
messages: list[AnthropicMessage]
system: Optional[Union[str, list]] = None
tools: Optional[list[AnthropicTool]] = None
tool_choice: Optional[Any] = None
stream: bool = False
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
stop_sequences: Optional[list[str]] = None
metadata: Optional[dict] = None
# [x-unsloth] extensions — mirror the OpenAI endpoint convenience fields
min_p: Optional[float] = Field(
None, ge = 0.0, le = 1.0, description = "[x-unsloth] Min-p sampling threshold"
)
repetition_penalty: Optional[float] = Field(
None, ge = 1.0, le = 2.0, description = "[x-unsloth] Repetition penalty"
)
presence_penalty: Optional[float] = Field(
None, ge = 0.0, le = 2.0, description = "[x-unsloth] Presence penalty"
)
enable_tools: Optional[bool] = None
enabled_tools: Optional[list[str]] = None
session_id: Optional[str] = None
model_config = {"extra": "allow"}
# ── Response models ────────────────────────────────────────────
class AnthropicUsage(BaseModel):
input_tokens: int = 0
output_tokens: int = 0
class AnthropicResponseTextBlock(BaseModel):
type: Literal["text"] = "text"
text: str
class AnthropicResponseToolUseBlock(BaseModel):
type: Literal["tool_use"] = "tool_use"
id: str
name: str
input: dict
AnthropicResponseBlock = Union[
AnthropicResponseTextBlock, AnthropicResponseToolUseBlock
]
class AnthropicMessagesResponse(BaseModel):
id: str = Field(default_factory = lambda: f"msg_{uuid.uuid4().hex[:24]}")
type: Literal["message"] = "message"
role: Literal["assistant"] = "assistant"
content: list[AnthropicResponseBlock] = Field(default_factory = list)
model: str = "default"
stop_reason: Optional[str] = None
stop_sequence: Optional[str] = None
usage: AnthropicUsage = Field(default_factory = AnthropicUsage)