DataDesigner/packages/data-designer-engine/tests/engine/models/test_facade.py
Eric W. Tramel 1cf84d53a7 fix: show token rates in progress demo
Emit synthetic token usage in the credential-free progress panel demo so the live token-rate columns are visible, and accept output-only provider usage as a progress event.

Signed-off-by: Eric W. Tramel <eric.tramel@gmail.com>
2026-05-21 13:41:10 -04:00

1721 lines
64 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from data_designer.engine.context import current_generation_column
from data_designer.engine.mcp.errors import MCPConfigurationError, MCPToolError
from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind
from data_designer.engine.models.clients.types import (
AssistantMessage,
ChatCompletionRequest,
ChatCompletionResponse,
EmbeddingResponse,
ImageGenerationResponse,
ImagePayload,
ToolCall,
Usage,
)
from data_designer.engine.models.errors import (
ImageGenerationError,
ModelGenerationValidationFailureError,
ModelTimeoutError,
)
from data_designer.engine.models.facade import ModelFacade
from data_designer.engine.models.parsers.errors import ParserException
from data_designer.engine.models.usage import TokenCountSource
from data_designer.engine.models.usage_events import TokenUsageEvent, subscribe_token_usage
from data_designer.engine.models.utils import ChatMessage
from data_designer.engine.testing import StubMCPFacade, StubMCPRegistry, make_stub_completion_response
def _make_response(content: str | None = None, **kwargs: Any) -> ChatCompletionResponse:
"""Shorthand for creating a ChatCompletionResponse in tests."""
return make_stub_completion_response(content=content, **kwargs)
def _assert_no_multi_choice_request(
request: Any,
expected_extra_body: dict[str, Any] | None = None,
) -> None:
assert isinstance(request, ChatCompletionRequest)
assert request.n is None
assert request.extra_body == expected_extra_body
@pytest.fixture
def stub_model_facade(
stub_model_configs: list[Any],
stub_model_client: MagicMock,
stub_model_provider_registry: Any,
) -> ModelFacade:
return ModelFacade(
model_config=stub_model_configs[0],
model_provider_registry=stub_model_provider_registry,
client=stub_model_client,
)
@pytest.fixture
def stub_completion_messages() -> list[ChatMessage]:
return [ChatMessage.as_user("test")]
@pytest.mark.parametrize(
"max_correction_steps,max_conversation_restarts,total_calls",
[
(0, 0, 1),
(1, 1, 4),
(1, 2, 6),
(5, 0, 6),
(0, 5, 6),
(3, 3, 16),
],
)
@patch.object(ModelFacade, "completion", autospec=True)
def test_generate(
mock_completion: Any,
stub_model_facade: ModelFacade,
max_correction_steps: int,
max_conversation_restarts: int,
total_calls: int,
) -> None:
bad_response = _make_response("bad response")
mock_completion.side_effect = lambda *args, **kwargs: bad_response
def _failing_parser(response: str) -> str:
raise ParserException("parser exception")
with pytest.raises(ModelGenerationValidationFailureError):
stub_model_facade.generate(
prompt="foo",
system_prompt="bar",
parser=_failing_parser,
max_correction_steps=max_correction_steps,
max_conversation_restarts=max_conversation_restarts,
)
assert mock_completion.call_count == total_calls
with pytest.raises(ModelGenerationValidationFailureError):
stub_model_facade.generate(
prompt="foo",
parser=_failing_parser,
system_prompt="bar",
max_correction_steps=max_correction_steps,
max_conversation_restarts=max_conversation_restarts,
)
assert mock_completion.call_count == 2 * total_calls
@pytest.mark.parametrize(
"system_prompt,expected_messages",
[
("", [ChatMessage.as_user("does not matter")]),
("hello!", [ChatMessage.as_system("hello!"), ChatMessage.as_user("does not matter")]),
],
)
@patch.object(ModelFacade, "completion", autospec=True)
def test_generate_with_system_prompt(
mock_completion: Any,
stub_model_facade: ModelFacade,
system_prompt: str,
expected_messages: list[ChatMessage],
) -> None:
captured_messages = []
def capture_and_return(*args: Any, **kwargs: Any) -> ChatCompletionResponse:
captured_messages.append(list(args[1]))
return _make_response("Hello!")
mock_completion.side_effect = capture_and_return
stub_model_facade.generate(prompt="does not matter", system_prompt=system_prompt, parser=lambda x: x)
assert mock_completion.call_count == 1
assert captured_messages[0] == expected_messages
def test_generate_drops_n_from_single_result_request(
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
stub_model_client.completion.return_value = _make_response("Hello!")
stub_model_facade.generate(prompt="does not matter", parser=lambda x: x, n=4)
_assert_no_multi_choice_request(stub_model_client.completion.call_args.args[0])
def test_generate_drops_extra_body_n_from_single_result_request(
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
stub_model_client.completion.return_value = _make_response("Hello!")
stub_model_facade.generate(prompt="does not matter", parser=lambda x: x, extra_body={"n": 4, "seed": 42})
_assert_no_multi_choice_request(
stub_model_client.completion.call_args.args[0],
expected_extra_body={"seed": 42},
)
def test_generate_drops_configured_extra_body_n_from_single_result_request(
stub_model_configs: list[Any],
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
stub_model_configs[0].inference_parameters.extra_body = {"n": 4, "seed": 42}
stub_model_facade.model_provider.extra_body = {"n": 5, "provider": "kept"}
stub_model_client.completion.return_value = _make_response("Hello!")
stub_model_facade.generate(prompt="does not matter", parser=lambda x: x)
_assert_no_multi_choice_request(
stub_model_client.completion.call_args.args[0],
expected_extra_body={"seed": 42, "provider": "kept"},
)
@pytest.mark.asyncio
async def test_agenerate_drops_n_from_single_result_request(
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
stub_model_client.acompletion = AsyncMock(return_value=_make_response("Hello!"))
await stub_model_facade.agenerate(prompt="does not matter", parser=lambda x: x, n=4)
_assert_no_multi_choice_request(stub_model_client.acompletion.call_args.args[0])
@pytest.mark.asyncio
async def test_agenerate_drops_configured_extra_body_n_from_single_result_request(
stub_model_configs: list[Any],
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
stub_model_configs[0].inference_parameters.extra_body = {"n": 4, "seed": 42}
stub_model_facade.model_provider.extra_body = {"n": 5, "provider": "kept"}
stub_model_client.acompletion = AsyncMock(return_value=_make_response("Hello!"))
await stub_model_facade.agenerate(prompt="does not matter", parser=lambda x: x)
_assert_no_multi_choice_request(
stub_model_client.acompletion.call_args.args[0],
expected_extra_body={"seed": 42, "provider": "kept"},
)
@patch.object(ModelFacade, "completion", autospec=True)
def test_generate_includes_parser_validation_detail_in_user_facing_error(
mock_completion: Any,
stub_model_facade: ModelFacade,
) -> None:
mock_completion.return_value = _make_response("bad response")
def _failing_parser(response: str) -> str:
raise ParserException("Response doesn't match requested <response_schema>\n'name' is a required property")
with pytest.raises(
ModelGenerationValidationFailureError,
match="Validation detail: Response doesn't match requested <response_schema> 'name' is a required property.",
) as exc_info:
stub_model_facade.generate(
prompt="foo",
parser=_failing_parser,
max_correction_steps=0,
max_conversation_restarts=0,
)
assert exc_info.value.detail == "Response doesn't match requested <response_schema> 'name' is a required property"
assert exc_info.value.failure_kind == "schema_validation"
@patch.object(ModelFacade, "acompletion", new_callable=AsyncMock)
@pytest.mark.asyncio
async def test_agenerate_includes_parser_validation_detail_in_user_facing_error(
mock_acompletion: AsyncMock,
stub_model_facade: ModelFacade,
) -> None:
mock_acompletion.return_value = _make_response("bad response")
def _failing_parser(response: str) -> str:
raise ParserException("Response doesn't match requested <response_schema>\n'name' is a required property")
with pytest.raises(
ModelGenerationValidationFailureError,
match="Validation detail: Response doesn't match requested <response_schema> 'name' is a required property.",
) as exc_info:
await stub_model_facade.agenerate(
prompt="foo",
parser=_failing_parser,
max_correction_steps=0,
max_conversation_restarts=0,
)
assert exc_info.value.detail == "Response doesn't match requested <response_schema> 'name' is a required property"
assert exc_info.value.failure_kind == "schema_validation"
@pytest.mark.parametrize(
"raw_content,expected",
[
("\nHello world", "Hello world"),
(" Hello world ", "Hello world"),
("\n\n Hello world\n", "Hello world"),
("Hello world", "Hello world"),
],
)
@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True)
def test_generate_strips_response_content(
mock_completion: Any,
stub_model_facade: ModelFacade,
raw_content: str,
expected: str,
) -> None:
"""Response content from the LLM is stripped of leading/trailing whitespace."""
mock_completion.side_effect = lambda *args, **kwargs: _make_response(raw_content)
result, _ = stub_model_facade.generate(prompt="test", parser=lambda x: x)
assert result == expected
def test_generate_maps_statusless_provider_timeout_to_model_timeout(stub_model_facade: ModelFacade) -> None:
stub_model_facade._client.completion.side_effect = ProviderError(
kind=ProviderErrorKind.TIMEOUT,
message="request timed out",
status_code=None,
provider_name="stub",
model_name=stub_model_facade.model_name,
)
with pytest.raises(ModelTimeoutError, match="timed out"):
stub_model_facade.generate(prompt="test", parser=lambda value: value)
@pytest.mark.asyncio
async def test_agenerate_maps_statusless_provider_timeout_to_model_timeout(stub_model_facade: ModelFacade) -> None:
stub_model_facade._client.acompletion = AsyncMock(
side_effect=ProviderError(
kind=ProviderErrorKind.TIMEOUT,
message="request timed out",
status_code=None,
provider_name="stub",
model_name=stub_model_facade.model_name,
)
)
with pytest.raises(ModelTimeoutError, match="timed out"):
await stub_model_facade.agenerate(prompt="test", parser=lambda value: value)
def test_model_alias_property(stub_model_facade: ModelFacade, stub_model_configs: list[Any]) -> None:
assert stub_model_facade.model_alias == stub_model_configs[0].alias
def test_usage_stats_property(stub_model_facade: ModelFacade) -> None:
assert stub_model_facade.usage_stats is not None
assert hasattr(stub_model_facade.usage_stats, "model_dump")
def test_completion_tracks_reasoning_tokens_without_changing_output_tokens(
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
stub_model_client.completion.return_value = ChatCompletionResponse(
message=AssistantMessage(content="ok"),
usage=Usage(
input_tokens=10,
output_tokens=8,
reasoning_tokens=3,
reasoning_token_count_source=TokenCountSource.PROVIDER,
),
)
stub_model_facade.completion([ChatMessage.as_user("hi")])
token_usage = stub_model_facade.usage_stats.token_usage
assert token_usage.input_tokens == 10
assert token_usage.output_tokens == 8
assert token_usage.reasoning_tokens == 3
assert token_usage.reasoning_token_count_source == TokenCountSource.PROVIDER
assert token_usage.total_tokens == 18
def test_completion_emits_token_usage_event(
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
events: list[TokenUsageEvent] = []
unsubscribe = subscribe_token_usage(events.append)
stub_model_client.completion.return_value = ChatCompletionResponse(
message=AssistantMessage(content="ok"),
usage=Usage(input_tokens=10, output_tokens=8),
)
token = current_generation_column.set("intent_label")
try:
stub_model_facade.completion([ChatMessage.as_user("hi")])
finally:
current_generation_column.reset(token)
unsubscribe()
assert len(events) == 1
assert events[0].model_alias == stub_model_facade.model_alias
assert events[0].model_name == stub_model_facade.model_name
assert events[0].input_tokens == 10
assert events[0].output_tokens == 8
assert events[0].column == "intent_label"
def test_completion_emits_token_usage_event_when_only_output_tokens_are_reported(
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
events: list[TokenUsageEvent] = []
unsubscribe = subscribe_token_usage(events.append)
stub_model_client.completion.return_value = ChatCompletionResponse(
message=AssistantMessage(content="ok"),
usage=Usage(output_tokens=8),
)
try:
stub_model_facade.completion([ChatMessage.as_user("hi")])
finally:
unsubscribe()
assert len(events) == 1
assert events[0].input_tokens == 0
assert events[0].output_tokens == 8
def test_consolidate_kwargs(stub_model_configs: list[Any], stub_model_facade: ModelFacade) -> None:
# Model config generate kwargs are used as base, and purpose is removed.
# When telemetry is enabled (default), X-Title is injected.
result = stub_model_facade.consolidate_kwargs(purpose="test")
assert result == {
**stub_model_configs[0].inference_parameters.generate_kwargs,
"extra_headers": {"X-Title": "NeMo Data Designer"},
}
# kwargs overrides model config generate kwargs
result = stub_model_facade.consolidate_kwargs(temperature=0.01, purpose="test")
assert result == {
**stub_model_configs[0].inference_parameters.generate_kwargs,
"temperature": 0.01,
"extra_headers": {"X-Title": "NeMo Data Designer"},
}
# Provider extra_body overrides all other kwargs
stub_model_facade.model_provider.extra_body = {"foo_provider": "bar_provider"}
result = stub_model_facade.consolidate_kwargs(extra_body={"foo": "bar"}, purpose="test")
assert result == {
**stub_model_configs[0].inference_parameters.generate_kwargs,
"extra_body": {"foo_provider": "bar_provider", "foo": "bar"},
"extra_headers": {"X-Title": "NeMo Data Designer"},
}
# Provider extra_headers merges with caller headers (provider takes precedence)
stub_model_facade.model_provider.extra_body = None
stub_model_facade.model_provider.extra_headers = {"hello": "world", "hola": "mundo"}
result = stub_model_facade.consolidate_kwargs(extra_headers={"hello": "caller", "X-Trace-ID": "abc"})
assert result == {
**stub_model_configs[0].inference_parameters.generate_kwargs,
"extra_headers": {"X-Title": "NeMo Data Designer", "hello": "world", "hola": "mundo", "X-Trace-ID": "abc"},
}
@patch("data_designer.engine.models.facade.TELEMETRY_ENABLED", False)
def test_consolidate_kwargs_telemetry_disabled(stub_model_configs: list[Any], stub_model_facade: ModelFacade) -> None:
"""Framework attribution headers are omitted when telemetry is disabled."""
result = stub_model_facade.consolidate_kwargs()
assert "extra_headers" not in result
# Provider extra_headers still applied even with telemetry off
stub_model_facade.model_provider.extra_headers = {"Custom": "header"}
result = stub_model_facade.consolidate_kwargs()
assert result["extra_headers"] == {"Custom": "header"}
def test_consolidate_kwargs_user_x_title_override(
stub_model_configs: list[Any], stub_model_facade: ModelFacade
) -> None:
"""User-supplied X-Title takes precedence over the framework default."""
stub_model_facade.model_provider.extra_headers = {"X-Title": "My Custom App"}
result = stub_model_facade.consolidate_kwargs()
assert result["extra_headers"]["X-Title"] == "My Custom App"
stub_model_facade.model_provider.extra_headers = None
result = stub_model_facade.consolidate_kwargs(extra_headers={"X-Title": "Caller App"})
assert result["extra_headers"]["X-Title"] == "Caller App"
def test_consolidate_kwargs_with_explicit_none_extra_headers(
stub_model_configs: list[Any], stub_model_facade: ModelFacade
) -> None:
"""Explicit None extra_headers does not break provider merges or framework attribution."""
stub_model_facade.model_provider.extra_headers = {"hello": "world"}
result = stub_model_facade.consolidate_kwargs(extra_headers=None)
assert result["extra_headers"] == {"X-Title": "NeMo Data Designer", "hello": "world"}
def test_consolidate_kwargs_openrouter_attribution(
stub_model_configs: list[Any], stub_model_facade: ModelFacade
) -> None:
"""OpenRouter-specific attribution headers are injected when provider is openrouter."""
stub_model_facade.model_provider.name = "openrouter"
stub_model_facade.model_provider.extra_headers = None
result = stub_model_facade.consolidate_kwargs()
assert result["extra_headers"] == {
"X-Title": "NeMo Data Designer",
"HTTP-Referer": "https://github.com/NVIDIA-NeMo/DataDesigner",
"X-OpenRouter-Title": "NeMo Data Designer",
"X-OpenRouter-Categories": "programming-app",
}
def test_consolidate_kwargs_openrouter_user_override_preserved(
stub_model_configs: list[Any], stub_model_facade: ModelFacade
) -> None:
"""User-supplied OpenRouter headers take precedence over framework defaults."""
stub_model_facade.model_provider.name = "openrouter"
stub_model_facade.model_provider.extra_headers = None
result = stub_model_facade.consolidate_kwargs(
extra_headers={"X-OpenRouter-Title": "Custom App", "X-Custom": "value"}
)
# User-supplied X-OpenRouter-Title should NOT be overwritten
assert result["extra_headers"]["X-OpenRouter-Title"] == "Custom App"
assert result["extra_headers"]["X-Custom"] == "value"
# Framework defaults still fill in missing keys
assert result["extra_headers"]["HTTP-Referer"] == "https://github.com/NVIDIA-NeMo/DataDesigner"
assert result["extra_headers"]["X-OpenRouter-Categories"] == "programming-app"
assert result["extra_headers"]["X-Title"] == "NeMo Data Designer"
def test_consolidate_kwargs_openrouter_provider_headers_preserved(
stub_model_configs: list[Any], stub_model_facade: ModelFacade
) -> None:
"""Provider-level OpenRouter headers override programmatic injection."""
stub_model_facade.model_provider.name = "openrouter"
stub_model_facade.model_provider.extra_headers = {
"HTTP-Referer": "https://custom-site.example.com",
"X-OpenRouter-Title": "Provider Title",
}
result = stub_model_facade.consolidate_kwargs()
# Provider-level values take precedence
assert result["extra_headers"]["HTTP-Referer"] == "https://custom-site.example.com"
assert result["extra_headers"]["X-OpenRouter-Title"] == "Provider Title"
# Framework still fills in what's missing
assert result["extra_headers"]["X-OpenRouter-Categories"] == "programming-app"
assert result["extra_headers"]["X-Title"] == "NeMo Data Designer"
@patch("data_designer.engine.models.facade.TELEMETRY_ENABLED", False)
def test_consolidate_kwargs_openrouter_no_attribution_when_telemetry_off(
stub_model_configs: list[Any], stub_model_facade: ModelFacade
) -> None:
"""OpenRouter attribution headers are NOT injected when telemetry is disabled."""
stub_model_facade.model_provider.name = "openrouter"
stub_model_facade.model_provider.extra_headers = None
result = stub_model_facade.consolidate_kwargs()
assert "extra_headers" not in result
def test_consolidate_kwargs_non_openrouter_no_openrouter_headers(
stub_model_configs: list[Any], stub_model_facade: ModelFacade
) -> None:
"""Non-openrouter providers do NOT get OpenRouter-specific headers."""
stub_model_facade.model_provider.name = "nvidia"
stub_model_facade.model_provider.extra_headers = None
result = stub_model_facade.consolidate_kwargs()
assert result["extra_headers"] == {"X-Title": "NeMo Data Designer"}
assert "HTTP-Referer" not in result["extra_headers"]
assert "X-OpenRouter-Title" not in result["extra_headers"]
assert "X-OpenRouter-Categories" not in result["extra_headers"]
@pytest.mark.parametrize(
"skip_usage_tracking",
[
False,
True,
],
)
def test_completion_success(
stub_completion_messages: list[ChatMessage],
stub_model_configs: Any,
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
skip_usage_tracking: bool,
) -> None:
expected_response = _make_response("Test response")
stub_model_client.completion.return_value = expected_response
result = stub_model_facade.completion(stub_completion_messages, skip_usage_tracking=skip_usage_tracking)
assert result == expected_response
assert stub_model_client.completion.call_count == 1
def test_completion_with_exception(
stub_completion_messages: list[ChatMessage],
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
stub_model_client.completion.side_effect = Exception("Router error")
with pytest.raises(Exception, match="Router error"):
stub_model_facade.completion(stub_completion_messages)
def test_completion_with_kwargs(
stub_completion_messages: list[ChatMessage],
stub_model_configs: Any,
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
expected_response = _make_response("Test response")
stub_model_client.completion.return_value = expected_response
kwargs = {"temperature": 0.7, "max_tokens": 100}
result = stub_model_facade.completion(stub_completion_messages, **kwargs)
assert result == expected_response
assert stub_model_client.completion.call_count == 1
def test_completion_forwards_n_to_request(
stub_completion_messages: list[ChatMessage],
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
expected_response = _make_response("Test response")
stub_model_client.completion.return_value = expected_response
stub_model_facade.completion(stub_completion_messages, n=4)
request = stub_model_client.completion.call_args.args[0]
assert isinstance(request, ChatCompletionRequest)
assert request.n == 4
def test_generate_text_embeddings_success(
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
stub_model_client.embeddings.return_value = EmbeddingResponse(vectors=expected_vectors)
input_texts = ["test1", "test2"]
result = stub_model_facade.generate_text_embeddings(input_texts)
assert result == expected_vectors
def test_generate_text_embeddings_with_exception(
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
stub_model_client.embeddings.side_effect = Exception("Router error")
with pytest.raises(Exception, match="Router error"):
stub_model_facade.generate_text_embeddings(["test1", "test2"])
def test_generate_text_embeddings_with_kwargs(
stub_model_configs: Any,
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
stub_model_client.embeddings.return_value = EmbeddingResponse(vectors=expected_vectors)
kwargs = {"temperature": 0.7, "max_tokens": 100, "input_type": "query"}
_ = stub_model_facade.generate_text_embeddings(["test1", "test2"], **kwargs)
assert stub_model_client.embeddings.call_count == 1
def test_generate_with_mcp_tools(
stub_model_configs: Any,
stub_model_client: MagicMock,
stub_model_provider_registry: Any,
) -> None:
tool_call = ToolCall(id="call-1", name="lookup", arguments_json='{"query": "foo"}')
responses = [
_make_response(content=None, tool_calls=[tool_call]),
_make_response("final result"),
]
captured_calls: list[tuple[list[ChatMessage], dict[str, Any]]] = []
registry_calls: list[tuple[str, str, dict[str, str], None]] = []
def process_with_tracking(completion_response: ChatCompletionResponse) -> list[ChatMessage]:
message = completion_response.message
if not message.tool_calls:
return [ChatMessage.as_assistant(content=message.content or "")]
registry_calls.append(("tools", "lookup", {"query": "foo"}, None))
tc_dict = {
"id": "call-1",
"type": "function",
"function": {"name": "lookup", "arguments": '{"query": "foo"}'},
}
return [
ChatMessage.as_assistant(content="", tool_calls=[tc_dict]),
ChatMessage.as_tool(content="tool-output", tool_call_id="call-1"),
]
facade = StubMCPFacade(
tool_schemas=[
{
"type": "function",
"function": {"name": "lookup", "description": "Lookup", "parameters": {"type": "object"}},
}
],
process_fn=process_with_tracking,
)
registry = StubMCPRegistry(facade)
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse:
captured_calls.append((messages, kwargs))
return responses.pop(0)
model = ModelFacade(
model_config=stub_model_configs[0],
model_provider_registry=stub_model_provider_registry,
client=stub_model_client,
mcp_registry=registry,
)
with patch.object(ModelFacade, "completion", new=_completion):
result, _ = model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
assert result == "final result"
assert len(captured_calls) == 2
assert "tools" in captured_calls[0][1]
assert captured_calls[0][1]["tools"][0]["function"]["name"] == "lookup"
assert any(message.role == "tool" for message in captured_calls[1][0])
assert registry_calls == [("tools", "lookup", {"query": "foo"}, None)]
def test_generate_preserves_multimodal_mcp_tool_results_between_turns(
stub_model_configs: Any,
stub_model_client: MagicMock,
stub_model_provider_registry: Any,
) -> None:
tool_call = ToolCall(id="call-1", name="render_chart", arguments_json="{}")
responses = [
_make_response(content=None, tool_calls=[tool_call]),
_make_response("final result"),
]
multimodal_result = [
{"type": "text", "text": "Rendered chart:"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBORw0KGgo="}},
]
captured_calls: list[tuple[list[ChatMessage], dict[str, Any]]] = []
def process_with_multimodal_tool_result(completion_response: ChatCompletionResponse) -> list[ChatMessage]:
if not completion_response.message.tool_calls:
return [ChatMessage.as_assistant(content=completion_response.message.content or "")]
return [
ChatMessage.as_assistant(
content="",
tool_calls=[
{
"id": "call-1",
"type": "function",
"function": {"name": "render_chart", "arguments": "{}"},
}
],
),
ChatMessage.as_tool(content=multimodal_result, tool_call_id="call-1"),
]
facade = StubMCPFacade(
tool_schemas=[
{
"type": "function",
"function": {"name": "render_chart", "description": "Render", "parameters": {"type": "object"}},
}
],
process_fn=process_with_multimodal_tool_result,
)
registry = StubMCPRegistry(facade)
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse:
captured_calls.append((messages, kwargs))
return responses.pop(0)
model = ModelFacade(
model_config=stub_model_configs[0],
model_provider_registry=stub_model_provider_registry,
client=stub_model_client,
mcp_registry=registry,
)
with patch.object(ModelFacade, "completion", new=_completion):
result, _ = model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
assert result == "final result"
assert len(captured_calls) == 2
tool_messages = [message for message in captured_calls[1][0] if message.role == "tool"]
assert len(tool_messages) == 1
assert tool_messages[0].content == multimodal_result
def test_generate_with_tools_missing_registry(
stub_model_configs: Any, stub_model_client: MagicMock, stub_model_provider_registry: Any
) -> None:
model = ModelFacade(
model_config=stub_model_configs[0],
model_provider_registry=stub_model_provider_registry,
client=stub_model_client,
mcp_registry=None,
)
with pytest.raises(MCPConfigurationError):
model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
# =============================================================================
# Tool calling integration tests
# =============================================================================
def test_generate_with_tool_alias_multiple_turns(
stub_model_configs: Any,
stub_model_client: MagicMock,
stub_model_provider_registry: Any,
) -> None:
"""Multiple tool call turns before final response."""
tool_call_1 = ToolCall(id="call-1", name="lookup", arguments_json='{"query": "foo"}')
tool_call_2 = ToolCall(id="call-2", name="search", arguments_json='{"term": "bar"}')
responses = [
_make_response("First lookup", tool_calls=[tool_call_1]),
_make_response("Second search", tool_calls=[tool_call_2]),
_make_response("final result after two tool turns"),
]
call_count = 0
facade = StubMCPFacade(max_tool_call_turns=5)
registry = StubMCPRegistry(facade)
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse:
nonlocal call_count
call_count += 1
return responses.pop(0)
model = ModelFacade(
model_config=stub_model_configs[0],
model_provider_registry=stub_model_provider_registry,
client=stub_model_client,
mcp_registry=registry,
)
with patch.object(ModelFacade, "completion", new=_completion):
result, trace = model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
assert result == "final result after two tool turns"
assert call_count == 3 # 2 tool turns + 1 final
def test_generate_with_tools_tracks_usage_stats(
stub_model_configs: Any,
stub_model_client: MagicMock,
stub_model_provider_registry: Any,
) -> None:
"""Tool usage stats are properly tracked with generations_with_tools incremented."""
tool_call_1 = ToolCall(id="call-1", name="lookup", arguments_json='{"query": "foo"}')
tool_call_2 = ToolCall(id="call-2", name="search", arguments_json='{"term": "bar"}')
responses = [
_make_response("First lookup", tool_calls=[tool_call_1]),
_make_response("Second search", tool_calls=[tool_call_2]),
_make_response("final result"),
]
facade = StubMCPFacade(max_tool_call_turns=5)
registry = StubMCPRegistry(facade)
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse:
return responses.pop(0)
model = ModelFacade(
model_config=stub_model_configs[0],
model_provider_registry=stub_model_provider_registry,
client=stub_model_client,
mcp_registry=registry,
)
# Verify initial state
assert model.usage_stats.tool_usage.total_tool_calls == 0
assert model.usage_stats.tool_usage.total_tool_call_turns == 0
assert model.usage_stats.tool_usage.total_generations == 0
assert model.usage_stats.tool_usage.generations_with_tools == 0
with patch.object(ModelFacade, "completion", new=_completion):
result, _ = model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
assert result == "final result"
# Verify tool usage stats are tracked correctly
assert model.usage_stats.tool_usage.total_tool_calls == 2
assert model.usage_stats.tool_usage.total_tool_call_turns == 2
assert model.usage_stats.tool_usage.total_generations == 1
assert model.usage_stats.tool_usage.generations_with_tools == 1
def test_generate_with_tools_tracks_multiple_generations(
stub_model_configs: Any,
stub_model_client: MagicMock,
stub_model_provider_registry: Any,
) -> None:
"""Tool usage is correctly tracked across multiple generations."""
facade = StubMCPFacade(max_tool_call_turns=10)
registry = StubMCPRegistry(facade)
model = ModelFacade(
model_config=stub_model_configs[0],
model_provider_registry=stub_model_provider_registry,
client=stub_model_client,
mcp_registry=registry,
)
# Generation 1: 2 tool calls across 1 turn
tool_call_a = ToolCall(id="call-a", name="lookup", arguments_json='{"q": "1"}')
tool_call_b = ToolCall(id="call-b", name="lookup", arguments_json='{"q": "2"}')
responses_gen1 = [
_make_response("", tool_calls=[tool_call_a, tool_call_b]),
_make_response("result 1"),
]
def _completion_gen1(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse:
return responses_gen1.pop(0)
with patch.object(ModelFacade, "completion", new=_completion_gen1):
model.generate(prompt="q1", parser=lambda x: x, tool_alias="tools")
# Generation 2: 4 tool calls across 2 turns
tool_call_c = ToolCall(id="call-c", name="search", arguments_json='{"q": "3"}')
tool_call_d = ToolCall(id="call-d", name="search", arguments_json='{"q": "4"}')
responses_gen2 = [
_make_response("", tool_calls=[tool_call_a, tool_call_b]),
_make_response("", tool_calls=[tool_call_c, tool_call_d]),
_make_response("result 2"),
]
def _completion_gen2(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse:
return responses_gen2.pop(0)
with patch.object(ModelFacade, "completion", new=_completion_gen2):
model.generate(prompt="q2", parser=lambda x: x, tool_alias="tools")
# Generation 3: No tool calls
responses_gen3 = [
_make_response("result 3"),
]
def _completion_gen3(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse:
return responses_gen3.pop(0)
with patch.object(ModelFacade, "completion", new=_completion_gen3):
model.generate(prompt="q3", parser=lambda x: x, tool_alias="tools")
# Verify totals: 2 + 4 + 0 = 6 calls, 1 + 2 + 0 = 3 turns, 3 total generations, 2 with tools
assert model.usage_stats.tool_usage.total_tool_calls == 6
assert model.usage_stats.tool_usage.total_tool_call_turns == 3
assert model.usage_stats.tool_usage.total_generations == 3
assert model.usage_stats.tool_usage.generations_with_tools == 2
def test_generate_tool_turn_limit_triggers_refusal(
stub_model_configs: Any,
stub_model_client: MagicMock,
stub_model_provider_registry: Any,
) -> None:
"""When max_tool_call_turns exceeded, refusal is used."""
tool_call = ToolCall(id="call-1", name="lookup", arguments_json="{}")
responses = [
_make_response("", tool_calls=[tool_call]), # Turn 1
_make_response("", tool_calls=[tool_call]), # Turn 2 (max)
_make_response("", tool_calls=[tool_call]), # Turn 3 (exceeds, should refuse)
_make_response("final answer after refusal"),
]
process_calls = 0
refuse_calls = 0
tc_dict = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}}
def custom_process_fn(completion_response: ChatCompletionResponse) -> list[ChatMessage]:
nonlocal process_calls
process_calls += 1
return [
ChatMessage.as_assistant(content="", tool_calls=[tc_dict]),
ChatMessage.as_tool(content="tool-result", tool_call_id="call-1"),
]
def custom_refuse_fn(completion_response: ChatCompletionResponse) -> list[ChatMessage]:
nonlocal refuse_calls
refuse_calls += 1
return [
ChatMessage.as_assistant(content="", tool_calls=[tc_dict]),
ChatMessage.as_tool(content="REFUSED: Budget exceeded", tool_call_id="call-1"),
]
facade = StubMCPFacade(max_tool_call_turns=2, process_fn=custom_process_fn, refuse_fn=custom_refuse_fn)
registry = StubMCPRegistry(facade)
response_idx = 0
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse:
nonlocal response_idx
resp = responses[response_idx]
response_idx += 1
return resp
model = ModelFacade(
model_config=stub_model_configs[0],
model_provider_registry=stub_model_provider_registry,
client=stub_model_client,
mcp_registry=registry,
)
with patch.object(ModelFacade, "completion", new=_completion):
result, _ = model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
assert result == "final answer after refusal"
assert process_calls == 2 # Turns 1 and 2
assert refuse_calls == 1 # Turn 3 was refused
def test_generate_tool_turn_limit_model_responds_after_refusal(
stub_model_configs: Any,
stub_model_client: MagicMock,
stub_model_provider_registry: Any,
) -> None:
"""Model provides final answer after refusal message."""
tool_call = ToolCall(id="call-1", name="lookup", arguments_json="{}")
tc_dict = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}}
responses = [
_make_response("", tool_calls=[tool_call]), # Exceeds on first turn
_make_response("I understand, here is the answer without tools"),
]
def custom_refuse_fn(completion_response: ChatCompletionResponse) -> list[ChatMessage]:
return [
ChatMessage.as_assistant(content="", tool_calls=[tc_dict]),
ChatMessage.as_tool(
content="Tool call refused: You have reached the maximum number of tool-calling turns.",
tool_call_id="call-1",
),
]
facade = StubMCPFacade(
max_tool_call_turns=0,
process_fn=lambda _: [], # Should not be called
refuse_fn=custom_refuse_fn,
)
registry = StubMCPRegistry(facade)
response_idx = 0
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse:
nonlocal response_idx
resp = responses[response_idx]
response_idx += 1
return resp
model = ModelFacade(
model_config=stub_model_configs[0],
model_provider_registry=stub_model_provider_registry,
client=stub_model_client,
mcp_registry=registry,
)
with patch.object(ModelFacade, "completion", new=_completion):
result, trace = model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
assert result == "I understand, here is the answer without tools"
# Trace should include refusal message
assert any(msg.content and "refused" in msg.content.lower() for msg in trace if msg.role == "tool")
def test_generate_tool_alias_not_in_registry(
stub_model_configs: Any,
stub_model_client: MagicMock,
stub_model_provider_registry: Any,
) -> None:
"""Raises error when tool_alias not found in MCPRegistry."""
class _StubMCPRegistry:
def get_mcp(self, *, tool_alias: str) -> Any:
raise ValueError(f"No tool config with alias {tool_alias!r} found!")
model = ModelFacade(
model_config=stub_model_configs[0],
model_provider_registry=stub_model_provider_registry,
client=stub_model_client,
mcp_registry=_StubMCPRegistry(),
)
with pytest.raises(MCPConfigurationError, match="not registered"):
model.generate(prompt="question", parser=lambda x: x, tool_alias="nonexistent")
def test_generate_no_tool_alias_ignores_mcp(
stub_model_configs: Any,
stub_model_client: MagicMock,
stub_model_provider_registry: Any,
) -> None:
"""When tool_alias is None, no MCP operations occur."""
get_mcp_called = False
class _StubMCPRegistry:
def get_mcp(self, *, tool_alias: str) -> Any:
nonlocal get_mcp_called
get_mcp_called = True
raise RuntimeError("Should not be called")
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse:
assert "tools" not in kwargs # No tools should be passed
return _make_response("response without tools")
model = ModelFacade(
model_config=stub_model_configs[0],
model_provider_registry=stub_model_provider_registry,
client=stub_model_client,
mcp_registry=_StubMCPRegistry(),
)
with patch.object(ModelFacade, "completion", new=_completion):
result, _ = model.generate(prompt="question", parser=lambda x: x, tool_alias=None)
assert result == "response without tools"
assert get_mcp_called is False
def test_generate_tool_calls_with_parser_corrections(
stub_model_configs: Any,
stub_model_client: MagicMock,
stub_model_provider_registry: Any,
) -> None:
"""Tool calling works correctly with parser correction steps."""
tool_call = ToolCall(id="call-1", name="lookup", arguments_json="{}")
parse_count = 0
responses = [
_make_response("", tool_calls=[tool_call]), # Tool call
_make_response("bad format"), # Parser will fail
_make_response("correct format"), # Parser will succeed
]
facade = StubMCPFacade()
registry = StubMCPRegistry(facade)
response_idx = 0
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse:
nonlocal response_idx
resp = responses[response_idx]
response_idx += 1
return resp
def _parser(text: str) -> str:
nonlocal parse_count
parse_count += 1
if text == "bad format":
raise ParserException("Invalid format")
return text
model = ModelFacade(
model_config=stub_model_configs[0],
model_provider_registry=stub_model_provider_registry,
client=stub_model_client,
mcp_registry=registry,
)
with patch.object(ModelFacade, "completion", new=_completion):
result, _ = model.generate(prompt="question", parser=_parser, tool_alias="tools", max_correction_steps=1)
assert result == "correct format"
assert parse_count == 2 # Failed once, then succeeded
def test_generate_tool_calls_with_conversation_restarts(
stub_model_configs: Any,
stub_model_client: MagicMock,
stub_model_provider_registry: Any,
) -> None:
"""Tool calling works correctly with conversation restarts."""
tool_call = ToolCall(id="call-1", name="lookup", arguments_json="{}")
messages_at_call: list[int] = []
responses = [
_make_response("", tool_calls=[tool_call]),
_make_response("still bad"), # Fails parser, triggers restart
_make_response("", tool_calls=[tool_call]), # After restart
_make_response("good result"),
]
facade = StubMCPFacade()
registry = StubMCPRegistry(facade)
response_idx = 0
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse:
nonlocal response_idx
messages_at_call.append(len(messages))
resp = responses[response_idx]
response_idx += 1
return resp
def _parser(text: str) -> str:
if text == "still bad":
raise ParserException("Bad format")
return text
model = ModelFacade(
model_config=stub_model_configs[0],
model_provider_registry=stub_model_provider_registry,
client=stub_model_client,
mcp_registry=registry,
)
with patch.object(ModelFacade, "completion", new=_completion):
result, _ = model.generate(
prompt="question", parser=_parser, tool_alias="tools", max_correction_steps=0, max_conversation_restarts=1
)
assert result == "good result"
# After restart, message count should preserve tool call history (restart from checkpoint)
assert messages_at_call[2] == messages_at_call[1]
# =============================================================================
# Message trace tests
# =============================================================================
def test_generate_trace_includes_tool_calls(
stub_model_configs: Any,
stub_model_client: MagicMock,
stub_model_provider_registry: Any,
) -> None:
"""Returned trace includes tool call messages."""
tool_call = ToolCall(id="call-1", name="lookup", arguments_json='{"q": "test"}')
responses = [
_make_response("Let me look that up", tool_calls=[tool_call]),
_make_response("Here is the answer"),
]
facade = StubMCPFacade()
registry = StubMCPRegistry(facade)
response_idx = 0
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse:
nonlocal response_idx
resp = responses[response_idx]
response_idx += 1
return resp
model = ModelFacade(
model_config=stub_model_configs[0],
model_provider_registry=stub_model_provider_registry,
client=stub_model_client,
mcp_registry=registry,
)
with patch.object(ModelFacade, "completion", new=_completion):
_, trace = model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
# Find assistant message with tool_calls
assistant_with_tools = [msg for msg in trace if msg.role == "assistant" and msg.tool_calls]
assert len(assistant_with_tools) >= 1
assert assistant_with_tools[0].tool_calls[0]["function"]["name"] == "lookup"
def test_generate_trace_includes_tool_responses(
stub_model_configs: Any,
stub_model_client: MagicMock,
stub_model_provider_registry: Any,
) -> None:
"""Returned trace includes tool response messages."""
tool_call = ToolCall(id="call-1", name="lookup", arguments_json="{}")
tc_dict = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}}
responses = [
_make_response("", tool_calls=[tool_call]),
_make_response("final"),
]
def custom_process_fn(completion_response: ChatCompletionResponse) -> list[ChatMessage]:
return [
ChatMessage.as_assistant(content="", tool_calls=[tc_dict]),
ChatMessage.as_tool(content="THE TOOL RESPONSE CONTENT", tool_call_id="call-1"),
]
facade = StubMCPFacade(process_fn=custom_process_fn)
registry = StubMCPRegistry(facade)
response_idx = 0
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse:
nonlocal response_idx
resp = responses[response_idx]
response_idx += 1
return resp
model = ModelFacade(
model_config=stub_model_configs[0],
model_provider_registry=stub_model_provider_registry,
client=stub_model_client,
mcp_registry=registry,
)
with patch.object(ModelFacade, "completion", new=_completion):
_, trace = model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
tool_messages = [msg for msg in trace if msg.role == "tool"]
assert len(tool_messages) >= 1
assert tool_messages[0].content == "THE TOOL RESPONSE CONTENT"
assert tool_messages[0].tool_call_id == "call-1"
def test_generate_trace_includes_refusal_messages(
stub_model_configs: Any,
stub_model_client: MagicMock,
stub_model_provider_registry: Any,
) -> None:
"""Returned trace includes refusal messages when budget exhausted."""
tool_call = ToolCall(id="call-1", name="lookup", arguments_json="{}")
tc_dict = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}}
responses = [
_make_response("", tool_calls=[tool_call]), # Will be refused (max_turns=0)
_make_response("answer without tools"),
]
def custom_refuse_fn(completion_response: ChatCompletionResponse) -> list[ChatMessage]:
return [
ChatMessage.as_assistant(content="", tool_calls=[tc_dict]),
ChatMessage.as_tool(content="BUDGET_EXCEEDED_REFUSAL", tool_call_id="call-1"),
]
facade = StubMCPFacade(
max_tool_call_turns=0,
process_fn=lambda _: [],
refuse_fn=custom_refuse_fn,
)
registry = StubMCPRegistry(facade)
response_idx = 0
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse:
nonlocal response_idx
resp = responses[response_idx]
response_idx += 1
return resp
model = ModelFacade(
model_config=stub_model_configs[0],
model_provider_registry=stub_model_provider_registry,
client=stub_model_client,
mcp_registry=registry,
)
with patch.object(ModelFacade, "completion", new=_completion):
_, trace = model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
# Check for refusal message in trace
tool_messages = [msg for msg in trace if msg.role == "tool"]
assert any("BUDGET_EXCEEDED_REFUSAL" in msg.content for msg in tool_messages)
def test_generate_trace_preserves_reasoning_content(
stub_model_configs: Any,
stub_model_client: MagicMock,
stub_model_provider_registry: Any,
) -> None:
"""Trace messages preserve reasoning_content field."""
response = _make_response(
"The answer is 42",
reasoning_content="Let me think about this carefully...",
)
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse:
return response
model = ModelFacade(
model_config=stub_model_configs[0],
model_provider_registry=stub_model_provider_registry,
client=stub_model_client,
)
with patch.object(ModelFacade, "completion", new=_completion):
_, trace = model.generate(prompt="question", parser=lambda x: x)
# Find assistant message and check reasoning content
assistant_messages = [msg for msg in trace if msg.role == "assistant"]
assert len(assistant_messages) >= 1
assert assistant_messages[-1].reasoning_content == "Let me think about this carefully..."
# =============================================================================
# Error handling tests
# =============================================================================
def test_generate_tool_execution_error(
stub_model_configs: Any,
stub_model_client: MagicMock,
stub_model_provider_registry: Any,
) -> None:
"""Handles MCP tool execution errors appropriately."""
tool_call = ToolCall(id="call-1", name="lookup", arguments_json="{}")
responses = [_make_response("", tool_calls=[tool_call])]
def error_process_fn(completion_response: ChatCompletionResponse) -> list[ChatMessage]:
raise MCPToolError("Tool execution failed: Connection refused")
facade = StubMCPFacade(process_fn=error_process_fn)
registry = StubMCPRegistry(facade)
response_idx = 0
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse:
nonlocal response_idx
resp = responses[response_idx]
response_idx += 1
return resp
model = ModelFacade(
model_config=stub_model_configs[0],
model_provider_registry=stub_model_provider_registry,
client=stub_model_client,
mcp_registry=registry,
)
with patch.object(ModelFacade, "completion", new=_completion):
with pytest.raises(MCPToolError, match="Connection refused"):
model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
def test_generate_tool_invalid_arguments(
stub_model_configs: Any,
stub_model_client: MagicMock,
stub_model_provider_registry: Any,
) -> None:
"""Handles invalid tool arguments from LLM."""
tool_call = ToolCall(id="call-1", name="lookup", arguments_json="not valid json")
responses = [_make_response("", tool_calls=[tool_call])]
def error_process_fn(completion_response: ChatCompletionResponse) -> list[ChatMessage]:
raise MCPToolError("Invalid tool arguments for 'lookup': not valid json")
facade = StubMCPFacade(process_fn=error_process_fn)
registry = StubMCPRegistry(facade)
response_idx = 0
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse:
nonlocal response_idx
resp = responses[response_idx]
response_idx += 1
return resp
model = ModelFacade(
model_config=stub_model_configs[0],
model_provider_registry=stub_model_provider_registry,
client=stub_model_client,
mcp_registry=registry,
)
with patch.object(ModelFacade, "completion", new=_completion):
with pytest.raises(MCPToolError, match="Invalid tool arguments"):
model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
# =============================================================================
# Image generation tests
# =============================================================================
def test_generate_image_diffusion_tracks_image_usage(
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
"""Test that generate_image tracks image usage for diffusion models."""
stub_model_client.generate_image.return_value = ImageGenerationResponse(
images=[
ImagePayload(b64_data="image1_base64"),
ImagePayload(b64_data="image2_base64"),
ImagePayload(b64_data="image3_base64"),
]
)
assert stub_model_facade.usage_stats.image_usage.total_images == 0
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True):
images = stub_model_facade.generate_image(prompt="test prompt", extra_body={"n": 3})
assert len(images) == 3
assert images == ["image1_base64", "image2_base64", "image3_base64"]
assert stub_model_facade.usage_stats.image_usage.total_images == 3
assert stub_model_facade.usage_stats.image_usage.has_usage is True
def test_generate_image_chat_completion_tracks_image_usage(
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
"""Test that generate_image tracks image usage for chat completion models."""
stub_model_client.generate_image.return_value = ImageGenerationResponse(
images=[
ImagePayload(b64_data="image1"),
ImagePayload(b64_data="image2"),
]
)
assert stub_model_facade.usage_stats.image_usage.total_images == 0
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False):
images = stub_model_facade.generate_image(prompt="test prompt")
assert len(images) == 2
assert images == ["image1", "image2"]
assert stub_model_facade.usage_stats.image_usage.total_images == 2
assert stub_model_facade.usage_stats.image_usage.has_usage is True
def test_generate_image_skip_usage_tracking(
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
"""Test that generate_image respects skip_usage_tracking flag."""
stub_model_client.generate_image.return_value = ImageGenerationResponse(
images=[
ImagePayload(b64_data="image1_base64"),
ImagePayload(b64_data="image2_base64"),
]
)
assert stub_model_facade.usage_stats.image_usage.total_images == 0
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True):
images = stub_model_facade.generate_image(prompt="test prompt", skip_usage_tracking=True)
assert len(images) == 2
assert stub_model_facade.usage_stats.image_usage.total_images == 0
assert stub_model_facade.usage_stats.image_usage.has_usage is False
def test_generate_image_no_image_data(
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
"""Test that generate_image raises ImageGenerationError when no image data in response."""
stub_model_client.generate_image.return_value = ImageGenerationResponse(images=[])
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False):
with pytest.raises(ImageGenerationError, match="No image data found"):
stub_model_facade.generate_image(prompt="test prompt")
assert stub_model_facade.usage_stats.request_usage.failed_requests == 1
assert stub_model_facade.usage_stats.request_usage.successful_requests == 0
def test_generate_image_accumulates_usage(
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
"""Test that generate_image accumulates image usage across multiple calls."""
response1 = ImageGenerationResponse(images=[ImagePayload(b64_data="image1"), ImagePayload(b64_data="image2")])
response2 = ImageGenerationResponse(
images=[ImagePayload(b64_data="image3"), ImagePayload(b64_data="image4"), ImagePayload(b64_data="image5")]
)
stub_model_client.generate_image.side_effect = [response1, response2]
assert stub_model_facade.usage_stats.image_usage.total_images == 0
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True):
images1 = stub_model_facade.generate_image(prompt="test1")
assert len(images1) == 2
assert stub_model_facade.usage_stats.image_usage.total_images == 2
images2 = stub_model_facade.generate_image(prompt="test2")
assert len(images2) == 3
assert stub_model_facade.usage_stats.image_usage.total_images == 5
# =============================================================================
# Async behavior tests
# =============================================================================
@pytest.mark.parametrize(
"skip_usage_tracking",
[
False,
True,
],
)
@pytest.mark.asyncio
async def test_acompletion_success(
stub_completion_messages: list[ChatMessage],
stub_model_configs: Any,
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
skip_usage_tracking: bool,
) -> None:
expected_response = _make_response("Test response")
stub_model_client.acompletion = AsyncMock(return_value=expected_response)
result = await stub_model_facade.acompletion(stub_completion_messages, skip_usage_tracking=skip_usage_tracking)
assert result == expected_response
assert stub_model_client.acompletion.call_count == 1
@pytest.mark.asyncio
async def test_acompletion_with_exception(
stub_completion_messages: list[ChatMessage],
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
stub_model_client.acompletion = AsyncMock(side_effect=Exception("Router error"))
with pytest.raises(Exception, match="Router error"):
await stub_model_facade.acompletion(stub_completion_messages)
@pytest.mark.asyncio
async def test_agenerate_text_embeddings_success(
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
stub_model_client.aembeddings = AsyncMock(return_value=EmbeddingResponse(vectors=expected_vectors))
input_texts = ["test1", "test2"]
result = await stub_model_facade.agenerate_text_embeddings(input_texts)
assert result == expected_vectors
@pytest.mark.parametrize(
"max_correction_steps,max_conversation_restarts,total_calls",
[
(0, 0, 1),
(1, 1, 4),
(1, 2, 6),
(5, 0, 6),
(0, 5, 6),
(3, 3, 16),
],
)
@patch.object(ModelFacade, "acompletion", new_callable=AsyncMock)
@pytest.mark.asyncio
async def test_agenerate_correction_retries(
mock_acompletion: AsyncMock,
stub_model_facade: ModelFacade,
max_correction_steps: int,
max_conversation_restarts: int,
total_calls: int,
) -> None:
bad_response = _make_response("bad response")
mock_acompletion.return_value = bad_response
def _failing_parser(response: str) -> str:
raise ParserException("parser exception")
with pytest.raises(ModelGenerationValidationFailureError):
await stub_model_facade.agenerate(
prompt="foo",
system_prompt="bar",
parser=_failing_parser,
max_correction_steps=max_correction_steps,
max_conversation_restarts=max_conversation_restarts,
)
assert mock_acompletion.call_count == total_calls
with pytest.raises(ModelGenerationValidationFailureError):
await stub_model_facade.agenerate(
prompt="foo",
parser=_failing_parser,
system_prompt="bar",
max_correction_steps=max_correction_steps,
max_conversation_restarts=max_conversation_restarts,
)
assert mock_acompletion.call_count == 2 * total_calls
@patch.object(ModelFacade, "acompletion", new_callable=AsyncMock)
@pytest.mark.asyncio
async def test_agenerate_success(
mock_acompletion: AsyncMock,
stub_model_facade: ModelFacade,
) -> None:
good_response = _make_response("parsed output")
mock_acompletion.return_value = good_response
result, trace = await stub_model_facade.agenerate(prompt="test", parser=lambda x: x)
assert result == "parsed output"
assert mock_acompletion.call_count == 1
assert any(msg.role == "user" for msg in trace)
assert any(msg.role == "assistant" and msg.content == "parsed output" for msg in trace)
# =============================================================================
# Async image generation tests
# =============================================================================
@pytest.mark.asyncio
async def test_agenerate_image_diffusion_success(
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
"""Test async image generation via diffusion API."""
stub_model_client.agenerate_image = AsyncMock(
return_value=ImageGenerationResponse(
images=[ImagePayload(b64_data="image1_base64"), ImagePayload(b64_data="image2_base64")]
)
)
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True):
images = await stub_model_facade.agenerate_image(prompt="test prompt")
assert len(images) == 2
assert images == ["image1_base64", "image2_base64"]
assert stub_model_facade.usage_stats.image_usage.total_images == 2
@pytest.mark.asyncio
async def test_agenerate_image_chat_completion_success(
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
"""Test async image generation via chat completion API."""
stub_model_client.agenerate_image = AsyncMock(
return_value=ImageGenerationResponse(images=[ImagePayload(b64_data="image1")])
)
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False):
images = await stub_model_facade.agenerate_image(prompt="test prompt")
assert len(images) == 1
assert images == ["image1"]
assert stub_model_facade.usage_stats.image_usage.total_images == 1
@pytest.mark.asyncio
async def test_agenerate_image_no_data(
stub_model_facade: ModelFacade,
stub_model_client: MagicMock,
) -> None:
"""Test async image generation raises error when no data."""
stub_model_client.agenerate_image = AsyncMock(return_value=ImageGenerationResponse(images=[]))
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True):
with pytest.raises(ImageGenerationError, match="No image data found"):
await stub_model_facade.agenerate_image(prompt="test prompt")