mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
fix(studio): prevent small models from stalling on tool-calling tasks (#4769)
* fix(studio): prevent small models from stalling on tool-calling tasks
Small GGUF models (< 9B params) in "Think, Search, Code" mode would
often describe what they planned to do ("Let me create this dashboard")
and then stop generating without ever calling a tool.
Three changes:
1. Simplify web_tips for small models: remove the "fetch its full content
by calling web_search with the url parameter" guidance for models < 9B.
This multi-step instruction causes small models to plan elaborate
search-then-fetch-then-code sequences they cannot reliably execute.
2. Add "always call tools directly" imperative to the system prompt nudge
so models act immediately instead of narrating their intentions.
3. Add plan-without-action re-prompt in the agentic loop: when the model
emits planning text (matching patterns like "let me", "I'll", etc.)
without calling any tool, inject a nudge asking it to call the tool
and continue the loop. Capped at 2 re-prompts per request.
Benchmarked with Qwen3.5-4B-GGUF (N=5 trials per variant):
- Baseline: 40% of requests had any tool call
- Combined fix: 100% of requests had at least one tool call
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
dc0729aadf
commit
e4d1499230
4 changed files with 189 additions and 53 deletions
|
|
@ -27,6 +27,52 @@ import httpx
|
|||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# ── Pre-compiled patterns for plan-without-action re-prompt ──
|
||||
# Forward-looking intent signals that indicate the model is
|
||||
# describing what it *will* do rather than giving a final answer.
|
||||
_INTENT_SIGNAL = re.compile(
|
||||
r"(?i)("
|
||||
# Direct intent: "I'll ...", "I will ...", "Let me ...", "I am going to ..."
|
||||
# Handles both straight and curly apostrophes.
|
||||
# Excludes "I can", "I should", "I want to", "let's" which
|
||||
# appear frequently in direct answers / explanations.
|
||||
r"\b(i['\u2019](ll|m going to|m gonna)|i am (going to|gonna)|i will|i shall|let me|allow me)\b"
|
||||
r"|"
|
||||
# Step/plan framing: "First ...", "Step 1:", "Here's my plan"
|
||||
r"\b(?:first\b|step \d+:?|here['\u2019]?s (?:my |the |a )?(?:plan|approach))"
|
||||
r"|"
|
||||
# "Now I" / "Next I" patterns
|
||||
r"\b(?:now i|next i)\b"
|
||||
r")"
|
||||
)
|
||||
_MAX_REPROMPTS = 1
|
||||
_REPROMPT_MAX_CHARS = 500
|
||||
|
||||
# ── Pre-compiled patterns for GGUF shard detection ───────────
|
||||
_SHARD_FULL_RE = re.compile(r"^(.*)-(\d{5})-of-(\d{5})\.gguf$")
|
||||
_SHARD_RE = re.compile(r"^(.*)-\d{5}-of-\d{5}\.gguf$")
|
||||
|
||||
# Model size extraction (shared with routes/inference.py)
|
||||
from utils.models import extract_model_size_b as _extract_model_size_b
|
||||
|
||||
# ── Pre-compiled patterns for tool XML stripping ─────────────
|
||||
_TOOL_CLOSED_PATS = [
|
||||
re.compile(r"<tool_call>.*?</tool_call>", re.DOTALL),
|
||||
re.compile(r"<function=\w+>.*?</function>", re.DOTALL),
|
||||
]
|
||||
_TOOL_ALL_PATS = _TOOL_CLOSED_PATS + [
|
||||
re.compile(r"<tool_call>.*$", re.DOTALL),
|
||||
re.compile(r"<function=\w+>.*$", re.DOTALL),
|
||||
]
|
||||
|
||||
# ── Pre-compiled patterns for tool-call XML parsing ──────────
|
||||
_TC_JSON_START_RE = re.compile(r"<tool_call>\s*\{")
|
||||
_TC_FUNC_START_RE = re.compile(r"<function=(\w+)>\s*")
|
||||
_TC_END_TAG_RE = re.compile(r"</tool_call>")
|
||||
_TC_FUNC_CLOSE_RE = re.compile(r"\s*</function>\s*$")
|
||||
_TC_PARAM_START_RE = re.compile(r"<parameter=(\w+)>\s*")
|
||||
_TC_PARAM_CLOSE_RE = re.compile(r"\s*</parameter>\s*$")
|
||||
|
||||
|
||||
class LlamaCppBackend:
|
||||
"""
|
||||
|
|
@ -242,14 +288,11 @@ class LlamaCppBackend:
|
|||
@staticmethod
|
||||
def _get_gguf_size_bytes(model_path: str) -> int:
|
||||
"""Get total GGUF size in bytes, including split shards."""
|
||||
import re
|
||||
|
||||
main = Path(model_path)
|
||||
total = main.stat().st_size
|
||||
|
||||
# Check for split shards (e.g., model-00001-of-00003.gguf)
|
||||
shard_pat = re.compile(r"^(.*)-(\d{5})-of-(\d{5})\.gguf$")
|
||||
m = shard_pat.match(main.name)
|
||||
m = _SHARD_FULL_RE.match(main.name)
|
||||
if m:
|
||||
prefix, _, num_total = m.group(1), m.group(2), m.group(3)
|
||||
sibling_pat = re.compile(
|
||||
|
|
@ -539,8 +582,6 @@ class LlamaCppBackend:
|
|||
|
||||
Returns (first_shard_filename, total_size_bytes) or None if nothing fits.
|
||||
"""
|
||||
import re
|
||||
|
||||
try:
|
||||
from huggingface_hub import get_paths_info, list_repo_files
|
||||
|
||||
|
|
@ -556,10 +597,9 @@ class LlamaCppBackend:
|
|||
size_map = {p.path: (p.size or 0) for p in path_infos}
|
||||
|
||||
# Group files by variant: shards share a prefix before -NNNNN-of-NNNNN
|
||||
shard_pat = re.compile(r"^(.*)-\d{5}-of-\d{5}\.gguf$")
|
||||
variants: dict[str, list[str]] = {}
|
||||
for f in gguf_files:
|
||||
m = shard_pat.match(f)
|
||||
m = _SHARD_RE.match(f)
|
||||
key = m.group(1) if m else f
|
||||
variants.setdefault(key, []).append(f)
|
||||
|
||||
|
|
@ -810,7 +850,6 @@ class LlamaCppBackend:
|
|||
gguf_extra_shards: list[str] = []
|
||||
if hf_variant:
|
||||
try:
|
||||
import re
|
||||
from huggingface_hub import list_repo_files
|
||||
|
||||
files = list_repo_files(hf_repo, token = hf_token)
|
||||
|
|
@ -825,11 +864,10 @@ class LlamaCppBackend:
|
|||
)
|
||||
if gguf_files:
|
||||
gguf_filename = gguf_files[0]
|
||||
shard_pat = re.compile(r"^(.*)-\d{5}-of-(\d{5})\.gguf$")
|
||||
m = shard_pat.match(gguf_filename)
|
||||
m = _SHARD_FULL_RE.match(gguf_filename)
|
||||
if m:
|
||||
prefix = m.group(1)
|
||||
total = m.group(2)
|
||||
total = m.group(3)
|
||||
sibling_pat = re.compile(
|
||||
r"^"
|
||||
+ re.escape(prefix)
|
||||
|
|
@ -886,10 +924,7 @@ class LlamaCppBackend:
|
|||
f"falling back to {fallback_file} ({fallback_size / (1024**3):.1f} GB)"
|
||||
)
|
||||
gguf_filename = fallback_file
|
||||
import re as _re
|
||||
|
||||
_shard_pat = _re.compile(r"^(.*)-\d{5}-of-\d{5}\.gguf$")
|
||||
_m = _shard_pat.match(gguf_filename)
|
||||
_m = _SHARD_RE.match(gguf_filename)
|
||||
_prefix = _m.group(1) if _m else None
|
||||
if _prefix:
|
||||
gguf_extra_shards = sorted(
|
||||
|
|
@ -1292,17 +1327,12 @@ class LlamaCppBackend:
|
|||
# Qwen3.5 models below 9B (0.8B, 2B, 4B) disable thinking by default.
|
||||
# Only 9B and larger enable thinking.
|
||||
if self._supports_reasoning:
|
||||
import re
|
||||
|
||||
thinking_default = True
|
||||
mid = (model_identifier or "").lower()
|
||||
if "qwen3.5" in mid:
|
||||
# Extract size like "0.8b", "4b", "35b" etc.
|
||||
size_match = re.search(r"(\d+\.?\d*)\s*b", mid)
|
||||
if size_match:
|
||||
size_val = float(size_match.group(1))
|
||||
if size_val < 9:
|
||||
thinking_default = False
|
||||
size_val = _extract_model_size_b(mid)
|
||||
if size_val is not None and size_val < 9:
|
||||
thinking_default = False
|
||||
self._reasoning_default = thinking_default
|
||||
cmd.extend(
|
||||
[
|
||||
|
|
@ -1775,13 +1805,11 @@ class LlamaCppBackend:
|
|||
Closing tags (</tool_call>, </function>, </parameter>) are all optional
|
||||
since models frequently omit them.
|
||||
"""
|
||||
import re
|
||||
|
||||
tool_calls = []
|
||||
|
||||
# Pattern 1: JSON inside <tool_call> tags.
|
||||
# Use balanced-brace extraction that skips braces inside JSON strings.
|
||||
for m in re.finditer(r"<tool_call>\s*\{", content):
|
||||
for m in _TC_JSON_START_RE.finditer(content):
|
||||
brace_start = m.end() - 1 # position of the opening {
|
||||
depth, i = 0, brace_start
|
||||
in_string = False
|
||||
|
|
@ -1831,7 +1859,7 @@ class LlamaCppBackend:
|
|||
# boundaries. We avoid using </function> as a boundary because
|
||||
# code parameter values can contain that literal string.
|
||||
# After extracting, we trim a trailing </function> if present.
|
||||
func_starts = list(re.finditer(r"<function=(\w+)>\s*", content))
|
||||
func_starts = list(_TC_FUNC_START_RE.finditer(content))
|
||||
for idx, fm in enumerate(func_starts):
|
||||
func_name = fm.group(1)
|
||||
body_start = fm.end()
|
||||
|
|
@ -1841,7 +1869,7 @@ class LlamaCppBackend:
|
|||
if idx + 1 < len(func_starts)
|
||||
else len(content)
|
||||
)
|
||||
end_tag = re.search(r"</tool_call>", content[body_start:])
|
||||
end_tag = _TC_END_TAG_RE.search(content[body_start:])
|
||||
if end_tag:
|
||||
body_end = body_start + end_tag.start()
|
||||
else:
|
||||
|
|
@ -1849,20 +1877,20 @@ class LlamaCppBackend:
|
|||
body_end = min(body_end, next_func)
|
||||
body = content[body_start:body_end]
|
||||
# Trim trailing </function> if present (it's the real closing tag)
|
||||
body = re.sub(r"\s*</function>\s*$", "", body)
|
||||
body = _TC_FUNC_CLOSE_RE.sub("", body)
|
||||
|
||||
# Step 2: Extract parameters from body.
|
||||
# For single-parameter functions (the common case: code, command,
|
||||
# query), use body end as the only boundary to avoid false matches
|
||||
# on </parameter> inside code strings.
|
||||
arguments = {}
|
||||
param_starts = list(re.finditer(r"<parameter=(\w+)>\s*", body))
|
||||
param_starts = list(_TC_PARAM_START_RE.finditer(body))
|
||||
if len(param_starts) == 1:
|
||||
# Single parameter: value is everything from after the tag
|
||||
# to end of body, trimming any trailing </parameter>.
|
||||
pm = param_starts[0]
|
||||
val = body[pm.end() :]
|
||||
val = re.sub(r"\s*</parameter>\s*$", "", val)
|
||||
val = _TC_PARAM_CLOSE_RE.sub("", val)
|
||||
arguments[pm.group(1)] = val.strip()
|
||||
else:
|
||||
for pidx, pm in enumerate(param_starts):
|
||||
|
|
@ -1876,7 +1904,7 @@ class LlamaCppBackend:
|
|||
)
|
||||
val = body[val_start:next_param]
|
||||
# Trim trailing </parameter> if present
|
||||
val = re.sub(r"\s*</parameter>\s*$", "", val)
|
||||
val = _TC_PARAM_CLOSE_RE.sub("", val)
|
||||
arguments[param_name] = val.strip()
|
||||
|
||||
tc = {
|
||||
|
|
@ -2249,22 +2277,10 @@ class LlamaCppBackend:
|
|||
_accumulated_predicted_ms = 0.0
|
||||
_accumulated_predicted_n = 0
|
||||
|
||||
# ── Shared patterns for stripping tool XML from streamed content ──
|
||||
import re as _re_tool
|
||||
|
||||
_TOOL_CLOSED_PATTERNS = [
|
||||
_re_tool.compile(r"<tool_call>.*?</tool_call>", _re_tool.DOTALL),
|
||||
_re_tool.compile(r"<function=\w+>.*?</function>", _re_tool.DOTALL),
|
||||
]
|
||||
_TOOL_ALL_PATTERNS = _TOOL_CLOSED_PATTERNS + [
|
||||
_re_tool.compile(r"<tool_call>.*$", _re_tool.DOTALL),
|
||||
_re_tool.compile(r"<function=\w+>.*$", _re_tool.DOTALL),
|
||||
]
|
||||
|
||||
def _strip_tool_markup(text: str, *, final: bool = False) -> str:
|
||||
if not auto_heal_tool_calls:
|
||||
return text
|
||||
patterns = _TOOL_ALL_PATTERNS if final else _TOOL_CLOSED_PATTERNS
|
||||
patterns = _TOOL_ALL_PATS if final else _TOOL_CLOSED_PATS
|
||||
for pat in patterns:
|
||||
text = pat.sub("", text)
|
||||
return text.strip() if final else text
|
||||
|
|
@ -2284,7 +2300,19 @@ class LlamaCppBackend:
|
|||
# identical call succeeded).
|
||||
_tool_call_history: list[tuple[str, bool]] = [] # (key, failed)
|
||||
|
||||
for iteration in range(max_tool_iterations):
|
||||
# ── Re-prompt on plan-without-action ─────────────────
|
||||
# When the model describes what it intends to do (forward-looking
|
||||
# language) without actually calling a tool, re-prompt once.
|
||||
# Only triggers on responses that signal intent/planning -- a
|
||||
# direct answer like "4" or "Hello!" will not match.
|
||||
# Pattern is compiled once at module level (_INTENT_SIGNAL).
|
||||
_reprompt_count = 0
|
||||
|
||||
# Reserve extra iterations for re-prompts so they don't
|
||||
# consume the caller's tool-call budget. Only add the
|
||||
# extra slot when tool iterations are actually allowed.
|
||||
_extra = _MAX_REPROMPTS if max_tool_iterations > 0 else 0
|
||||
for iteration in range(max_tool_iterations + _extra):
|
||||
if cancel_event is not None and cancel_event.is_set():
|
||||
return
|
||||
|
||||
|
|
@ -2595,6 +2623,56 @@ class LlamaCppBackend:
|
|||
content_accum,
|
||||
)
|
||||
if not _safety_tc:
|
||||
# ── Re-prompt on plan-without-action ──
|
||||
# If the model described what it intends to do
|
||||
# (forward-looking language) without calling any
|
||||
# tool, nudge it to act. Only fires once per
|
||||
# request and only on short responses that
|
||||
# contain intent signals -- a direct answer
|
||||
# like "4" or "Hello!" won't trigger this.
|
||||
# Use content if available, otherwise fall back
|
||||
# to reasoning text (reasoning-only stalls).
|
||||
_stripped = content_accum.strip()
|
||||
if not _stripped:
|
||||
_stripped = reasoning_accum.strip()
|
||||
if (
|
||||
tools
|
||||
and _reprompt_count < _MAX_REPROMPTS
|
||||
and 0 < len(_stripped) < _REPROMPT_MAX_CHARS
|
||||
and _INTENT_SIGNAL.search(_stripped)
|
||||
):
|
||||
_reprompt_count += 1
|
||||
logger.info(
|
||||
f"Re-prompt {_reprompt_count}/{_MAX_REPROMPTS}: "
|
||||
f"model responded without calling tools "
|
||||
f"({len(_stripped)} chars)"
|
||||
)
|
||||
conversation.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": _stripped,
|
||||
}
|
||||
)
|
||||
conversation.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Please use the available tools to complete "
|
||||
"the task instead of describing what to do."
|
||||
),
|
||||
}
|
||||
)
|
||||
# Accumulate tokens and timing from this iteration
|
||||
_fu_r = _iter_usage or {}
|
||||
_accumulated_completion_tokens += _fu_r.get(
|
||||
"completion_tokens", 0
|
||||
)
|
||||
_it_r = _iter_timings or {}
|
||||
_accumulated_predicted_ms += _it_r.get("predicted_ms", 0)
|
||||
_accumulated_predicted_n += _it_r.get("predicted_n", 0)
|
||||
yield {"type": "status", "text": ""}
|
||||
continue
|
||||
|
||||
# Content was already streamed. Yield metadata.
|
||||
yield {"type": "status", "text": ""}
|
||||
_fu = _iter_usage or {}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,9 @@ import threading
|
|||
|
||||
import re as _re
|
||||
|
||||
# Model size extraction (shared with core/inference/llama_cpp.py)
|
||||
from utils.models import extract_model_size_b as _extract_model_size_b
|
||||
|
||||
|
||||
def _friendly_error(exc: Exception) -> str:
|
||||
"""Extract a user-friendly message from known llama-server errors."""
|
||||
|
|
@ -90,6 +93,12 @@ from datetime import date as _date
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
# Appended to tool-use nudge to discourage plan-without-action
|
||||
_TOOL_ACTION_NUDGE = (
|
||||
" Always call tools directly."
|
||||
" Never describe what you plan to do -- just call the tool immediately."
|
||||
)
|
||||
|
||||
# Regex for stripping leaked tool-call XML from assistant messages/stream
|
||||
_TOOL_XML_RE = _re.compile(
|
||||
r"<tool_call>.*?</tool_call>|<function=\w+>.*?</function>",
|
||||
|
|
@ -1095,12 +1104,20 @@ async def openai_chat_completions(
|
|||
|
||||
_date_line = f"The current date is {_date.today().isoformat()}."
|
||||
|
||||
_web_tips = (
|
||||
"When you search and find a relevant URL in the results, "
|
||||
"fetch its full content by calling web_search with the url parameter. "
|
||||
"Do not repeat the same search query. If a search returns "
|
||||
"no useful results, try rephrasing or fetching a result URL directly."
|
||||
)
|
||||
# Small models (<9B) struggle with multi-step search plans,
|
||||
# so simplify the web tips to avoid plan-then-stall behavior.
|
||||
_model_size_b = _extract_model_size_b(model_name)
|
||||
_is_small_model = _model_size_b is not None and _model_size_b < 9
|
||||
|
||||
if _is_small_model:
|
||||
_web_tips = "Do not repeat the same search query."
|
||||
else:
|
||||
_web_tips = (
|
||||
"When you search and find a relevant URL in the results, "
|
||||
"fetch its full content by calling web_search with the url parameter. "
|
||||
"Do not repeat the same search query. If a search returns "
|
||||
"no useful results, try rephrasing or fetching a result URL directly."
|
||||
)
|
||||
_code_tips = (
|
||||
"Use code execution for math, calculations, data processing, "
|
||||
"or to parse and analyze information from tool results."
|
||||
|
|
@ -1132,6 +1149,7 @@ async def openai_chat_completions(
|
|||
_nudge = ""
|
||||
|
||||
if _nudge:
|
||||
_nudge += _TOOL_ACTION_NUDGE
|
||||
# Append nudge to system prompt (preserve user's prompt)
|
||||
if system_prompt:
|
||||
system_prompt = system_prompt.rstrip() + "\n\n" + _nudge
|
||||
|
|
@ -1208,7 +1226,14 @@ async def openai_chat_completions(
|
|||
break
|
||||
|
||||
if event["type"] == "status":
|
||||
# Empty status marks an iteration boundary
|
||||
# in the GGUF tool loop (e.g. after a
|
||||
# re-prompt). Reset the cumulative cursor
|
||||
# so the next assistant turn streams cleanly.
|
||||
if not event["text"]:
|
||||
prev_text = ""
|
||||
# Emit tool status as a custom SSE event
|
||||
# (including empty ones to clear UI badges)
|
||||
status_data = json.dumps(
|
||||
{
|
||||
"type": "tool_status",
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from .model_config import (
|
|||
get_base_model_from_lora,
|
||||
load_model_config,
|
||||
list_gguf_variants,
|
||||
extract_model_size_b,
|
||||
MODEL_NAME_MAPPING,
|
||||
UI_STATUS_INDICATORS,
|
||||
)
|
||||
|
|
@ -38,6 +39,7 @@ __all__ = [
|
|||
"get_base_model_from_lora",
|
||||
"load_model_config",
|
||||
"list_gguf_variants",
|
||||
"extract_model_size_b",
|
||||
"MODEL_NAME_MAPPING",
|
||||
"UI_STATUS_INDICATORS",
|
||||
"scan_checkpoints",
|
||||
|
|
|
|||
|
|
@ -31,6 +31,37 @@ import yaml
|
|||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# ── Model size extraction ────────────────────────────────────
|
||||
import re as _re
|
||||
|
||||
_MODEL_SIZE_RE = _re.compile(
|
||||
r"(?:^|[-_/])(\d+\.?\d*)\s*([bm])(?:$|[-_/])", _re.IGNORECASE
|
||||
)
|
||||
# MoE active-parameter pattern: matches "A3B", "A3.5B", etc.
|
||||
_ACTIVE_SIZE_RE = _re.compile(
|
||||
r"(?:^|[-_/])a(\d+\.?\d*)\s*([bm])(?:$|[-_/])", _re.IGNORECASE
|
||||
)
|
||||
|
||||
|
||||
def extract_model_size_b(model_id: str) -> float | None:
|
||||
"""Extract model size in billions from a model identifier.
|
||||
|
||||
Prefers MoE active-parameter notation (e.g. ``A3B`` in
|
||||
``Qwen3.5-35B-A3B``) over the total parameter count.
|
||||
Handles both ``B`` (billions) and ``M`` (millions) suffixes.
|
||||
"""
|
||||
mid = (model_id or "").lower()
|
||||
active = _ACTIVE_SIZE_RE.search(mid)
|
||||
if active:
|
||||
val = float(active.group(1))
|
||||
return val / 1000.0 if active.group(2).lower() == "m" else val
|
||||
size = _MODEL_SIZE_RE.search(mid)
|
||||
if not size:
|
||||
return None
|
||||
val = float(size.group(1))
|
||||
return val / 1000.0 if size.group(2).lower() == "m" else val
|
||||
|
||||
|
||||
# Model name mapping: maps all equivalent model names to their canonical YAML config file
|
||||
# Format: "canonical_model_name.yaml": [list of all equivalent model names]
|
||||
# Based on the model mapper provided - canonical filename is based on the first model name in the mapper
|
||||
|
|
|
|||
Loading…
Reference in a new issue