mirror of
https://github.com/mudler/LocalAI
synced 2026-04-21 13:27:21 +00:00
86 lines
2.6 KiB
Python
86 lines
2.6 KiB
Python
|
|
"""Common types + parser registry for tool-call extraction."""
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class ToolCall:
|
||
|
|
"""One extracted tool call — maps 1:1 to backend_pb2.ToolCallDelta."""
|
||
|
|
index: int
|
||
|
|
name: str
|
||
|
|
arguments: str # JSON string
|
||
|
|
id: str = ""
|
||
|
|
|
||
|
|
|
||
|
|
class ToolParser:
|
||
|
|
"""Parser interface.
|
||
|
|
|
||
|
|
Subclasses implement `parse` (full non-streaming pass) and optionally
|
||
|
|
`parse_stream` (incremental). The default `parse_stream` buffers until a
|
||
|
|
full response is available and then delegates to `parse`.
|
||
|
|
"""
|
||
|
|
|
||
|
|
name: str = "base"
|
||
|
|
|
||
|
|
def __init__(self) -> None:
|
||
|
|
self._stream_buffer = ""
|
||
|
|
self._stream_index = 0
|
||
|
|
|
||
|
|
def parse(self, text: str) -> tuple[str, list[ToolCall]]:
|
||
|
|
"""Return (content_for_user, tool_calls)."""
|
||
|
|
raise NotImplementedError
|
||
|
|
|
||
|
|
def parse_stream(self, delta: str, finished: bool = False) -> tuple[str, list[ToolCall]]:
|
||
|
|
"""Accumulate a streaming delta. Emits any tool calls that have closed.
|
||
|
|
|
||
|
|
Default behavior: buffer until `finished=True`, then parse once.
|
||
|
|
Subclasses can override to emit mid-stream.
|
||
|
|
"""
|
||
|
|
self._stream_buffer += delta
|
||
|
|
if not finished:
|
||
|
|
return "", []
|
||
|
|
content, calls = self.parse(self._stream_buffer)
|
||
|
|
# Re-index starting from whatever we've already emitted in this stream.
|
||
|
|
reindexed: list[ToolCall] = []
|
||
|
|
for i, c in enumerate(calls):
|
||
|
|
reindexed.append(ToolCall(
|
||
|
|
index=self._stream_index + i,
|
||
|
|
name=c.name,
|
||
|
|
arguments=c.arguments,
|
||
|
|
id=c.id,
|
||
|
|
))
|
||
|
|
self._stream_index += len(reindexed)
|
||
|
|
return content, reindexed
|
||
|
|
|
||
|
|
def reset(self) -> None:
|
||
|
|
self._stream_buffer = ""
|
||
|
|
self._stream_index = 0
|
||
|
|
|
||
|
|
|
||
|
|
_REGISTRY: dict[str, type[ToolParser]] = {}
|
||
|
|
|
||
|
|
|
||
|
|
def register(cls: type[ToolParser]) -> type[ToolParser]:
|
||
|
|
_REGISTRY[cls.name] = cls
|
||
|
|
return cls
|
||
|
|
|
||
|
|
|
||
|
|
def resolve_parser(name: Optional[str]) -> ToolParser:
|
||
|
|
"""Return a parser instance by name, falling back to a no-op passthrough."""
|
||
|
|
# Import for side effects — each module registers itself.
|
||
|
|
from . import hermes, llama3_json, mistral, qwen3_xml # noqa: F401
|
||
|
|
|
||
|
|
if name and name in _REGISTRY:
|
||
|
|
return _REGISTRY[name]()
|
||
|
|
return PassthroughToolParser()
|
||
|
|
|
||
|
|
|
||
|
|
class PassthroughToolParser(ToolParser):
|
||
|
|
"""No-op parser — used when no tool_parser is configured."""
|
||
|
|
name = "passthrough"
|
||
|
|
|
||
|
|
def parse(self, text: str) -> tuple[str, list[ToolCall]]:
|
||
|
|
return text, []
|