diff --git a/studio/backend/models/inference.py b/studio/backend/models/inference.py index b7b64b7b4..8d9dc9830 100644 --- a/studio/backend/models/inference.py +++ b/studio/backend/models/inference.py @@ -94,6 +94,10 @@ class ValidateModelResponse(BaseModel): is_gguf: bool = Field(False, description = "Whether this is a GGUF model (llama.cpp)") is_lora: bool = Field(False, description = "Whether this is a LoRA adapter") is_vision: bool = Field(False, description = "Whether this is a vision-capable model") + requires_trust_remote_code: bool = Field( + False, + description = "Whether the model defaults require trust_remote_code to be enabled for loading.", + ) class GenerateRequest(BaseModel): @@ -137,6 +141,10 @@ class LoadResponse(BaseModel): inference: dict = Field( ..., description = "Inference parameters (temperature, top_p, top_k, min_p)" ) + requires_trust_remote_code: bool = Field( + False, + description = "Whether the model defaults require trust_remote_code to be enabled for loading.", + ) context_length: Optional[int] = Field( None, description = "Model's native context length (from GGUF metadata)" ) @@ -213,6 +221,10 @@ class InferenceStatusResponse(BaseModel): inference: Optional[Dict[str, Any]] = Field( None, description = "Recommended inference parameters for the active model" ) + requires_trust_remote_code: bool = Field( + False, + description = "Whether the active model requires trust_remote_code to be enabled for loading.", + ) supports_reasoning: bool = Field( False, description = "Whether the active model supports reasoning/thinking mode" ) diff --git a/studio/backend/requirements/extras-no-deps.txt b/studio/backend/requirements/extras-no-deps.txt index 9934bacd2..59a8e4543 100644 --- a/studio/backend/requirements/extras-no-deps.txt +++ b/studio/backend/requirements/extras-no-deps.txt @@ -13,4 +13,4 @@ torch-c-dlpack-ext sentence_transformers==5.2.0 transformers==4.57.6 pytorch_tokenizers -kernels +kernels==0.12.1 diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py index 24812199d..0717b3bc9 100644 --- a/studio/backend/routes/inference.py +++ b/studio/backend/routes/inference.py @@ -198,6 +198,9 @@ async def load_model( if _gguf_audio else False, inference = inference_config, + requires_trust_remote_code = bool( + inference_config.get("trust_remote_code", False) + ), context_length = llama_backend.context_length, max_context_length = llama_backend.max_context_length, native_context_length = llama_backend.native_context_length, @@ -235,6 +238,9 @@ async def load_model( audio_type = _model_info.get("audio_type"), has_audio_input = _model_info.get("has_audio_input", False), inference = inference_config, + requires_trust_remote_code = bool( + inference_config.get("trust_remote_code", False) + ), chat_template = _chat_template, ) @@ -341,6 +347,9 @@ async def load_model( audio_type = _gguf_audio, has_audio_input = is_audio_input_type(_gguf_audio), inference = inference_config, + requires_trust_remote_code = bool( + inference_config.get("trust_remote_code", False) + ), context_length = llama_backend.context_length, max_context_length = llama_backend.max_context_length, native_context_length = llama_backend.native_context_length, @@ -479,6 +488,9 @@ async def load_model( audio_type = config.audio_type, has_audio_input = config.has_audio_input, inference = inference_config, + requires_trust_remote_code = bool( + inference_config.get("trust_remote_code", False) + ), chat_template = _chat_template, ) @@ -534,6 +546,9 @@ async def validate_model( is_gguf = getattr(config, "is_gguf", False), is_lora = getattr(config, "is_lora", False), is_vision = getattr(config, "is_vision", False), + requires_trust_remote_code = bool( + load_inference_config(config.identifier).get("trust_remote_code", False) + ), ) except HTTPException: @@ -679,6 +694,9 @@ async def get_status( loading = [], loaded = [_model_id], inference = _inference_cfg, + requires_trust_remote_code = bool( + (_inference_cfg or {}).get("trust_remote_code", False) + ), supports_reasoning = llama_backend.supports_reasoning, reasoning_always_on = llama_backend.reasoning_always_on, supports_tools = llama_backend.supports_tools, @@ -706,6 +724,11 @@ async def get_status( supports_reasoning = False if backend.active_model_name and hasattr(backend, "_is_gpt_oss_model"): supports_reasoning = backend._is_gpt_oss_model() + inference_config = ( + load_inference_config(backend.active_model_name) + if backend.active_model_name + else None + ) return InferenceStatusResponse( active_model = backend.active_model_name, @@ -716,6 +739,10 @@ async def get_status( has_audio_input = has_audio_input, loading = list(getattr(backend, "loading_models", set())), loaded = list(backend.models.keys()), + inference = inference_config, + requires_trust_remote_code = bool( + (inference_config or {}).get("trust_remote_code", False) + ), supports_reasoning = supports_reasoning, ) diff --git a/studio/frontend/src/components/assistant-ui/markdown-text.tsx b/studio/frontend/src/components/assistant-ui/markdown-text.tsx index 91ef78fcf..c7974db36 100644 --- a/studio/frontend/src/components/assistant-ui/markdown-text.tsx +++ b/studio/frontend/src/components/assistant-ui/markdown-text.tsx @@ -41,7 +41,7 @@ const COPY_RESET_MS = 2000; const MERMAID_SOURCE_RE = /```mermaid\s*([\s\S]*?)```/i; const CODE_FENCE_RE = /^```([^\r\n`]*)\r?\n([\s\S]*?)\r?\n?```$/; const ACTION_PANEL_CLASS = - "pointer-events-auto flex shrink-0 items-center gap-2 rounded-md border border-sidebar bg-sidebar/80 px-1.5 py-1 supports-[backdrop-filter]:bg-sidebar/70 supports-[backdrop-filter]:backdrop-blur"; + "pointer-events-auto flex shrink-0 items-center gap-2 rounded-md border border-sidebar bg-sidebar/80 px-1.5 py-1 supports-[backdrop-filter]:bg-sidebar/70 supports-[backdrop-filter]:backdrop-blur dark:border-white/10 dark:bg-code-block dark:supports-[backdrop-filter]:bg-code-block"; const ACTION_BUTTON_CLASS = "cursor-pointer p-1 text-muted-foreground transition-all hover:text-foreground disabled:cursor-not-allowed disabled:opacity-50"; diff --git a/studio/frontend/src/components/assistant-ui/model-selector/pickers.tsx b/studio/frontend/src/components/assistant-ui/model-selector/pickers.tsx index 74ca2542d..87d59dfaa 100644 --- a/studio/frontend/src/components/assistant-ui/model-selector/pickers.tsx +++ b/studio/frontend/src/components/assistant-ui/model-selector/pickers.tsx @@ -841,7 +841,7 @@ export function HubModelPicker({ (cachedGguf.length > 0 || (!chatOnly && cachedModels.length > 0)) ? ( <> - {"\uD83E\uDDA5"} Downloaded + Downloaded {cachedGguf.map((c) => (
- {"\uD83E\uDDA5"} Recommended + Recommended {visibleRecommendedIds.length === 0 ? (
No default models. @@ -1128,7 +1128,7 @@ export function HubModelPicker({ {showHfSection && filteredRecommendedIds.length > 0 ? ( <> - {"\uD83E\uDDA5"} Recommended + Recommended {filteredRecommendedIds.map((id) => { const vram = recommendedVramMap.get(id); return ( diff --git a/studio/frontend/src/components/assistant-ui/thread.tsx b/studio/frontend/src/components/assistant-ui/thread.tsx index 0d99654b8..8f41987fb 100644 --- a/studio/frontend/src/components/assistant-ui/thread.tsx +++ b/studio/frontend/src/components/assistant-ui/thread.tsx @@ -56,9 +56,12 @@ import { RefreshCwIcon, SquareIcon, TerminalIcon, + Trash2Icon, XIcon, } from "lucide-react"; import { type FC, useCallback, useEffect, useRef, useState } from "react"; +import { toast } from "sonner"; +import { deleteThreadMessage } from "@/features/chat/utils/delete-thread-message"; import { useChatRuntimeStore } from "@/features/chat/stores/chat-runtime-store"; export const Thread: FC<{ hideComposer?: boolean; hideWelcome?: boolean }> = ({ @@ -635,6 +638,41 @@ const AssistantMessage: FC = () => { const COPY_RESET_MS = 2000; +const DeleteMessageButton: FC = () => { + const aui = useAui(); + const messageId = useAuiState(({ message }) => message.id); + const isRunning = useAuiState(({ thread }) => thread.isRunning); + + const handleDelete = async () => { + const remoteId = aui.threadListItem().getState().remoteId; + const thread = aui.thread(); + try { + await deleteThreadMessage({ + thread: { + export: () => thread.export(), + import: (data) => thread.import(data), + }, + messageId, + remoteId, + }); + } catch (error) { + console.error("Failed to delete message", error); + toast.error("Failed to delete message"); + } + }; + + return ( + + + + ); +}; + const CopyButton: FC = () => { const aui = useAui(); const [copied, setCopied] = useState(false); @@ -673,6 +711,7 @@ const AssistantActionBar: FC = () => { + @@ -748,6 +787,7 @@ const UserActionBar: FC = () => { + ); }; diff --git a/studio/frontend/src/features/chat/api/chat-adapter.ts b/studio/frontend/src/features/chat/api/chat-adapter.ts index 3d1bce290..76dbbf0bf 100644 --- a/studio/frontend/src/features/chat/api/chat-adapter.ts +++ b/studio/frontend/src/features/chat/api/chat-adapter.ts @@ -11,6 +11,7 @@ import { listGgufVariants, loadModel, streamChatCompletions, + validateModel, } from "./chat-api"; import { db } from "../db"; import { useChatRuntimeStore } from "../stores/chat-runtime-store"; @@ -252,13 +253,39 @@ function waitForModelReady(abortSignal?: AbortSignal): Promise { * without selecting one. Prefers GGUF (picks smallest cached variant), * falls back to smallest cached safetensors model. */ -async function autoLoadSmallestModel(): Promise { - const hfToken = useChatRuntimeStore.getState().hfToken || null; +async function autoLoadSmallestModel(): Promise<{ + loaded: boolean; + blockedByTrustRemoteCode: boolean; +}> { + const store = useChatRuntimeStore.getState(); + const hfToken = store.hfToken || null; + const trustRemoteCode = store.params.trustRemoteCode ?? false; const toastId = toast("Loading a model…", { description: "Auto-selecting the smallest downloaded model.", duration: 5000, closeButton: true, }); + let blockedByTrustRemoteCode = false; + let hadNonTrustFailure = false; + + async function canAutoLoad(payload: { + model_path: string; + max_seq_length: number; + is_lora: boolean; + gguf_variant?: string | null; + }): Promise { + const validation = await validateModel({ + ...payload, + hf_token: hfToken, + load_in_4bit: true, + trust_remote_code: trustRemoteCode, + }); + if (validation.requires_trust_remote_code && !trustRemoteCode) { + blockedByTrustRemoteCode = true; + return false; + } + return true; + } try { const [ggufRepos, modelRepos] = await Promise.all([ listCachedGguf().catch(() => []), @@ -277,6 +304,16 @@ async function autoLoadSmallestModel(): Promise { .sort((a, b) => a.size_bytes - b.size_bytes); if (downloaded.length > 0) { const variant = downloaded[0]; + if ( + !(await canAutoLoad({ + model_path: repo.repo_id, + max_seq_length: 0, + is_lora: false, + gguf_variant: variant.quant, + })) + ) { + continue; + } const loadResp = await loadModel({ model_path: repo.repo_id, hf_token: hfToken, @@ -284,10 +321,13 @@ async function autoLoadSmallestModel(): Promise { load_in_4bit: true, is_lora: false, gguf_variant: variant.quant, - trust_remote_code: false, + trust_remote_code: trustRemoteCode, }); useChatRuntimeStore.getState().setCheckpoint(repo.repo_id, variant.quant); const store = useChatRuntimeStore.getState(); + store.setModelRequiresTrustRemoteCode( + loadResp.requires_trust_remote_code ?? false, + ); store.setParams({ ...store.params, maxTokens: loadResp.context_length ?? 131072 }); // Add model to store so the selector shows the name const autoModel: ChatModelSummary = { @@ -319,9 +359,10 @@ async function autoLoadSmallestModel(): Promise { chatTemplateOverride: null, }); toast.success(`Loaded ${repo.repo_id} (${variant.quant})`, { id: toastId }); - return true; + return { loaded: true, blockedByTrustRemoteCode: false }; } } catch { + hadNonTrustFailure = true; continue; } } @@ -332,6 +373,16 @@ async function autoLoadSmallestModel(): Promise { const sorted = [...modelRepos].sort((a, b) => a.size_bytes - b.size_bytes); for (const repo of sorted) { try { + if ( + !(await canAutoLoad({ + model_path: repo.repo_id, + max_seq_length: 4096, + is_lora: false, + gguf_variant: null, + })) + ) { + continue; + } const sfLoadResp = await loadModel({ model_path: repo.repo_id, hf_token: hfToken, @@ -339,10 +390,13 @@ async function autoLoadSmallestModel(): Promise { load_in_4bit: true, is_lora: false, gguf_variant: null, - trust_remote_code: false, + trust_remote_code: trustRemoteCode, }); useChatRuntimeStore.getState().setCheckpoint(repo.repo_id); const store = useChatRuntimeStore.getState(); + store.setModelRequiresTrustRemoteCode( + sfLoadResp.requires_trust_remote_code ?? false, + ); store.setParams({ ...store.params, maxTokens: 4096 }); const sfModel: ChatModelSummary = { id: repo.repo_id, @@ -355,8 +409,9 @@ async function autoLoadSmallestModel(): Promise { store.setModels([...store.models, sfModel]); } toast.success(`Loaded ${repo.repo_id}`, { id: toastId }); - return true; + return { loaded: true, blockedByTrustRemoteCode: false }; } catch { + hadNonTrustFailure = true; continue; } } @@ -369,6 +424,17 @@ async function autoLoadSmallestModel(): Promise { duration: 30000, }); try { + if ( + !(await canAutoLoad({ + model_path: "unsloth/Qwen3.5-4B-GGUF", + max_seq_length: 0, + is_lora: false, + gguf_variant: "UD-Q4_K_XL", + })) + ) { + toast.dismiss(toastId); + return { loaded: false, blockedByTrustRemoteCode }; + } const loadResp = await loadModel({ model_path: "unsloth/Qwen3.5-4B-GGUF", hf_token: hfToken, @@ -376,10 +442,13 @@ async function autoLoadSmallestModel(): Promise { load_in_4bit: true, is_lora: false, gguf_variant: "UD-Q4_K_XL", - trust_remote_code: false, + trust_remote_code: trustRemoteCode, }); useChatRuntimeStore.getState().setCheckpoint("unsloth/Qwen3.5-4B-GGUF", "UD-Q4_K_XL"); const store = useChatRuntimeStore.getState(); + store.setModelRequiresTrustRemoteCode( + loadResp.requires_trust_remote_code ?? false, + ); store.setParams({ ...store.params, maxTokens: loadResp.context_length ?? 131072 }); const defaultModel: ChatModelSummary = { id: "unsloth/Qwen3.5-4B-GGUF", @@ -406,14 +475,24 @@ async function autoLoadSmallestModel(): Promise { chatTemplateOverride: null, }); toast.success("Loaded Qwen3.5-4B (UD-Q4_K_XL)", { id: toastId }); - return true; + return { loaded: true, blockedByTrustRemoteCode: false }; } catch { toast.dismiss(toastId); - return false; + hadNonTrustFailure = true; + return { + loaded: false, + blockedByTrustRemoteCode: + blockedByTrustRemoteCode && !hadNonTrustFailure, + }; } } catch { toast.dismiss(toastId); - return false; + hadNonTrustFailure = true; + return { + loaded: false, + blockedByTrustRemoteCode: + blockedByTrustRemoteCode && !hadNonTrustFailure, + }; } } @@ -434,11 +513,19 @@ export function createOpenAIStreamAdapter(): ChatModelAdapter { if (!useChatRuntimeStore.getState().params.checkpoint) { // Auto-load the smallest downloaded model - const loaded = await autoLoadSmallestModel(); + const { loaded, blockedByTrustRemoteCode } = + await autoLoadSmallestModel(); if (!loaded) { - toast.error("No model loaded", { - description: "Pick a model in the top bar, then retry.", - }); + toast.error( + blockedByTrustRemoteCode + ? "Enable custom code to auto-load this model" + : "No model loaded", + { + description: blockedByTrustRemoteCode + ? 'Turn on "Enable custom code" in Chat Settings, or pick another model in the top bar.' + : "Pick a model in the top bar, then retry.", + }, + ); throw new Error("Load a model first."); } } diff --git a/studio/frontend/src/features/chat/chat-settings-sheet.tsx b/studio/frontend/src/features/chat/chat-settings-sheet.tsx index 550df2bf7..bcfd58521 100644 --- a/studio/frontend/src/features/chat/chat-settings-sheet.tsx +++ b/studio/frontend/src/features/chat/chat-settings-sheet.tsx @@ -1,6 +1,11 @@ // SPDX-License-Identifier: AGPL-3.0-only // Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 +import { + Alert, + AlertDescription, + AlertTitle, +} from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; import { Dialog, @@ -10,7 +15,19 @@ import { DialogHeader, DialogTitle, } from "@/components/ui/dialog"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; import { Input } from "@/components/ui/input"; +import { + InputGroup, + InputGroupAddon, + InputGroupButton, + InputGroupInput, +} from "@/components/ui/input-group"; import { Select, SelectContent, @@ -29,6 +46,7 @@ import { Slider } from "@/components/ui/slider"; import { Switch } from "@/components/ui/switch"; import { Textarea } from "@/components/ui/textarea"; import { useIsMobile } from "@/hooks/use-mobile"; +import { cn } from "@/lib/utils"; import { ArrowDown01Icon, CodeIcon, @@ -43,7 +61,8 @@ import { import { HugeiconsIcon } from "@hugeicons/react"; import { AnimatePresence, motion } from "motion/react"; import type { ReactNode } from "react"; -import { useEffect, useMemo, useState } from "react"; +import { useEffect, useLayoutEffect, useMemo, useRef, useState } from "react"; +import { toast } from "sonner"; import { useChatRuntimeStore } from "./stores/chat-runtime-store"; import { DEFAULT_INFERENCE_PARAMS, @@ -58,6 +77,11 @@ export interface Preset { params: InferenceParams; } +interface LegacySystemPromptTemplate { + name: string; + content: string; +} + const BUILTIN_PRESETS: Preset[] = [ { name: "Default", params: { ...defaultInferenceParams } }, { @@ -86,19 +110,134 @@ const BUILTIN_PRESETS: Preset[] = [ const CHAT_PRESETS_KEY = "unsloth_chat_custom_presets"; const CHAT_ACTIVE_PRESET_KEY = "unsloth_chat_active_preset"; +const LEGACY_CHAT_SYSTEM_PROMPTS_KEY = "unsloth_chat_system_prompts"; +const LEGACY_CHAT_SYSTEM_PROMPTS_MIGRATED_KEY = + "unsloth_chat_system_prompts_migrated"; function canUseStorage(): boolean { return typeof window !== "undefined"; } +function getUniquePresetName(baseName: string, usedNames: Set): string { + const normalizedBase = baseName.trim() || "Imported Prompt"; + let nextName = normalizedBase; + let suffix = 2; + while (usedNames.has(nextName)) { + nextName = `${normalizedBase} ${suffix}`; + suffix += 1; + } + usedNames.add(nextName); + return nextName; +} + +function migrateLegacySystemPromptTemplates(presets: Preset[]): Preset[] { + if (!canUseStorage()) return presets; + try { + const raw = localStorage.getItem(LEGACY_CHAT_SYSTEM_PROMPTS_KEY); + if (!raw) return presets; + if (localStorage.getItem(LEGACY_CHAT_SYSTEM_PROMPTS_MIGRATED_KEY) === raw) { + return presets; + } + let parsed: unknown; + try { + parsed = JSON.parse(raw) as unknown; + } catch { + localStorage.removeItem(LEGACY_CHAT_SYSTEM_PROMPTS_KEY); + localStorage.setItem(LEGACY_CHAT_SYSTEM_PROMPTS_MIGRATED_KEY, raw); + return presets; + } + if (!Array.isArray(parsed)) { + localStorage.removeItem(LEGACY_CHAT_SYSTEM_PROMPTS_KEY); + localStorage.setItem(LEGACY_CHAT_SYSTEM_PROMPTS_MIGRATED_KEY, raw); + return presets; + } + const usedNames = new Set([ + ...BUILTIN_PRESETS.map((preset) => preset.name), + ...presets.map((preset) => preset.name), + ]); + const seenImportedConfigKeys = new Set( + [...BUILTIN_PRESETS, ...presets].map((preset) => + JSON.stringify({ + temperature: preset.params.temperature, + topP: preset.params.topP, + topK: preset.params.topK, + minP: preset.params.minP, + repetitionPenalty: preset.params.repetitionPenalty, + presencePenalty: preset.params.presencePenalty, + maxSeqLength: preset.params.maxSeqLength, + maxTokens: preset.params.maxTokens, + systemPrompt: preset.params.systemPrompt, + trustRemoteCode: preset.params.trustRemoteCode ?? false, + }), + ), + ); + const importedPresets = parsed + .filter((item): item is LegacySystemPromptTemplate => { + if (!item || typeof item !== "object") return false; + const maybe = item as Partial; + return ( + typeof maybe.name === "string" && typeof maybe.content === "string" + ); + }) + .map((template) => ({ + template, + importedParams: { + ...defaultInferenceParams, + systemPrompt: template.content, + }, + })) + .filter(({ importedParams }) => { + const configKey = JSON.stringify({ + temperature: importedParams.temperature, + topP: importedParams.topP, + topK: importedParams.topK, + minP: importedParams.minP, + repetitionPenalty: importedParams.repetitionPenalty, + presencePenalty: importedParams.presencePenalty, + maxSeqLength: importedParams.maxSeqLength, + maxTokens: importedParams.maxTokens, + systemPrompt: importedParams.systemPrompt, + trustRemoteCode: importedParams.trustRemoteCode ?? false, + }); + if (seenImportedConfigKeys.has(configKey)) return false; + seenImportedConfigKeys.add(configKey); + return true; + }) + .map(({ template, importedParams }) => ({ + name: getUniquePresetName(`${template.name} Prompt`, usedNames), + params: importedParams, + })); + if (importedPresets.length === 0) { + localStorage.removeItem(LEGACY_CHAT_SYSTEM_PROMPTS_KEY); + localStorage.setItem(LEGACY_CHAT_SYSTEM_PROMPTS_MIGRATED_KEY, raw); + return presets; + } + const mergedPresets = [...presets, ...importedPresets]; + localStorage.setItem(CHAT_PRESETS_KEY, JSON.stringify(mergedPresets)); + try { + localStorage.setItem(LEGACY_CHAT_SYSTEM_PROMPTS_MIGRATED_KEY, raw); + localStorage.removeItem(LEGACY_CHAT_SYSTEM_PROMPTS_KEY); + } catch { + // ignore cleanup failure after successful import write + } + return mergedPresets; + } catch { + return presets; + } +} + function loadSavedCustomPresets(): Preset[] { if (!canUseStorage()) return []; try { const raw = localStorage.getItem(CHAT_PRESETS_KEY); - if (!raw) return []; + if (!raw) { + return migrateLegacySystemPromptTemplates([]); + } const parsed = JSON.parse(raw) as unknown; - if (!Array.isArray(parsed)) return []; - return parsed + if (!Array.isArray(parsed)) { + return migrateLegacySystemPromptTemplates([]); + } + const presets = parsed .filter((item): item is Preset => { if (!item || typeof item !== "object") return false; const maybe = item as Partial; @@ -111,13 +250,10 @@ function loadSavedCustomPresets(): Preset[] { ...preset.params, }, })) - .filter( - (preset) => - preset.name.length > 0 && - !BUILTIN_PRESETS.some((builtin) => builtin.name === preset.name), - ); + .filter((preset) => preset.name.length > 0); + return migrateLegacySystemPromptTemplates(presets); } catch { - return []; + return migrateLegacySystemPromptTemplates([]); } } @@ -130,6 +266,82 @@ function loadSavedActivePreset(): string { } } +type PresetSaveMode = + | "disabled" + | "overwrite-active" + | "overwrite-other" + | "create"; + +interface PresetSaveState { + mode: PresetSaveMode; + canSubmit: boolean; + isSaveReady: boolean; + buttonLabel: string; + title: string; +} + +function isSamePresetConfig(a: InferenceParams, b: InferenceParams): boolean { + return ( + a.temperature === b.temperature && + a.topP === b.topP && + a.topK === b.topK && + a.minP === b.minP && + a.repetitionPenalty === b.repetitionPenalty && + a.presencePenalty === b.presencePenalty && + a.maxSeqLength === b.maxSeqLength && + a.maxTokens === b.maxTokens && + a.systemPrompt === b.systemPrompt && + (a.trustRemoteCode ?? false) === (b.trustRemoteCode ?? false) + ); +} + +function getPresetSaveState({ + rawName, + activePreset, + presets, + activePresetDirty, +}: { + rawName: string; + activePreset: string; + presets: Preset[]; + activePresetDirty: boolean; +}): PresetSaveState { + const trimmedName = rawName.trim(); + if (!trimmedName) { + return { + mode: "disabled", + canSubmit: false, + isSaveReady: false, + buttonLabel: "Save", + title: "Enter a preset name", + }; + } + + const matchingPreset = presets.find((preset) => preset.name === trimmedName); + if (matchingPreset) { + const isActiveMatch = matchingPreset.name === activePreset; + return { + mode: isActiveMatch ? "overwrite-active" : "overwrite-other", + canSubmit: !isActiveMatch || activePresetDirty, + isSaveReady: !isActiveMatch || activePresetDirty, + buttonLabel: isActiveMatch && !activePresetDirty ? "Saved" : "Overwrite", + title: isActiveMatch + ? activePresetDirty + ? "Save current settings to this preset" + : "No unsaved changes" + : `Overwrite preset "${trimmedName}"`, + }; + } + + return { + mode: "create", + canSubmit: true, + isSaveReady: true, + buttonLabel: "Save as New", + title: `Save current settings as "${trimmedName}"`, + }; +} + function ParamSlider({ label, value, @@ -286,6 +498,9 @@ export function ChatSettingsPanel({ (s) => s.loadedSpeculativeType, ); const currentModels = useChatRuntimeStore((s) => s.models); + const modelRequiresTrustRemoteCode = useChatRuntimeStore( + (s) => s.modelRequiresTrustRemoteCode, + ); const currentCheckpoint = params.checkpoint; const currentModelIsVision = currentModels.find((m) => m.id === currentCheckpoint)?.isVision ?? false; @@ -316,13 +531,57 @@ export function ChatSettingsPanel({ const [activePreset, setActivePreset] = useState(() => loadSavedActivePreset(), ); - const [savePresetOpen, setSavePresetOpen] = useState(false); - const [presetNameDraft, setPresetNameDraft] = useState(""); - const presets = useMemo( - () => [...BUILTIN_PRESETS, ...customPresets], - [customPresets], + const [presetNameInput, setPresetNameInput] = useState(() => + loadSavedActivePreset(), ); - const isBuiltinPreset = BUILTIN_PRESETS.some((p) => p.name === activePreset); + const presetControlRowRef = useRef(null); + const [presetMenuWidthPx, setPresetMenuWidthPx] = useState< + number | undefined + >(undefined); + const [systemPromptEditorOpen, setSystemPromptEditorOpen] = useState(false); + const [systemPromptDraft, setSystemPromptDraft] = useState(""); + const presets = useMemo(() => { + const overrides = new Set(customPresets.map((preset) => preset.name)); + return [ + ...BUILTIN_PRESETS.filter((preset) => !overrides.has(preset.name)), + ...customPresets, + ]; + }, [customPresets]); + const activePresetDefinition = useMemo( + () => presets.find((preset) => preset.name === activePreset) ?? null, + [activePreset, presets], + ); + const activeCustomPreset = useMemo( + () => customPresets.find((preset) => preset.name === activePreset) ?? null, + [activePreset, customPresets], + ); + const activeBuiltinPreset = useMemo( + () => + BUILTIN_PRESETS.find((preset) => preset.name === activePreset) ?? null, + [activePreset], + ); + const activePresetDirty = useMemo( + () => + activePresetDefinition == null + ? false + : !isSamePresetConfig(activePresetDefinition.params, params), + [activePresetDefinition, params], + ); + const presetSaveState = useMemo( + () => + getPresetSaveState({ + rawName: presetNameInput, + activePreset, + presets, + activePresetDirty, + }), + [activePreset, activePresetDirty, presetNameInput, presets], + ); + const systemPromptEditorDirty = systemPromptDraft !== params.systemPrompt; + const trustRemoteCodeMissing = + Boolean(currentCheckpoint) && + modelRequiresTrustRemoteCode && + !(params.trustRemoteCode ?? false); function set(key: K) { return (v: InferenceParams[K]) => onParamsChange({ ...params, [key]: v }); @@ -331,11 +590,19 @@ export function ChatSettingsPanel({ function applyPreset(name: string) { const p = presets.find((pr) => pr.name === name); if (p) { + if ( + modelRequiresTrustRemoteCode && + !(p.params.trustRemoteCode ?? false) + ) { + toast.warning("This configuration turns custom code off", { + description: + "The current model needs custom code enabled to load. Keep it on for this model.", + }); + return; + } onParamsChange({ ...p.params, - systemPrompt: params.systemPrompt, checkpoint: params.checkpoint, - trustRemoteCode: params.trustRemoteCode, }); setActivePreset(name); if (canUseStorage()) { @@ -348,32 +615,23 @@ export function ChatSettingsPanel({ } } - function openSavePresetDialog() { - setPresetNameDraft(activePreset === "Default" ? "" : activePreset); - setSavePresetOpen(true); - } - function savePresetWithName(rawName: string) { const trimmed = rawName.trim(); if (!trimmed) { - return; - } - if (BUILTIN_PRESETS.some((preset) => preset.name === trimmed)) { + toast.error("Enter a preset name"); return; } setCustomPresets((prev) => { - const next = [ - ...prev.filter((preset) => preset.name !== trimmed), - { name: trimmed, params: { ...params } }, - ]; + const next = prev.filter((p) => p.name !== trimmed); + const merged = [...next, { name: trimmed, params: { ...params } }]; if (canUseStorage()) { try { - localStorage.setItem(CHAT_PRESETS_KEY, JSON.stringify(next)); + localStorage.setItem(CHAT_PRESETS_KEY, JSON.stringify(merged)); } catch { // ignore } } - return next; + return merged; }); if (canUseStorage()) { try { @@ -383,11 +641,31 @@ export function ChatSettingsPanel({ } } setActivePreset(trimmed); - setSavePresetOpen(false); + setPresetNameInput(trimmed); } function deletePreset(name: string) { - if (BUILTIN_PRESETS.some((p) => p.name === name)) { + const hasCustomPreset = customPresets.some( + (preset) => preset.name === name, + ); + if (!hasCustomPreset) { + return; + } + const builtinPreset = BUILTIN_PRESETS.find((preset) => preset.name === name); + const fallbackPreset = + builtinPreset ?? + BUILTIN_PRESETS.find((preset) => preset.name === "Default") ?? + null; + if ( + activePreset === name && + fallbackPreset && + modelRequiresTrustRemoteCode && + !(fallbackPreset.params.trustRemoteCode ?? false) + ) { + toast.warning("Reset would turn custom code off", { + description: + "The current model needs custom code enabled to load. Keep it on for this model.", + }); return; } setCustomPresets((prev) => { @@ -402,17 +680,33 @@ export function ChatSettingsPanel({ return next; }); if (activePreset === name) { - setActivePreset("Default"); - if (canUseStorage()) { - try { - localStorage.setItem(CHAT_ACTIVE_PRESET_KEY, "Default"); - } catch { - // ignore + if (fallbackPreset) { + onParamsChange({ + ...fallbackPreset.params, + checkpoint: params.checkpoint, + }); + setActivePreset(fallbackPreset.name); + if (canUseStorage()) { + try { + localStorage.setItem(CHAT_ACTIVE_PRESET_KEY, fallbackPreset.name); + } catch { + // ignore + } } } } } + function openSystemPromptEditor() { + setSystemPromptDraft(params.systemPrompt); + setSystemPromptEditorOpen(true); + } + + function saveSystemPromptEditor() { + set("systemPrompt")(systemPromptDraft); + setSystemPromptEditorOpen(false); + } + useEffect(() => { if (presets.some((preset) => preset.name === activePreset)) return; setActivePreset("Default"); @@ -425,6 +719,28 @@ export function ChatSettingsPanel({ } }, [activePreset, presets]); + useEffect(() => { + setPresetNameInput(activePreset); + }, [activePreset]); + + useEffect(() => { + if (!open) { + setSystemPromptEditorOpen(false); + } + }, [open]); + + useLayoutEffect(() => { + const el = presetControlRowRef.current; + if (!el || !open) return; + const measure = () => { + setPresetMenuWidthPx(el.getBoundingClientRect().width); + }; + measure(); + const ro = new ResizeObserver(measure); + ro.observe(el); + return () => ro.disconnect(); + }, [open]); + const settingsContent = ( <>
@@ -440,52 +756,138 @@ export function ChatSettingsPanel({
{/* mt-4 matches the Playground sidebar gap (SidebarHeader py-3 + SidebarGroup pt-1) */}
-
- - - +
+
+ + + setPresetNameInput(e.target.value)} + onKeyDown={(e) => { + if (e.key === "Enter" && presetSaveState.canSubmit) { + e.preventDefault(); + savePresetWithName(presetNameInput); + } + }} + placeholder="Preset name" + maxLength={80} + autoComplete="off" + className={cn( + "!h-8 min-h-0 min-w-0 self-stretch !pl-2.5 !pr-2 pt-1 pb-1 text-sm leading-10 md:text-sm", + presetSaveState.isSaveReady && + "text-foreground placeholder:text-primary/45", + )} + aria-label="Inference preset name" + /> + + + + + + + + + + {presets.map((p) => ( + applyPreset(p.name)} + > + {p.name} + + ))} + + +
+
+ + +
- +
+ + +