fix(studio): prevent small models from stalling on tool-calling tasks (#4769)

* fix(studio): prevent small models from stalling on tool-calling tasks

Small GGUF models (< 9B params) in "Think, Search, Code" mode would
often describe what they planned to do ("Let me create this dashboard")
and then stop generating without ever calling a tool.

Three changes:

1. Simplify web_tips for small models: remove the "fetch its full content
   by calling web_search with the url parameter" guidance for models < 9B.
   This multi-step instruction causes small models to plan elaborate
   search-then-fetch-then-code sequences they cannot reliably execute.

2. Add "always call tools directly" imperative to the system prompt nudge
   so models act immediately instead of narrating their intentions.

3. Add plan-without-action re-prompt in the agentic loop: when the model
   emits planning text (matching patterns like "let me", "I'll", etc.)
   without calling any tool, inject a nudge asking it to call the tool
   and continue the loop. Capped at 2 re-prompts per request.

Benchmarked with Qwen3.5-4B-GGUF (N=5 trials per variant):
- Baseline: 40% of requests had any tool call
- Combined fix: 100% of requests had at least one tool call

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Daniel Han 2026-04-02 02:11:07 -07:00 committed by GitHub
parent dc0729aadf
commit e4d1499230
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 189 additions and 53 deletions

View file

@ -27,6 +27,52 @@ import httpx
logger = get_logger(__name__)
# ── Pre-compiled patterns for plan-without-action re-prompt ──
# Forward-looking intent signals that indicate the model is
# describing what it *will* do rather than giving a final answer.
_INTENT_SIGNAL = re.compile(
r"(?i)("
# Direct intent: "I'll ...", "I will ...", "Let me ...", "I am going to ..."
# Handles both straight and curly apostrophes.
# Excludes "I can", "I should", "I want to", "let's" which
# appear frequently in direct answers / explanations.
r"\b(i['\u2019](ll|m going to|m gonna)|i am (going to|gonna)|i will|i shall|let me|allow me)\b"
r"|"
# Step/plan framing: "First ...", "Step 1:", "Here's my plan"
r"\b(?:first\b|step \d+:?|here['\u2019]?s (?:my |the |a )?(?:plan|approach))"
r"|"
# "Now I" / "Next I" patterns
r"\b(?:now i|next i)\b"
r")"
)
_MAX_REPROMPTS = 1
_REPROMPT_MAX_CHARS = 500
# ── Pre-compiled patterns for GGUF shard detection ───────────
_SHARD_FULL_RE = re.compile(r"^(.*)-(\d{5})-of-(\d{5})\.gguf$")
_SHARD_RE = re.compile(r"^(.*)-\d{5}-of-\d{5}\.gguf$")
# Model size extraction (shared with routes/inference.py)
from utils.models import extract_model_size_b as _extract_model_size_b
# ── Pre-compiled patterns for tool XML stripping ─────────────
_TOOL_CLOSED_PATS = [
re.compile(r"<tool_call>.*?</tool_call>", re.DOTALL),
re.compile(r"<function=\w+>.*?</function>", re.DOTALL),
]
_TOOL_ALL_PATS = _TOOL_CLOSED_PATS + [
re.compile(r"<tool_call>.*$", re.DOTALL),
re.compile(r"<function=\w+>.*$", re.DOTALL),
]
# ── Pre-compiled patterns for tool-call XML parsing ──────────
_TC_JSON_START_RE = re.compile(r"<tool_call>\s*\{")
_TC_FUNC_START_RE = re.compile(r"<function=(\w+)>\s*")
_TC_END_TAG_RE = re.compile(r"</tool_call>")
_TC_FUNC_CLOSE_RE = re.compile(r"\s*</function>\s*$")
_TC_PARAM_START_RE = re.compile(r"<parameter=(\w+)>\s*")
_TC_PARAM_CLOSE_RE = re.compile(r"\s*</parameter>\s*$")
class LlamaCppBackend:
"""
@ -242,14 +288,11 @@ class LlamaCppBackend:
@staticmethod
def _get_gguf_size_bytes(model_path: str) -> int:
"""Get total GGUF size in bytes, including split shards."""
import re
main = Path(model_path)
total = main.stat().st_size
# Check for split shards (e.g., model-00001-of-00003.gguf)
shard_pat = re.compile(r"^(.*)-(\d{5})-of-(\d{5})\.gguf$")
m = shard_pat.match(main.name)
m = _SHARD_FULL_RE.match(main.name)
if m:
prefix, _, num_total = m.group(1), m.group(2), m.group(3)
sibling_pat = re.compile(
@ -539,8 +582,6 @@ class LlamaCppBackend:
Returns (first_shard_filename, total_size_bytes) or None if nothing fits.
"""
import re
try:
from huggingface_hub import get_paths_info, list_repo_files
@ -556,10 +597,9 @@ class LlamaCppBackend:
size_map = {p.path: (p.size or 0) for p in path_infos}
# Group files by variant: shards share a prefix before -NNNNN-of-NNNNN
shard_pat = re.compile(r"^(.*)-\d{5}-of-\d{5}\.gguf$")
variants: dict[str, list[str]] = {}
for f in gguf_files:
m = shard_pat.match(f)
m = _SHARD_RE.match(f)
key = m.group(1) if m else f
variants.setdefault(key, []).append(f)
@ -810,7 +850,6 @@ class LlamaCppBackend:
gguf_extra_shards: list[str] = []
if hf_variant:
try:
import re
from huggingface_hub import list_repo_files
files = list_repo_files(hf_repo, token = hf_token)
@ -825,11 +864,10 @@ class LlamaCppBackend:
)
if gguf_files:
gguf_filename = gguf_files[0]
shard_pat = re.compile(r"^(.*)-\d{5}-of-(\d{5})\.gguf$")
m = shard_pat.match(gguf_filename)
m = _SHARD_FULL_RE.match(gguf_filename)
if m:
prefix = m.group(1)
total = m.group(2)
total = m.group(3)
sibling_pat = re.compile(
r"^"
+ re.escape(prefix)
@ -886,10 +924,7 @@ class LlamaCppBackend:
f"falling back to {fallback_file} ({fallback_size / (1024**3):.1f} GB)"
)
gguf_filename = fallback_file
import re as _re
_shard_pat = _re.compile(r"^(.*)-\d{5}-of-\d{5}\.gguf$")
_m = _shard_pat.match(gguf_filename)
_m = _SHARD_RE.match(gguf_filename)
_prefix = _m.group(1) if _m else None
if _prefix:
gguf_extra_shards = sorted(
@ -1292,17 +1327,12 @@ class LlamaCppBackend:
# Qwen3.5 models below 9B (0.8B, 2B, 4B) disable thinking by default.
# Only 9B and larger enable thinking.
if self._supports_reasoning:
import re
thinking_default = True
mid = (model_identifier or "").lower()
if "qwen3.5" in mid:
# Extract size like "0.8b", "4b", "35b" etc.
size_match = re.search(r"(\d+\.?\d*)\s*b", mid)
if size_match:
size_val = float(size_match.group(1))
if size_val < 9:
thinking_default = False
size_val = _extract_model_size_b(mid)
if size_val is not None and size_val < 9:
thinking_default = False
self._reasoning_default = thinking_default
cmd.extend(
[
@ -1775,13 +1805,11 @@ class LlamaCppBackend:
Closing tags (</tool_call>, </function>, </parameter>) are all optional
since models frequently omit them.
"""
import re
tool_calls = []
# Pattern 1: JSON inside <tool_call> tags.
# Use balanced-brace extraction that skips braces inside JSON strings.
for m in re.finditer(r"<tool_call>\s*\{", content):
for m in _TC_JSON_START_RE.finditer(content):
brace_start = m.end() - 1 # position of the opening {
depth, i = 0, brace_start
in_string = False
@ -1831,7 +1859,7 @@ class LlamaCppBackend:
# boundaries. We avoid using </function> as a boundary because
# code parameter values can contain that literal string.
# After extracting, we trim a trailing </function> if present.
func_starts = list(re.finditer(r"<function=(\w+)>\s*", content))
func_starts = list(_TC_FUNC_START_RE.finditer(content))
for idx, fm in enumerate(func_starts):
func_name = fm.group(1)
body_start = fm.end()
@ -1841,7 +1869,7 @@ class LlamaCppBackend:
if idx + 1 < len(func_starts)
else len(content)
)
end_tag = re.search(r"</tool_call>", content[body_start:])
end_tag = _TC_END_TAG_RE.search(content[body_start:])
if end_tag:
body_end = body_start + end_tag.start()
else:
@ -1849,20 +1877,20 @@ class LlamaCppBackend:
body_end = min(body_end, next_func)
body = content[body_start:body_end]
# Trim trailing </function> if present (it's the real closing tag)
body = re.sub(r"\s*</function>\s*$", "", body)
body = _TC_FUNC_CLOSE_RE.sub("", body)
# Step 2: Extract parameters from body.
# For single-parameter functions (the common case: code, command,
# query), use body end as the only boundary to avoid false matches
# on </parameter> inside code strings.
arguments = {}
param_starts = list(re.finditer(r"<parameter=(\w+)>\s*", body))
param_starts = list(_TC_PARAM_START_RE.finditer(body))
if len(param_starts) == 1:
# Single parameter: value is everything from after the tag
# to end of body, trimming any trailing </parameter>.
pm = param_starts[0]
val = body[pm.end() :]
val = re.sub(r"\s*</parameter>\s*$", "", val)
val = _TC_PARAM_CLOSE_RE.sub("", val)
arguments[pm.group(1)] = val.strip()
else:
for pidx, pm in enumerate(param_starts):
@ -1876,7 +1904,7 @@ class LlamaCppBackend:
)
val = body[val_start:next_param]
# Trim trailing </parameter> if present
val = re.sub(r"\s*</parameter>\s*$", "", val)
val = _TC_PARAM_CLOSE_RE.sub("", val)
arguments[param_name] = val.strip()
tc = {
@ -2249,22 +2277,10 @@ class LlamaCppBackend:
_accumulated_predicted_ms = 0.0
_accumulated_predicted_n = 0
# ── Shared patterns for stripping tool XML from streamed content ──
import re as _re_tool
_TOOL_CLOSED_PATTERNS = [
_re_tool.compile(r"<tool_call>.*?</tool_call>", _re_tool.DOTALL),
_re_tool.compile(r"<function=\w+>.*?</function>", _re_tool.DOTALL),
]
_TOOL_ALL_PATTERNS = _TOOL_CLOSED_PATTERNS + [
_re_tool.compile(r"<tool_call>.*$", _re_tool.DOTALL),
_re_tool.compile(r"<function=\w+>.*$", _re_tool.DOTALL),
]
def _strip_tool_markup(text: str, *, final: bool = False) -> str:
if not auto_heal_tool_calls:
return text
patterns = _TOOL_ALL_PATTERNS if final else _TOOL_CLOSED_PATTERNS
patterns = _TOOL_ALL_PATS if final else _TOOL_CLOSED_PATS
for pat in patterns:
text = pat.sub("", text)
return text.strip() if final else text
@ -2284,7 +2300,19 @@ class LlamaCppBackend:
# identical call succeeded).
_tool_call_history: list[tuple[str, bool]] = [] # (key, failed)
for iteration in range(max_tool_iterations):
# ── Re-prompt on plan-without-action ─────────────────
# When the model describes what it intends to do (forward-looking
# language) without actually calling a tool, re-prompt once.
# Only triggers on responses that signal intent/planning -- a
# direct answer like "4" or "Hello!" will not match.
# Pattern is compiled once at module level (_INTENT_SIGNAL).
_reprompt_count = 0
# Reserve extra iterations for re-prompts so they don't
# consume the caller's tool-call budget. Only add the
# extra slot when tool iterations are actually allowed.
_extra = _MAX_REPROMPTS if max_tool_iterations > 0 else 0
for iteration in range(max_tool_iterations + _extra):
if cancel_event is not None and cancel_event.is_set():
return
@ -2595,6 +2623,56 @@ class LlamaCppBackend:
content_accum,
)
if not _safety_tc:
# ── Re-prompt on plan-without-action ──
# If the model described what it intends to do
# (forward-looking language) without calling any
# tool, nudge it to act. Only fires once per
# request and only on short responses that
# contain intent signals -- a direct answer
# like "4" or "Hello!" won't trigger this.
# Use content if available, otherwise fall back
# to reasoning text (reasoning-only stalls).
_stripped = content_accum.strip()
if not _stripped:
_stripped = reasoning_accum.strip()
if (
tools
and _reprompt_count < _MAX_REPROMPTS
and 0 < len(_stripped) < _REPROMPT_MAX_CHARS
and _INTENT_SIGNAL.search(_stripped)
):
_reprompt_count += 1
logger.info(
f"Re-prompt {_reprompt_count}/{_MAX_REPROMPTS}: "
f"model responded without calling tools "
f"({len(_stripped)} chars)"
)
conversation.append(
{
"role": "assistant",
"content": _stripped,
}
)
conversation.append(
{
"role": "user",
"content": (
"Please use the available tools to complete "
"the task instead of describing what to do."
),
}
)
# Accumulate tokens and timing from this iteration
_fu_r = _iter_usage or {}
_accumulated_completion_tokens += _fu_r.get(
"completion_tokens", 0
)
_it_r = _iter_timings or {}
_accumulated_predicted_ms += _it_r.get("predicted_ms", 0)
_accumulated_predicted_n += _it_r.get("predicted_n", 0)
yield {"type": "status", "text": ""}
continue
# Content was already streamed. Yield metadata.
yield {"type": "status", "text": ""}
_fu = _iter_usage or {}

View file

@ -21,6 +21,9 @@ import threading
import re as _re
# Model size extraction (shared with core/inference/llama_cpp.py)
from utils.models import extract_model_size_b as _extract_model_size_b
def _friendly_error(exc: Exception) -> str:
"""Extract a user-friendly message from known llama-server errors."""
@ -90,6 +93,12 @@ from datetime import date as _date
router = APIRouter()
# Appended to tool-use nudge to discourage plan-without-action
_TOOL_ACTION_NUDGE = (
" Always call tools directly."
" Never describe what you plan to do -- just call the tool immediately."
)
# Regex for stripping leaked tool-call XML from assistant messages/stream
_TOOL_XML_RE = _re.compile(
r"<tool_call>.*?</tool_call>|<function=\w+>.*?</function>",
@ -1095,12 +1104,20 @@ async def openai_chat_completions(
_date_line = f"The current date is {_date.today().isoformat()}."
_web_tips = (
"When you search and find a relevant URL in the results, "
"fetch its full content by calling web_search with the url parameter. "
"Do not repeat the same search query. If a search returns "
"no useful results, try rephrasing or fetching a result URL directly."
)
# Small models (<9B) struggle with multi-step search plans,
# so simplify the web tips to avoid plan-then-stall behavior.
_model_size_b = _extract_model_size_b(model_name)
_is_small_model = _model_size_b is not None and _model_size_b < 9
if _is_small_model:
_web_tips = "Do not repeat the same search query."
else:
_web_tips = (
"When you search and find a relevant URL in the results, "
"fetch its full content by calling web_search with the url parameter. "
"Do not repeat the same search query. If a search returns "
"no useful results, try rephrasing or fetching a result URL directly."
)
_code_tips = (
"Use code execution for math, calculations, data processing, "
"or to parse and analyze information from tool results."
@ -1132,6 +1149,7 @@ async def openai_chat_completions(
_nudge = ""
if _nudge:
_nudge += _TOOL_ACTION_NUDGE
# Append nudge to system prompt (preserve user's prompt)
if system_prompt:
system_prompt = system_prompt.rstrip() + "\n\n" + _nudge
@ -1208,7 +1226,14 @@ async def openai_chat_completions(
break
if event["type"] == "status":
# Empty status marks an iteration boundary
# in the GGUF tool loop (e.g. after a
# re-prompt). Reset the cumulative cursor
# so the next assistant turn streams cleanly.
if not event["text"]:
prev_text = ""
# Emit tool status as a custom SSE event
# (including empty ones to clear UI badges)
status_data = json.dumps(
{
"type": "tool_status",

View file

@ -19,6 +19,7 @@ from .model_config import (
get_base_model_from_lora,
load_model_config,
list_gguf_variants,
extract_model_size_b,
MODEL_NAME_MAPPING,
UI_STATUS_INDICATORS,
)
@ -38,6 +39,7 @@ __all__ = [
"get_base_model_from_lora",
"load_model_config",
"list_gguf_variants",
"extract_model_size_b",
"MODEL_NAME_MAPPING",
"UI_STATUS_INDICATORS",
"scan_checkpoints",

View file

@ -31,6 +31,37 @@ import yaml
logger = get_logger(__name__)
# ── Model size extraction ────────────────────────────────────
import re as _re
_MODEL_SIZE_RE = _re.compile(
r"(?:^|[-_/])(\d+\.?\d*)\s*([bm])(?:$|[-_/])", _re.IGNORECASE
)
# MoE active-parameter pattern: matches "A3B", "A3.5B", etc.
_ACTIVE_SIZE_RE = _re.compile(
r"(?:^|[-_/])a(\d+\.?\d*)\s*([bm])(?:$|[-_/])", _re.IGNORECASE
)
def extract_model_size_b(model_id: str) -> float | None:
"""Extract model size in billions from a model identifier.
Prefers MoE active-parameter notation (e.g. ``A3B`` in
``Qwen3.5-35B-A3B``) over the total parameter count.
Handles both ``B`` (billions) and ``M`` (millions) suffixes.
"""
mid = (model_id or "").lower()
active = _ACTIVE_SIZE_RE.search(mid)
if active:
val = float(active.group(1))
return val / 1000.0 if active.group(2).lower() == "m" else val
size = _MODEL_SIZE_RE.search(mid)
if not size:
return None
val = float(size.group(1))
return val / 1000.0 if size.group(2).lower() == "m" else val
# Model name mapping: maps all equivalent model names to their canonical YAML config file
# Format: "canonical_model_name.yaml": [list of all equivalent model names]
# Based on the model mapper provided - canonical filename is based on the first model name in the mapper