mirror of
https://github.com/NVIDIA-NeMo/DataDesigner
synced 2026-05-24 09:48:29 +00:00
* perf: defer heavy imports to improve CLI startup time Move expensive imports (engine, models, controllers) out of the module-level import path so that data-designer --help and other non-generation commands no longer pay the full startup cost. Key changes: - Defer controller imports to inside command functions - Remove eager re-export chains from CLI package __init__ files - Move default-settings bootstrap into load_config_builder() and DataDesigner.__init__() instead of running at import time - Add lazy __getattr__ exports in interface/__init__.py - Replace module-level tokenizer init with cached lazy getter - Fix ModelProvider import to use config layer instead of engine - Update test mock paths to match new import locations Reduces CLI import-time from ~1.67s to ~0.46s. * perf: defer pandas/numpy in io_helpers and add config_list benchmark - Replace eager `from lazy_heavy_imports import pd, np` in io_helpers with module-level __getattr__ (for backwards-compatible external access / test mocks) and function-level imports in the 3 functions that actually use them (read_parquet_dataset, smart_load_dataframe, _convert_to_serializable). Importing io_helpers no longer triggers pandas/numpy loading. - Defer heavy imports in list and reset CLI commands into function bodies to avoid loading repositories, Rich, and prompt_toolkit at module import time. - Add `config_list` (data-designer config list) measurement to the CLI startup benchmark with isolated cold measurement in a separate venv and a --skip-config-list-check flag. - Update test mock paths to match new import locations. * Refine lazy import usage and TYPE_CHECKING cleanup * Run license header updater on PR-touched files * fix: update sqlfluff mock target for lazy imports in test_sql * perf: cache globals() in lazy __getattr__ to avoid repeated lookups Add globals() caching and explanatory comment to all three lazy __getattr__ implementations (lazy_heavy_imports, config/__init__, interface/__init__) so subsequent attribute accesses bypass __getattr__. * perf: lazy CLI command loading and deferred heavy import evaluations - Add LazyTyperGroup to defer command module loading until invocation, allowing module-level imports in all CLI command files - Split DataFrameSeedSource into seed_source_dataframe.py to isolate pandas dependency from other seed source classes - Move TypeVar/TypeAlias definitions (DataT, NumpyArray1dT, RadomStateT, EngineT) to TYPE_CHECKING blocks with runtime fallbacks - Wrap module-level constants in lru_cache (phone_number parquet data, jsonschema validator) to defer I/O and heavy imports to first use - Update test mock targets to patch at usage-site for module-level imports * refactor: use direct pandas import in seed_source_dataframe Drop lazy-loading for pandas in DataFrameSeedSource; use direct import for simplicity. * update lazy import pattern * update tests to use lazy import namespace Switch test modules to import data_designer.lazy_heavy_imports as lazy and reference heavy libraries through that namespace. This keeps heavy imports deferred during module import and aligns tests with the new lazy-import usage pattern. * tighten import perf test thresholds Document recent baseline timings and lower the allowed average import time and timeout so regressions are detected sooner. * document pandas import requirement Clarify that Pydantic needs DataFrame resolved at module load and that keeping the direct import preserves IDE typing support. * increase timeout time * use lazy pandas imports in visualization tests - replace direct pandas usage with lazy.pd in visualization tests to avoid eager imports - add TYPE_CHECKING pandas import and keep CLI controller imports sorted * fix lazy pandas runtime usage and preview mocks Switch sample-record handling to lazy pandas types so runtime paths no longer depend on TYPE_CHECKING imports. Align preview controller tests to patch the module-local DataDesigner symbol, preventing real engine invocation in save results scenarios.
1516 lines
57 KiB
Python
1516 lines
57 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
import data_designer.lazy_heavy_imports as lazy
|
|
from data_designer.engine.mcp.errors import MCPConfigurationError, MCPToolError
|
|
from data_designer.engine.models.errors import ImageGenerationError, ModelGenerationValidationFailureError
|
|
from data_designer.engine.models.facade import CustomRouter, ModelFacade
|
|
from data_designer.engine.models.parsers.errors import ParserException
|
|
from data_designer.engine.models.utils import ChatMessage
|
|
from data_designer.engine.testing import StubMCPFacade, StubMCPRegistry, StubMessage, StubResponse
|
|
|
|
if TYPE_CHECKING:
|
|
from litellm.types.utils import EmbeddingResponse, ModelResponse
|
|
|
|
|
|
def mock_oai_response_object(response_text: str) -> StubResponse:
|
|
return StubResponse(StubMessage(content=response_text))
|
|
|
|
|
|
@pytest.fixture
|
|
def stub_model_facade(stub_model_configs, stub_secrets_resolver, stub_model_provider_registry):
|
|
return ModelFacade(
|
|
model_config=stub_model_configs[0],
|
|
secret_resolver=stub_secrets_resolver,
|
|
model_provider_registry=stub_model_provider_registry,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def stub_completion_messages() -> list[ChatMessage]:
|
|
return [ChatMessage.as_user("test")]
|
|
|
|
|
|
@pytest.fixture
|
|
def stub_expected_completion_response():
|
|
return lazy.litellm.types.utils.ModelResponse(
|
|
choices=lazy.litellm.types.utils.Choices(message=lazy.litellm.types.utils.Message(content="Test response"))
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def stub_expected_embedding_response():
|
|
return lazy.litellm.types.utils.EmbeddingResponse(data=[{"embedding": [0.1, 0.2, 0.3]}] * 2)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"max_correction_steps,max_conversation_restarts,total_calls",
|
|
[
|
|
(0, 0, 1),
|
|
(1, 1, 4),
|
|
(1, 2, 6),
|
|
(5, 0, 6),
|
|
(0, 5, 6),
|
|
(3, 3, 16),
|
|
],
|
|
)
|
|
@patch.object(ModelFacade, "completion", autospec=True)
|
|
def test_generate(
|
|
mock_completion: Any,
|
|
stub_model_facade: ModelFacade,
|
|
max_correction_steps: int,
|
|
max_conversation_restarts: int,
|
|
total_calls: int,
|
|
) -> None:
|
|
bad_response = mock_oai_response_object("bad response")
|
|
mock_completion.side_effect = lambda *args, **kwargs: bad_response
|
|
|
|
def _failing_parser(response: str) -> str:
|
|
raise ParserException("parser exception")
|
|
|
|
with pytest.raises(ModelGenerationValidationFailureError):
|
|
stub_model_facade.generate(
|
|
prompt="foo",
|
|
system_prompt="bar",
|
|
parser=_failing_parser,
|
|
max_correction_steps=max_correction_steps,
|
|
max_conversation_restarts=max_conversation_restarts,
|
|
)
|
|
assert mock_completion.call_count == total_calls
|
|
|
|
with pytest.raises(ModelGenerationValidationFailureError):
|
|
stub_model_facade.generate(
|
|
prompt="foo",
|
|
parser=_failing_parser,
|
|
system_prompt="bar",
|
|
max_correction_steps=max_correction_steps,
|
|
max_conversation_restarts=max_conversation_restarts,
|
|
)
|
|
assert mock_completion.call_count == 2 * total_calls
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"system_prompt,expected_messages",
|
|
[
|
|
("", [ChatMessage.as_user("does not matter")]),
|
|
("hello!", [ChatMessage.as_system("hello!"), ChatMessage.as_user("does not matter")]),
|
|
],
|
|
)
|
|
@patch.object(ModelFacade, "completion", autospec=True)
|
|
def test_generate_with_system_prompt(
|
|
mock_completion: Any,
|
|
stub_model_facade: ModelFacade,
|
|
system_prompt: str,
|
|
expected_messages: list[ChatMessage],
|
|
) -> None:
|
|
# Capture messages at call time since they get mutated after the call
|
|
captured_messages = []
|
|
|
|
def capture_and_return(*args: Any, **kwargs: Any) -> ModelResponse:
|
|
captured_messages.append(list(args[1])) # Copy the messages list
|
|
return lazy.litellm.types.utils.ModelResponse(
|
|
choices=lazy.litellm.types.utils.Choices(message=lazy.litellm.types.utils.Message(content="Hello!"))
|
|
)
|
|
|
|
mock_completion.side_effect = capture_and_return
|
|
|
|
stub_model_facade.generate(prompt="does not matter", system_prompt=system_prompt, parser=lambda x: x)
|
|
assert mock_completion.call_count == 1
|
|
assert captured_messages[0] == expected_messages
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"raw_content,expected",
|
|
[
|
|
("\nHello world", "Hello world"),
|
|
(" Hello world ", "Hello world"),
|
|
("\n\n Hello world\n", "Hello world"),
|
|
("Hello world", "Hello world"),
|
|
],
|
|
)
|
|
@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True)
|
|
def test_generate_strips_response_content(
|
|
mock_completion: Any,
|
|
stub_model_facade: ModelFacade,
|
|
raw_content: str,
|
|
expected: str,
|
|
) -> None:
|
|
"""Response content from the LLM is stripped of leading/trailing whitespace."""
|
|
mock_completion.side_effect = lambda *args, **kwargs: StubResponse(StubMessage(content=raw_content))
|
|
result, _ = stub_model_facade.generate(prompt="test", parser=lambda x: x)
|
|
assert result == expected
|
|
|
|
|
|
def test_model_alias_property(stub_model_facade, stub_model_configs):
|
|
assert stub_model_facade.model_alias == stub_model_configs[0].alias
|
|
|
|
|
|
def test_usage_stats_property(stub_model_facade):
|
|
assert stub_model_facade.usage_stats is not None
|
|
assert hasattr(stub_model_facade.usage_stats, "model_dump")
|
|
|
|
|
|
def test_consolidate_kwargs(stub_model_configs, stub_model_facade):
|
|
# Model config generate kwargs are used as base, and purpose is removed
|
|
result = stub_model_facade.consolidate_kwargs(purpose="test")
|
|
assert result == stub_model_configs[0].inference_parameters.generate_kwargs
|
|
|
|
# kwargs overrides model config generate kwargs
|
|
result = stub_model_facade.consolidate_kwargs(temperature=0.01, purpose="test")
|
|
assert result == {**stub_model_configs[0].inference_parameters.generate_kwargs, "temperature": 0.01}
|
|
|
|
# Provider extra_body overrides all other kwargs
|
|
stub_model_facade.model_provider.extra_body = {"foo_provider": "bar_provider"}
|
|
result = stub_model_facade.consolidate_kwargs(extra_body={"foo": "bar"}, purpose="test")
|
|
assert result == {
|
|
**stub_model_configs[0].inference_parameters.generate_kwargs,
|
|
"extra_body": {"foo_provider": "bar_provider", "foo": "bar"},
|
|
}
|
|
|
|
# Provider extra_headers
|
|
stub_model_facade.model_provider.extra_body = None
|
|
stub_model_facade.model_provider.extra_headers = {"hello": "world", "hola": "mundo"}
|
|
result = stub_model_facade.consolidate_kwargs()
|
|
assert result == {
|
|
**stub_model_configs[0].inference_parameters.generate_kwargs,
|
|
"extra_headers": {"hello": "world", "hola": "mundo"},
|
|
}
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"skip_usage_tracking",
|
|
[
|
|
False,
|
|
True,
|
|
],
|
|
)
|
|
@patch.object(CustomRouter, "completion", autospec=True)
|
|
def test_completion_success(
|
|
mock_router_completion: Any,
|
|
stub_completion_messages: list[ChatMessage],
|
|
stub_model_configs: Any,
|
|
stub_model_facade: ModelFacade,
|
|
stub_expected_completion_response: ModelResponse,
|
|
skip_usage_tracking: bool,
|
|
) -> None:
|
|
mock_router_completion.side_effect = lambda self, model, messages, **kwargs: stub_expected_completion_response
|
|
result = stub_model_facade.completion(stub_completion_messages, skip_usage_tracking=skip_usage_tracking)
|
|
expected_messages = [message.to_dict() for message in stub_completion_messages]
|
|
assert result == stub_expected_completion_response
|
|
assert mock_router_completion.call_count == 1
|
|
assert mock_router_completion.call_args[1] == {
|
|
"model": "stub-model-text",
|
|
"messages": expected_messages,
|
|
**stub_model_configs[0].inference_parameters.generate_kwargs,
|
|
}
|
|
|
|
|
|
@patch.object(CustomRouter, "completion", autospec=True)
|
|
def test_completion_with_exception(
|
|
mock_router_completion: Any,
|
|
stub_completion_messages: list[ChatMessage],
|
|
stub_model_facade: ModelFacade,
|
|
) -> None:
|
|
mock_router_completion.side_effect = Exception("Router error")
|
|
|
|
with pytest.raises(Exception, match="Router error"):
|
|
stub_model_facade.completion(stub_completion_messages)
|
|
|
|
|
|
@patch.object(CustomRouter, "completion", autospec=True)
|
|
def test_completion_with_kwargs(
|
|
mock_router_completion: Any,
|
|
stub_completion_messages: list[ChatMessage],
|
|
stub_model_configs: Any,
|
|
stub_model_facade: ModelFacade,
|
|
stub_expected_completion_response: ModelResponse,
|
|
) -> None:
|
|
captured_kwargs = {}
|
|
|
|
def mock_completion(self: Any, model: str, messages: list[dict[str, Any]], **kwargs: Any) -> ModelResponse:
|
|
captured_kwargs.update(kwargs)
|
|
return stub_expected_completion_response
|
|
|
|
mock_router_completion.side_effect = mock_completion
|
|
|
|
kwargs = {"temperature": 0.7, "max_tokens": 100}
|
|
result = stub_model_facade.completion(stub_completion_messages, **kwargs)
|
|
|
|
assert result == stub_expected_completion_response
|
|
# completion kwargs overrides model config generate kwargs
|
|
assert captured_kwargs == {**stub_model_configs[0].inference_parameters.generate_kwargs, **kwargs}
|
|
|
|
|
|
@patch.object(CustomRouter, "embedding", autospec=True)
|
|
def test_generate_text_embeddings_success(
|
|
mock_router_embedding: Any,
|
|
stub_model_facade: ModelFacade,
|
|
stub_expected_embedding_response: EmbeddingResponse,
|
|
) -> None:
|
|
mock_router_embedding.side_effect = lambda self, model, input, **kwargs: stub_expected_embedding_response
|
|
input_texts = ["test1", "test2"]
|
|
result = stub_model_facade.generate_text_embeddings(input_texts)
|
|
assert result == [data["embedding"] for data in stub_expected_embedding_response.data]
|
|
|
|
|
|
@patch.object(CustomRouter, "embedding", autospec=True)
|
|
def test_generate_text_embeddings_with_exception(mock_router_embedding: Any, stub_model_facade: ModelFacade) -> None:
|
|
mock_router_embedding.side_effect = Exception("Router error")
|
|
|
|
with pytest.raises(Exception, match="Router error"):
|
|
stub_model_facade.generate_text_embeddings(["test1", "test2"])
|
|
|
|
|
|
@patch.object(CustomRouter, "embedding", autospec=True)
|
|
def test_generate_text_embeddings_with_kwargs(
|
|
mock_router_embedding: Any,
|
|
stub_model_configs: Any,
|
|
stub_model_facade: ModelFacade,
|
|
stub_expected_embedding_response: EmbeddingResponse,
|
|
) -> None:
|
|
captured_kwargs = {}
|
|
|
|
def mock_embedding(self: Any, model: str, input: list[str], **kwargs: Any) -> EmbeddingResponse:
|
|
captured_kwargs.update(kwargs)
|
|
return stub_expected_embedding_response
|
|
|
|
mock_router_embedding.side_effect = mock_embedding
|
|
kwargs = {"temperature": 0.7, "max_tokens": 100, "input_type": "query"}
|
|
_ = stub_model_facade.generate_text_embeddings(["test1", "test2"], **kwargs)
|
|
assert captured_kwargs == {**stub_model_configs[0].inference_parameters.generate_kwargs, **kwargs}
|
|
|
|
|
|
def test_generate_with_mcp_tools(
|
|
stub_model_configs: Any,
|
|
stub_secrets_resolver: Any,
|
|
stub_model_provider_registry: Any,
|
|
) -> None:
|
|
tool_call = {
|
|
"id": "call-1",
|
|
"type": "function",
|
|
"function": {"name": "lookup", "arguments": '{"query": "foo"}'},
|
|
}
|
|
responses = [
|
|
StubResponse(StubMessage(content=None, tool_calls=[tool_call])),
|
|
StubResponse(StubMessage(content="final result")),
|
|
]
|
|
captured_calls: list[tuple[list[ChatMessage], dict[str, Any]]] = []
|
|
registry_calls: list[tuple[str, str, dict[str, str], None]] = []
|
|
|
|
def process_with_tracking(completion_response: Any) -> list[ChatMessage]:
|
|
message = completion_response.choices[0].message
|
|
if not message.tool_calls:
|
|
return [ChatMessage.as_assistant(content=message.content or "")]
|
|
registry_calls.append(("tools", "lookup", {"query": "foo"}, None))
|
|
return [
|
|
ChatMessage.as_assistant(content="", tool_calls=[tool_call]),
|
|
ChatMessage.as_tool(content="tool-output", tool_call_id="call-1"),
|
|
]
|
|
|
|
facade = StubMCPFacade(
|
|
tool_schemas=[
|
|
{
|
|
"type": "function",
|
|
"function": {"name": "lookup", "description": "Lookup", "parameters": {"type": "object"}},
|
|
}
|
|
],
|
|
process_fn=process_with_tracking,
|
|
)
|
|
registry = StubMCPRegistry(facade)
|
|
|
|
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse:
|
|
captured_calls.append((messages, kwargs))
|
|
return responses.pop(0)
|
|
|
|
model = ModelFacade(
|
|
model_config=stub_model_configs[0],
|
|
secret_resolver=stub_secrets_resolver,
|
|
model_provider_registry=stub_model_provider_registry,
|
|
mcp_registry=registry,
|
|
)
|
|
|
|
with patch.object(ModelFacade, "completion", new=_completion):
|
|
result, _ = model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
|
|
|
|
assert result == "final result"
|
|
assert len(captured_calls) == 2
|
|
assert "tools" in captured_calls[0][1]
|
|
assert captured_calls[0][1]["tools"][0]["function"]["name"] == "lookup"
|
|
assert any(message.role == "tool" for message in captured_calls[1][0])
|
|
assert registry_calls == [("tools", "lookup", {"query": "foo"}, None)]
|
|
|
|
|
|
def test_generate_with_tools_missing_registry(
|
|
stub_model_configs: Any, stub_secrets_resolver: Any, stub_model_provider_registry: Any
|
|
) -> None:
|
|
model = ModelFacade(
|
|
model_config=stub_model_configs[0],
|
|
secret_resolver=stub_secrets_resolver,
|
|
model_provider_registry=stub_model_provider_registry,
|
|
mcp_registry=None,
|
|
)
|
|
|
|
with pytest.raises(MCPConfigurationError):
|
|
model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
|
|
|
|
|
|
# =============================================================================
|
|
# Tool calling integration tests
|
|
# =============================================================================
|
|
|
|
|
|
def test_generate_with_tool_alias_multiple_turns(
|
|
stub_model_configs: Any,
|
|
stub_secrets_resolver: Any,
|
|
stub_model_provider_registry: Any,
|
|
) -> None:
|
|
"""Multiple tool call turns before final response."""
|
|
tool_call_1 = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": '{"query": "foo"}'}}
|
|
tool_call_2 = {"id": "call-2", "type": "function", "function": {"name": "search", "arguments": '{"term": "bar"}'}}
|
|
|
|
responses = [
|
|
StubResponse(StubMessage(content="First lookup", tool_calls=[tool_call_1])),
|
|
StubResponse(StubMessage(content="Second search", tool_calls=[tool_call_2])),
|
|
StubResponse(StubMessage(content="final result after two tool turns")),
|
|
]
|
|
call_count = 0
|
|
|
|
facade = StubMCPFacade(max_tool_call_turns=5)
|
|
registry = StubMCPRegistry(facade)
|
|
|
|
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse:
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return responses.pop(0)
|
|
|
|
model = ModelFacade(
|
|
model_config=stub_model_configs[0],
|
|
secret_resolver=stub_secrets_resolver,
|
|
model_provider_registry=stub_model_provider_registry,
|
|
mcp_registry=registry,
|
|
)
|
|
|
|
with patch.object(ModelFacade, "completion", new=_completion):
|
|
result, trace = model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
|
|
|
|
assert result == "final result after two tool turns"
|
|
assert call_count == 3 # 2 tool turns + 1 final
|
|
|
|
|
|
def test_generate_with_tools_tracks_usage_stats(
|
|
stub_model_configs: Any,
|
|
stub_secrets_resolver: Any,
|
|
stub_model_provider_registry: Any,
|
|
) -> None:
|
|
"""Tool usage stats are properly tracked with generations_with_tools incremented."""
|
|
tool_call_1 = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": '{"query": "foo"}'}}
|
|
tool_call_2 = {"id": "call-2", "type": "function", "function": {"name": "search", "arguments": '{"term": "bar"}'}}
|
|
|
|
responses = [
|
|
StubResponse(StubMessage(content="First lookup", tool_calls=[tool_call_1])),
|
|
StubResponse(StubMessage(content="Second search", tool_calls=[tool_call_2])),
|
|
StubResponse(StubMessage(content="final result")),
|
|
]
|
|
|
|
facade = StubMCPFacade(max_tool_call_turns=5)
|
|
registry = StubMCPRegistry(facade)
|
|
|
|
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse:
|
|
return responses.pop(0)
|
|
|
|
model = ModelFacade(
|
|
model_config=stub_model_configs[0],
|
|
secret_resolver=stub_secrets_resolver,
|
|
model_provider_registry=stub_model_provider_registry,
|
|
mcp_registry=registry,
|
|
)
|
|
|
|
# Verify initial state
|
|
assert model.usage_stats.tool_usage.total_tool_calls == 0
|
|
assert model.usage_stats.tool_usage.total_tool_call_turns == 0
|
|
assert model.usage_stats.tool_usage.total_generations == 0
|
|
assert model.usage_stats.tool_usage.generations_with_tools == 0
|
|
|
|
with patch.object(ModelFacade, "completion", new=_completion):
|
|
result, _ = model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
|
|
|
|
assert result == "final result"
|
|
|
|
# Verify tool usage stats are tracked correctly
|
|
assert model.usage_stats.tool_usage.total_tool_calls == 2 # 2 tool calls total
|
|
assert model.usage_stats.tool_usage.total_tool_call_turns == 2 # 2 turns with tool calls
|
|
assert model.usage_stats.tool_usage.total_generations == 1 # 1 generation
|
|
assert model.usage_stats.tool_usage.generations_with_tools == 1 # 1 generation with tools
|
|
|
|
|
|
def test_generate_with_tools_tracks_multiple_generations(
|
|
stub_model_configs: Any,
|
|
stub_secrets_resolver: Any,
|
|
stub_model_provider_registry: Any,
|
|
) -> None:
|
|
"""Tool usage is correctly tracked across multiple generations."""
|
|
facade = StubMCPFacade(max_tool_call_turns=10)
|
|
registry = StubMCPRegistry(facade)
|
|
|
|
model = ModelFacade(
|
|
model_config=stub_model_configs[0],
|
|
secret_resolver=stub_secrets_resolver,
|
|
model_provider_registry=stub_model_provider_registry,
|
|
mcp_registry=registry,
|
|
)
|
|
|
|
# Generation 1: 2 tool calls across 1 turn
|
|
tool_call_a = {"id": "call-a", "type": "function", "function": {"name": "lookup", "arguments": '{"q": "1"}'}}
|
|
tool_call_b = {"id": "call-b", "type": "function", "function": {"name": "lookup", "arguments": '{"q": "2"}'}}
|
|
responses_gen1 = [
|
|
StubResponse(StubMessage(content="", tool_calls=[tool_call_a, tool_call_b])),
|
|
StubResponse(StubMessage(content="result 1")),
|
|
]
|
|
|
|
def _completion_gen1(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse:
|
|
return responses_gen1.pop(0)
|
|
|
|
with patch.object(ModelFacade, "completion", new=_completion_gen1):
|
|
model.generate(prompt="q1", parser=lambda x: x, tool_alias="tools")
|
|
|
|
# Generation 2: 4 tool calls across 2 turns
|
|
tool_call_c = {"id": "call-c", "type": "function", "function": {"name": "search", "arguments": '{"q": "3"}'}}
|
|
tool_call_d = {"id": "call-d", "type": "function", "function": {"name": "search", "arguments": '{"q": "4"}'}}
|
|
responses_gen2 = [
|
|
StubResponse(StubMessage(content="", tool_calls=[tool_call_a, tool_call_b])),
|
|
StubResponse(StubMessage(content="", tool_calls=[tool_call_c, tool_call_d])),
|
|
StubResponse(StubMessage(content="result 2")),
|
|
]
|
|
|
|
def _completion_gen2(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse:
|
|
return responses_gen2.pop(0)
|
|
|
|
with patch.object(ModelFacade, "completion", new=_completion_gen2):
|
|
model.generate(prompt="q2", parser=lambda x: x, tool_alias="tools")
|
|
|
|
# Generation 3: No tool calls
|
|
responses_gen3 = [
|
|
StubResponse(StubMessage(content="result 3")),
|
|
]
|
|
|
|
def _completion_gen3(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse:
|
|
return responses_gen3.pop(0)
|
|
|
|
with patch.object(ModelFacade, "completion", new=_completion_gen3):
|
|
model.generate(prompt="q3", parser=lambda x: x, tool_alias="tools")
|
|
|
|
# Verify totals: 2 + 4 + 0 = 6 calls, 1 + 2 + 0 = 3 turns, 3 total generations, 2 with tools
|
|
assert model.usage_stats.tool_usage.total_tool_calls == 6
|
|
assert model.usage_stats.tool_usage.total_tool_call_turns == 3
|
|
assert model.usage_stats.tool_usage.total_generations == 3
|
|
assert model.usage_stats.tool_usage.generations_with_tools == 2
|
|
|
|
|
|
def test_generate_tool_turn_limit_triggers_refusal(
|
|
stub_model_configs: Any,
|
|
stub_secrets_resolver: Any,
|
|
stub_model_provider_registry: Any,
|
|
) -> None:
|
|
"""When max_tool_call_turns exceeded, refusal is used."""
|
|
tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}}
|
|
|
|
# Keep returning tool calls to exceed the limit
|
|
responses = [
|
|
StubResponse(StubMessage(content="", tool_calls=[tool_call])), # Turn 1
|
|
StubResponse(StubMessage(content="", tool_calls=[tool_call])), # Turn 2 (max)
|
|
StubResponse(StubMessage(content="", tool_calls=[tool_call])), # Turn 3 (exceeds, should refuse)
|
|
StubResponse(StubMessage(content="final answer after refusal")),
|
|
]
|
|
process_calls = 0
|
|
refuse_calls = 0
|
|
|
|
def custom_process_fn(completion_response: Any) -> list[ChatMessage]:
|
|
nonlocal process_calls
|
|
process_calls += 1
|
|
message = completion_response.choices[0].message
|
|
return [
|
|
ChatMessage.as_assistant(content="", tool_calls=message.tool_calls or []),
|
|
ChatMessage.as_tool(content="tool-result", tool_call_id="call-1"),
|
|
]
|
|
|
|
def custom_refuse_fn(completion_response: Any) -> list[ChatMessage]:
|
|
nonlocal refuse_calls
|
|
refuse_calls += 1
|
|
message = completion_response.choices[0].message
|
|
return [
|
|
ChatMessage.as_assistant(content="", tool_calls=message.tool_calls or []),
|
|
ChatMessage.as_tool(content="REFUSED: Budget exceeded", tool_call_id="call-1"),
|
|
]
|
|
|
|
facade = StubMCPFacade(max_tool_call_turns=2, process_fn=custom_process_fn, refuse_fn=custom_refuse_fn)
|
|
registry = StubMCPRegistry(facade)
|
|
|
|
response_idx = 0
|
|
|
|
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse:
|
|
nonlocal response_idx
|
|
resp = responses[response_idx]
|
|
response_idx += 1
|
|
return resp
|
|
|
|
model = ModelFacade(
|
|
model_config=stub_model_configs[0],
|
|
secret_resolver=stub_secrets_resolver,
|
|
model_provider_registry=stub_model_provider_registry,
|
|
mcp_registry=registry,
|
|
)
|
|
|
|
with patch.object(ModelFacade, "completion", new=_completion):
|
|
result, _ = model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
|
|
|
|
assert result == "final answer after refusal"
|
|
assert process_calls == 2 # Turns 1 and 2
|
|
assert refuse_calls == 1 # Turn 3 was refused
|
|
|
|
|
|
def test_generate_tool_turn_limit_model_responds_after_refusal(
|
|
stub_model_configs: Any,
|
|
stub_secrets_resolver: Any,
|
|
stub_model_provider_registry: Any,
|
|
) -> None:
|
|
"""Model provides final answer after refusal message."""
|
|
tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}}
|
|
|
|
responses = [
|
|
StubResponse(StubMessage(content="", tool_calls=[tool_call])), # Exceeds on first turn
|
|
StubResponse(StubMessage(content="I understand, here is the answer without tools")),
|
|
]
|
|
|
|
def custom_refuse_fn(completion_response: Any) -> list[ChatMessage]:
|
|
return [
|
|
ChatMessage.as_assistant(content="", tool_calls=[tool_call]),
|
|
ChatMessage.as_tool(
|
|
content="Tool call refused: You have reached the maximum number of tool-calling turns.",
|
|
tool_call_id="call-1",
|
|
),
|
|
]
|
|
|
|
facade = StubMCPFacade(
|
|
max_tool_call_turns=0,
|
|
process_fn=lambda _: [], # Should not be called
|
|
refuse_fn=custom_refuse_fn,
|
|
)
|
|
registry = StubMCPRegistry(facade)
|
|
|
|
response_idx = 0
|
|
|
|
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse:
|
|
nonlocal response_idx
|
|
resp = responses[response_idx]
|
|
response_idx += 1
|
|
return resp
|
|
|
|
model = ModelFacade(
|
|
model_config=stub_model_configs[0],
|
|
secret_resolver=stub_secrets_resolver,
|
|
model_provider_registry=stub_model_provider_registry,
|
|
mcp_registry=registry,
|
|
)
|
|
|
|
with patch.object(ModelFacade, "completion", new=_completion):
|
|
result, trace = model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
|
|
|
|
assert result == "I understand, here is the answer without tools"
|
|
# Trace should include refusal message
|
|
assert any(msg.content and "refused" in msg.content.lower() for msg in trace if msg.role == "tool")
|
|
|
|
|
|
def test_generate_tool_alias_not_in_registry(
|
|
stub_model_configs: Any,
|
|
stub_secrets_resolver: Any,
|
|
stub_model_provider_registry: Any,
|
|
) -> None:
|
|
"""Raises error when tool_alias not found in MCPRegistry."""
|
|
|
|
class StubMCPRegistry:
|
|
def get_mcp(self, *, tool_alias: str) -> Any:
|
|
raise ValueError(f"No tool config with alias {tool_alias!r} found!")
|
|
|
|
model = ModelFacade(
|
|
model_config=stub_model_configs[0],
|
|
secret_resolver=stub_secrets_resolver,
|
|
model_provider_registry=stub_model_provider_registry,
|
|
mcp_registry=StubMCPRegistry(),
|
|
)
|
|
|
|
with pytest.raises(MCPConfigurationError, match="not registered"):
|
|
model.generate(prompt="question", parser=lambda x: x, tool_alias="nonexistent")
|
|
|
|
|
|
def test_generate_no_tool_alias_ignores_mcp(
|
|
stub_model_configs: Any,
|
|
stub_secrets_resolver: Any,
|
|
stub_model_provider_registry: Any,
|
|
) -> None:
|
|
"""When tool_alias is None, no MCP operations occur."""
|
|
get_mcp_called = False
|
|
|
|
class StubMCPRegistry:
|
|
def get_mcp(self, *, tool_alias: str) -> Any:
|
|
nonlocal get_mcp_called
|
|
get_mcp_called = True
|
|
raise RuntimeError("Should not be called")
|
|
|
|
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse:
|
|
assert "tools" not in kwargs # No tools should be passed
|
|
return StubResponse(StubMessage(content="response without tools"))
|
|
|
|
model = ModelFacade(
|
|
model_config=stub_model_configs[0],
|
|
secret_resolver=stub_secrets_resolver,
|
|
model_provider_registry=stub_model_provider_registry,
|
|
mcp_registry=StubMCPRegistry(),
|
|
)
|
|
|
|
with patch.object(ModelFacade, "completion", new=_completion):
|
|
result, _ = model.generate(prompt="question", parser=lambda x: x, tool_alias=None)
|
|
|
|
assert result == "response without tools"
|
|
assert get_mcp_called is False
|
|
|
|
|
|
def test_generate_tool_calls_with_parser_corrections(
|
|
stub_model_configs: Any,
|
|
stub_secrets_resolver: Any,
|
|
stub_model_provider_registry: Any,
|
|
) -> None:
|
|
"""Tool calling works correctly with parser correction steps."""
|
|
tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}}
|
|
parse_count = 0
|
|
|
|
responses = [
|
|
StubResponse(StubMessage(content="", tool_calls=[tool_call])), # Tool call
|
|
StubResponse(StubMessage(content="bad format")), # Parser will fail
|
|
StubResponse(StubMessage(content="correct format")), # Parser will succeed
|
|
]
|
|
|
|
facade = StubMCPFacade()
|
|
registry = StubMCPRegistry(facade)
|
|
|
|
response_idx = 0
|
|
|
|
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse:
|
|
nonlocal response_idx
|
|
resp = responses[response_idx]
|
|
response_idx += 1
|
|
return resp
|
|
|
|
def _parser(text: str) -> str:
|
|
nonlocal parse_count
|
|
parse_count += 1
|
|
if text == "bad format":
|
|
raise ParserException("Invalid format")
|
|
return text
|
|
|
|
model = ModelFacade(
|
|
model_config=stub_model_configs[0],
|
|
secret_resolver=stub_secrets_resolver,
|
|
model_provider_registry=stub_model_provider_registry,
|
|
mcp_registry=registry,
|
|
)
|
|
|
|
with patch.object(ModelFacade, "completion", new=_completion):
|
|
result, _ = model.generate(prompt="question", parser=_parser, tool_alias="tools", max_correction_steps=1)
|
|
|
|
assert result == "correct format"
|
|
assert parse_count == 2 # Failed once, then succeeded
|
|
|
|
|
|
def test_generate_tool_calls_with_conversation_restarts(
|
|
stub_model_configs: Any,
|
|
stub_secrets_resolver: Any,
|
|
stub_model_provider_registry: Any,
|
|
) -> None:
|
|
"""Tool calling works correctly with conversation restarts."""
|
|
tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}}
|
|
messages_at_call: list[int] = []
|
|
|
|
# First conversation: tool call + bad response
|
|
# After restart: tool call + good response
|
|
responses = [
|
|
StubResponse(StubMessage(content="", tool_calls=[tool_call])),
|
|
StubResponse(StubMessage(content="still bad")), # Fails parser, triggers restart
|
|
StubResponse(StubMessage(content="", tool_calls=[tool_call])), # After restart
|
|
StubResponse(StubMessage(content="good result")),
|
|
]
|
|
|
|
facade = StubMCPFacade()
|
|
registry = StubMCPRegistry(facade)
|
|
|
|
response_idx = 0
|
|
|
|
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse:
|
|
nonlocal response_idx
|
|
messages_at_call.append(len(messages))
|
|
resp = responses[response_idx]
|
|
response_idx += 1
|
|
return resp
|
|
|
|
def _parser(text: str) -> str:
|
|
if text == "still bad":
|
|
raise ParserException("Bad format")
|
|
return text
|
|
|
|
model = ModelFacade(
|
|
model_config=stub_model_configs[0],
|
|
secret_resolver=stub_secrets_resolver,
|
|
model_provider_registry=stub_model_provider_registry,
|
|
mcp_registry=registry,
|
|
)
|
|
|
|
with patch.object(ModelFacade, "completion", new=_completion):
|
|
result, _ = model.generate(
|
|
prompt="question", parser=_parser, tool_alias="tools", max_correction_steps=0, max_conversation_restarts=1
|
|
)
|
|
|
|
assert result == "good result"
|
|
# After restart, message count should preserve tool call history (restart from checkpoint)
|
|
assert messages_at_call[2] == messages_at_call[1] # Both should be post-tool-call message count
|
|
|
|
|
|
# =============================================================================
|
|
# Message trace tests
|
|
# =============================================================================
|
|
|
|
|
|
def test_generate_trace_includes_tool_calls(
|
|
stub_model_configs: Any,
|
|
stub_secrets_resolver: Any,
|
|
stub_model_provider_registry: Any,
|
|
) -> None:
|
|
"""Returned trace includes tool call messages."""
|
|
tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": '{"q": "test"}'}}
|
|
|
|
responses = [
|
|
StubResponse(StubMessage(content="Let me look that up", tool_calls=[tool_call])),
|
|
StubResponse(StubMessage(content="Here is the answer")),
|
|
]
|
|
|
|
facade = StubMCPFacade()
|
|
registry = StubMCPRegistry(facade)
|
|
|
|
response_idx = 0
|
|
|
|
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse:
|
|
nonlocal response_idx
|
|
resp = responses[response_idx]
|
|
response_idx += 1
|
|
return resp
|
|
|
|
model = ModelFacade(
|
|
model_config=stub_model_configs[0],
|
|
secret_resolver=stub_secrets_resolver,
|
|
model_provider_registry=stub_model_provider_registry,
|
|
mcp_registry=registry,
|
|
)
|
|
|
|
with patch.object(ModelFacade, "completion", new=_completion):
|
|
_, trace = model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
|
|
|
|
# Find assistant message with tool_calls
|
|
assistant_with_tools = [msg for msg in trace if msg.role == "assistant" and msg.tool_calls]
|
|
assert len(assistant_with_tools) >= 1
|
|
assert assistant_with_tools[0].tool_calls[0]["function"]["name"] == "lookup"
|
|
|
|
|
|
def test_generate_trace_includes_tool_responses(
|
|
stub_model_configs: Any,
|
|
stub_secrets_resolver: Any,
|
|
stub_model_provider_registry: Any,
|
|
) -> None:
|
|
"""Returned trace includes tool response messages."""
|
|
tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}}
|
|
|
|
responses = [
|
|
StubResponse(StubMessage(content="", tool_calls=[tool_call])),
|
|
StubResponse(StubMessage(content="final")),
|
|
]
|
|
|
|
def custom_process_fn(completion_response: Any) -> list[ChatMessage]:
|
|
return [
|
|
ChatMessage.as_assistant(content="", tool_calls=[tool_call]),
|
|
ChatMessage.as_tool(content="THE TOOL RESPONSE CONTENT", tool_call_id="call-1"),
|
|
]
|
|
|
|
facade = StubMCPFacade(process_fn=custom_process_fn)
|
|
registry = StubMCPRegistry(facade)
|
|
|
|
response_idx = 0
|
|
|
|
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse:
|
|
nonlocal response_idx
|
|
resp = responses[response_idx]
|
|
response_idx += 1
|
|
return resp
|
|
|
|
model = ModelFacade(
|
|
model_config=stub_model_configs[0],
|
|
secret_resolver=stub_secrets_resolver,
|
|
model_provider_registry=stub_model_provider_registry,
|
|
mcp_registry=registry,
|
|
)
|
|
|
|
with patch.object(ModelFacade, "completion", new=_completion):
|
|
_, trace = model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
|
|
|
|
tool_messages = [msg for msg in trace if msg.role == "tool"]
|
|
assert len(tool_messages) >= 1
|
|
assert tool_messages[0].content == "THE TOOL RESPONSE CONTENT"
|
|
assert tool_messages[0].tool_call_id == "call-1"
|
|
|
|
|
|
def test_generate_trace_includes_refusal_messages(
|
|
stub_model_configs: Any,
|
|
stub_secrets_resolver: Any,
|
|
stub_model_provider_registry: Any,
|
|
) -> None:
|
|
"""Returned trace includes refusal messages when budget exhausted."""
|
|
tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}}
|
|
|
|
responses = [
|
|
StubResponse(StubMessage(content="", tool_calls=[tool_call])), # Will be refused (max_turns=0)
|
|
StubResponse(StubMessage(content="answer without tools")),
|
|
]
|
|
|
|
def custom_refuse_fn(completion_response: Any) -> list[ChatMessage]:
|
|
return [
|
|
ChatMessage.as_assistant(content="", tool_calls=[tool_call]),
|
|
ChatMessage.as_tool(content="BUDGET_EXCEEDED_REFUSAL", tool_call_id="call-1"),
|
|
]
|
|
|
|
facade = StubMCPFacade(
|
|
max_tool_call_turns=0,
|
|
process_fn=lambda _: [],
|
|
refuse_fn=custom_refuse_fn,
|
|
)
|
|
registry = StubMCPRegistry(facade)
|
|
|
|
response_idx = 0
|
|
|
|
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse:
|
|
nonlocal response_idx
|
|
resp = responses[response_idx]
|
|
response_idx += 1
|
|
return resp
|
|
|
|
model = ModelFacade(
|
|
model_config=stub_model_configs[0],
|
|
secret_resolver=stub_secrets_resolver,
|
|
model_provider_registry=stub_model_provider_registry,
|
|
mcp_registry=registry,
|
|
)
|
|
|
|
with patch.object(ModelFacade, "completion", new=_completion):
|
|
_, trace = model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
|
|
|
|
# Check for refusal message in trace
|
|
tool_messages = [msg for msg in trace if msg.role == "tool"]
|
|
assert any("BUDGET_EXCEEDED_REFUSAL" in msg.content for msg in tool_messages)
|
|
|
|
|
|
def test_generate_trace_preserves_reasoning_content(
|
|
stub_model_configs: Any,
|
|
stub_secrets_resolver: Any,
|
|
stub_model_provider_registry: Any,
|
|
) -> None:
|
|
"""Trace messages preserve reasoning_content field."""
|
|
response = StubResponse(
|
|
StubMessage(
|
|
content="The answer is 42",
|
|
reasoning_content="Let me think about this carefully...",
|
|
)
|
|
)
|
|
|
|
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse:
|
|
return response
|
|
|
|
model = ModelFacade(
|
|
model_config=stub_model_configs[0],
|
|
secret_resolver=stub_secrets_resolver,
|
|
model_provider_registry=stub_model_provider_registry,
|
|
)
|
|
|
|
with patch.object(ModelFacade, "completion", new=_completion):
|
|
_, trace = model.generate(prompt="question", parser=lambda x: x)
|
|
|
|
# Find assistant message and check reasoning content
|
|
assistant_messages = [msg for msg in trace if msg.role == "assistant"]
|
|
assert len(assistant_messages) >= 1
|
|
assert assistant_messages[-1].reasoning_content == "Let me think about this carefully..."
|
|
|
|
|
|
# =============================================================================
|
|
# Error handling tests
|
|
# =============================================================================
|
|
|
|
|
|
def test_generate_tool_execution_error(
|
|
stub_model_configs: Any,
|
|
stub_secrets_resolver: Any,
|
|
stub_model_provider_registry: Any,
|
|
) -> None:
|
|
"""Handles MCP tool execution errors appropriately."""
|
|
tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}}
|
|
|
|
responses = [StubResponse(StubMessage(content="", tool_calls=[tool_call]))]
|
|
|
|
def error_process_fn(completion_response: Any) -> list[ChatMessage]:
|
|
raise MCPToolError("Tool execution failed: Connection refused")
|
|
|
|
facade = StubMCPFacade(process_fn=error_process_fn)
|
|
registry = StubMCPRegistry(facade)
|
|
|
|
response_idx = 0
|
|
|
|
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse:
|
|
nonlocal response_idx
|
|
resp = responses[response_idx]
|
|
response_idx += 1
|
|
return resp
|
|
|
|
model = ModelFacade(
|
|
model_config=stub_model_configs[0],
|
|
secret_resolver=stub_secrets_resolver,
|
|
model_provider_registry=stub_model_provider_registry,
|
|
mcp_registry=registry,
|
|
)
|
|
|
|
with patch.object(ModelFacade, "completion", new=_completion):
|
|
with pytest.raises(MCPToolError, match="Connection refused"):
|
|
model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
|
|
|
|
|
|
def test_generate_tool_invalid_arguments(
|
|
stub_model_configs: Any,
|
|
stub_secrets_resolver: Any,
|
|
stub_model_provider_registry: Any,
|
|
) -> None:
|
|
"""Handles invalid tool arguments from LLM."""
|
|
# Tool call with invalid JSON arguments
|
|
tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "not valid json"}}
|
|
|
|
responses = [StubResponse(StubMessage(content="", tool_calls=[tool_call]))]
|
|
|
|
def error_process_fn(completion_response: Any) -> list[ChatMessage]:
|
|
raise MCPToolError("Invalid tool arguments for 'lookup': not valid json")
|
|
|
|
facade = StubMCPFacade(process_fn=error_process_fn)
|
|
registry = StubMCPRegistry(facade)
|
|
|
|
response_idx = 0
|
|
|
|
def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse:
|
|
nonlocal response_idx
|
|
resp = responses[response_idx]
|
|
response_idx += 1
|
|
return resp
|
|
|
|
model = ModelFacade(
|
|
model_config=stub_model_configs[0],
|
|
secret_resolver=stub_secrets_resolver,
|
|
model_provider_registry=stub_model_provider_registry,
|
|
mcp_registry=registry,
|
|
)
|
|
|
|
with patch.object(ModelFacade, "completion", new=_completion):
|
|
with pytest.raises(MCPToolError, match="Invalid tool arguments"):
|
|
model.generate(prompt="question", parser=lambda x: x, tool_alias="tools")
|
|
|
|
|
|
# =============================================================================
|
|
# Image generation tests
|
|
# =============================================================================
|
|
|
|
|
|
@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True)
|
|
def test_generate_image_diffusion_tracks_image_usage(
|
|
mock_image_generation: Any,
|
|
stub_model_facade: ModelFacade,
|
|
) -> None:
|
|
"""Test that generate_image tracks image usage for diffusion models."""
|
|
# Mock response with 3 images
|
|
mock_response = lazy.litellm.types.utils.ImageResponse(
|
|
data=[
|
|
lazy.litellm.types.utils.ImageObject(b64_json="image1_base64"),
|
|
lazy.litellm.types.utils.ImageObject(b64_json="image2_base64"),
|
|
lazy.litellm.types.utils.ImageObject(b64_json="image3_base64"),
|
|
]
|
|
)
|
|
mock_image_generation.return_value = mock_response
|
|
|
|
# Verify initial state
|
|
assert stub_model_facade.usage_stats.image_usage.total_images == 0
|
|
|
|
# Generate images
|
|
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True):
|
|
images = stub_model_facade.generate_image(prompt="test prompt", n=3)
|
|
|
|
# Verify results
|
|
assert len(images) == 3
|
|
assert images == ["image1_base64", "image2_base64", "image3_base64"]
|
|
|
|
# Verify image usage was tracked
|
|
assert stub_model_facade.usage_stats.image_usage.total_images == 3
|
|
assert stub_model_facade.usage_stats.image_usage.has_usage is True
|
|
|
|
|
|
@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True)
|
|
def test_generate_image_chat_completion_tracks_image_usage(
|
|
mock_completion: Any,
|
|
stub_model_facade: ModelFacade,
|
|
) -> None:
|
|
"""Test that generate_image tracks image usage for chat completion models."""
|
|
# Mock response with images attribute (Message requires type and index per ImageURLListItem)
|
|
mock_message = lazy.litellm.types.utils.Message(
|
|
role="assistant",
|
|
content="",
|
|
images=[
|
|
lazy.litellm.types.utils.ImageURLListItem(
|
|
type="image_url", image_url={"url": "data:image/png;base64,image1"}, index=0
|
|
),
|
|
lazy.litellm.types.utils.ImageURLListItem(
|
|
type="image_url", image_url={"url": "data:image/png;base64,image2"}, index=1
|
|
),
|
|
],
|
|
)
|
|
mock_response = lazy.litellm.types.utils.ModelResponse(
|
|
choices=[lazy.litellm.types.utils.Choices(message=mock_message)]
|
|
)
|
|
mock_completion.return_value = mock_response
|
|
|
|
# Verify initial state
|
|
assert stub_model_facade.usage_stats.image_usage.total_images == 0
|
|
|
|
# Generate images
|
|
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False):
|
|
images = stub_model_facade.generate_image(prompt="test prompt")
|
|
|
|
# Verify results
|
|
assert len(images) == 2
|
|
assert images == ["image1", "image2"]
|
|
|
|
# Verify image usage was tracked
|
|
assert stub_model_facade.usage_stats.image_usage.total_images == 2
|
|
assert stub_model_facade.usage_stats.image_usage.has_usage is True
|
|
|
|
|
|
@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True)
|
|
def test_generate_image_chat_completion_with_dict_format(
|
|
mock_completion: Any,
|
|
stub_model_facade: ModelFacade,
|
|
) -> None:
|
|
"""Test that generate_image handles images as dicts with image_url string."""
|
|
# Create mock message with images as dict with string image_url
|
|
mock_message = MagicMock()
|
|
mock_message.role = "assistant"
|
|
mock_message.content = ""
|
|
mock_message.images = [
|
|
{"image_url": "data:image/png;base64,image1"},
|
|
{"image_url": "data:image/jpeg;base64,image2"},
|
|
]
|
|
|
|
mock_choice = MagicMock()
|
|
mock_choice.message = mock_message
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [mock_choice]
|
|
|
|
mock_completion.return_value = mock_response
|
|
|
|
# Generate images
|
|
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False):
|
|
images = stub_model_facade.generate_image(prompt="test prompt")
|
|
|
|
# Verify results
|
|
assert len(images) == 2
|
|
assert images == ["image1", "image2"]
|
|
|
|
|
|
@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True)
|
|
def test_generate_image_chat_completion_with_plain_strings(
|
|
mock_completion: Any,
|
|
stub_model_facade: ModelFacade,
|
|
) -> None:
|
|
"""Test that generate_image handles images as plain strings."""
|
|
# Create mock message with images as plain strings
|
|
mock_message = MagicMock()
|
|
mock_message.role = "assistant"
|
|
mock_message.content = ""
|
|
mock_message.images = [
|
|
"data:image/png;base64,image1",
|
|
"image2", # Plain base64 without data URI prefix
|
|
]
|
|
|
|
mock_choice = MagicMock()
|
|
mock_choice.message = mock_message
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [mock_choice]
|
|
|
|
mock_completion.return_value = mock_response
|
|
|
|
# Generate images
|
|
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False):
|
|
images = stub_model_facade.generate_image(prompt="test prompt")
|
|
|
|
# Verify results
|
|
assert len(images) == 2
|
|
assert images == ["image1", "image2"]
|
|
|
|
|
|
@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True)
|
|
def test_generate_image_skip_usage_tracking(
|
|
mock_image_generation: Any,
|
|
stub_model_facade: ModelFacade,
|
|
) -> None:
|
|
"""Test that generate_image respects skip_usage_tracking flag."""
|
|
mock_response = lazy.litellm.types.utils.ImageResponse(
|
|
data=[
|
|
lazy.litellm.types.utils.ImageObject(b64_json="image1_base64"),
|
|
lazy.litellm.types.utils.ImageObject(b64_json="image2_base64"),
|
|
]
|
|
)
|
|
mock_image_generation.return_value = mock_response
|
|
|
|
# Verify initial state
|
|
assert stub_model_facade.usage_stats.image_usage.total_images == 0
|
|
|
|
# Generate images with skip_usage_tracking=True
|
|
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True):
|
|
images = stub_model_facade.generate_image(prompt="test prompt", skip_usage_tracking=True)
|
|
|
|
# Verify results
|
|
assert len(images) == 2
|
|
|
|
# Verify image usage was NOT tracked
|
|
assert stub_model_facade.usage_stats.image_usage.total_images == 0
|
|
assert stub_model_facade.usage_stats.image_usage.has_usage is False
|
|
|
|
|
|
@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True)
|
|
def test_generate_image_chat_completion_no_choices(
|
|
mock_completion: Any,
|
|
stub_model_facade: ModelFacade,
|
|
) -> None:
|
|
"""Test that generate_image raises ImageGenerationError when response has no choices."""
|
|
mock_response = lazy.litellm.types.utils.ModelResponse(choices=[])
|
|
mock_completion.return_value = mock_response
|
|
|
|
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False):
|
|
with pytest.raises(ImageGenerationError, match="Image generation response missing choices"):
|
|
stub_model_facade.generate_image(prompt="test prompt")
|
|
|
|
|
|
@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True)
|
|
def test_generate_image_chat_completion_no_image_data(
|
|
mock_completion: Any,
|
|
stub_model_facade: ModelFacade,
|
|
) -> None:
|
|
"""Test that generate_image raises ImageGenerationError when no image data in response."""
|
|
mock_message = lazy.litellm.types.utils.Message(role="assistant", content="just text, no image")
|
|
mock_response = lazy.litellm.types.utils.ModelResponse(
|
|
choices=[lazy.litellm.types.utils.Choices(message=mock_message)]
|
|
)
|
|
mock_completion.return_value = mock_response
|
|
|
|
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False):
|
|
with pytest.raises(ImageGenerationError, match="No image data found in image generation response"):
|
|
stub_model_facade.generate_image(prompt="test prompt")
|
|
|
|
|
|
@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True)
|
|
def test_generate_image_diffusion_no_data(
|
|
mock_image_generation: Any,
|
|
stub_model_facade: ModelFacade,
|
|
) -> None:
|
|
"""Test that generate_image raises ImageGenerationError when diffusion API returns no data."""
|
|
mock_response = lazy.litellm.types.utils.ImageResponse(data=[])
|
|
mock_image_generation.return_value = mock_response
|
|
|
|
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True):
|
|
with pytest.raises(ImageGenerationError, match="Image generation returned no data"):
|
|
stub_model_facade.generate_image(prompt="test prompt")
|
|
|
|
|
|
@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True)
|
|
def test_generate_image_accumulates_usage(
|
|
mock_image_generation: Any,
|
|
stub_model_facade: ModelFacade,
|
|
) -> None:
|
|
"""Test that generate_image accumulates image usage across multiple calls."""
|
|
# First call - 2 images
|
|
mock_response1 = lazy.litellm.types.utils.ImageResponse(
|
|
data=[
|
|
lazy.litellm.types.utils.ImageObject(b64_json="image1"),
|
|
lazy.litellm.types.utils.ImageObject(b64_json="image2"),
|
|
]
|
|
)
|
|
# Second call - 3 images
|
|
mock_response2 = lazy.litellm.types.utils.ImageResponse(
|
|
data=[
|
|
lazy.litellm.types.utils.ImageObject(b64_json="image3"),
|
|
lazy.litellm.types.utils.ImageObject(b64_json="image4"),
|
|
lazy.litellm.types.utils.ImageObject(b64_json="image5"),
|
|
]
|
|
)
|
|
mock_image_generation.side_effect = [mock_response1, mock_response2]
|
|
|
|
# Verify initial state
|
|
assert stub_model_facade.usage_stats.image_usage.total_images == 0
|
|
|
|
# First generation
|
|
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True):
|
|
images1 = stub_model_facade.generate_image(prompt="test1")
|
|
assert len(images1) == 2
|
|
assert stub_model_facade.usage_stats.image_usage.total_images == 2
|
|
|
|
# Second generation
|
|
images2 = stub_model_facade.generate_image(prompt="test2")
|
|
assert len(images2) == 3
|
|
# Usage should accumulate
|
|
assert stub_model_facade.usage_stats.image_usage.total_images == 5
|
|
|
|
|
|
# =============================================================================
|
|
# Async behavior tests
|
|
# =============================================================================
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"skip_usage_tracking",
|
|
[
|
|
False,
|
|
True,
|
|
],
|
|
)
|
|
@patch.object(CustomRouter, "acompletion", new_callable=AsyncMock)
|
|
@pytest.mark.asyncio
|
|
async def test_acompletion_success(
|
|
mock_router_acompletion: AsyncMock,
|
|
stub_completion_messages: list[ChatMessage],
|
|
stub_model_configs: Any,
|
|
stub_model_facade: ModelFacade,
|
|
stub_expected_completion_response: ModelResponse,
|
|
skip_usage_tracking: bool,
|
|
) -> None:
|
|
mock_router_acompletion.return_value = stub_expected_completion_response
|
|
result = await stub_model_facade.acompletion(stub_completion_messages, skip_usage_tracking=skip_usage_tracking)
|
|
expected_messages = [message.to_dict() for message in stub_completion_messages]
|
|
assert result == stub_expected_completion_response
|
|
assert mock_router_acompletion.call_count == 1
|
|
assert mock_router_acompletion.call_args[1] == {
|
|
"model": "stub-model-text",
|
|
"messages": expected_messages,
|
|
**stub_model_configs[0].inference_parameters.generate_kwargs,
|
|
}
|
|
|
|
|
|
@patch.object(CustomRouter, "acompletion", new_callable=AsyncMock)
|
|
@pytest.mark.asyncio
|
|
async def test_acompletion_with_exception(
|
|
mock_router_acompletion: AsyncMock,
|
|
stub_completion_messages: list[ChatMessage],
|
|
stub_model_facade: ModelFacade,
|
|
) -> None:
|
|
mock_router_acompletion.side_effect = Exception("Router error")
|
|
|
|
with pytest.raises(Exception, match="Router error"):
|
|
await stub_model_facade.acompletion(stub_completion_messages)
|
|
|
|
|
|
@patch.object(CustomRouter, "aembedding", new_callable=AsyncMock)
|
|
@pytest.mark.asyncio
|
|
async def test_agenerate_text_embeddings_success(
|
|
mock_router_aembedding: AsyncMock,
|
|
stub_model_facade: ModelFacade,
|
|
stub_expected_embedding_response: EmbeddingResponse,
|
|
) -> None:
|
|
mock_router_aembedding.return_value = stub_expected_embedding_response
|
|
input_texts = ["test1", "test2"]
|
|
result = await stub_model_facade.agenerate_text_embeddings(input_texts)
|
|
assert result == [data["embedding"] for data in stub_expected_embedding_response.data]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"max_correction_steps,max_conversation_restarts,total_calls",
|
|
[
|
|
(0, 0, 1),
|
|
(1, 1, 4),
|
|
(1, 2, 6),
|
|
(5, 0, 6),
|
|
(0, 5, 6),
|
|
(3, 3, 16),
|
|
],
|
|
)
|
|
@patch.object(ModelFacade, "acompletion", new_callable=AsyncMock)
|
|
@pytest.mark.asyncio
|
|
async def test_agenerate_correction_retries(
|
|
mock_acompletion: AsyncMock,
|
|
stub_model_facade: ModelFacade,
|
|
max_correction_steps: int,
|
|
max_conversation_restarts: int,
|
|
total_calls: int,
|
|
) -> None:
|
|
bad_response = mock_oai_response_object("bad response")
|
|
mock_acompletion.return_value = bad_response
|
|
|
|
def _failing_parser(response: str) -> str:
|
|
raise ParserException("parser exception")
|
|
|
|
with pytest.raises(ModelGenerationValidationFailureError):
|
|
await stub_model_facade.agenerate(
|
|
prompt="foo",
|
|
system_prompt="bar",
|
|
parser=_failing_parser,
|
|
max_correction_steps=max_correction_steps,
|
|
max_conversation_restarts=max_conversation_restarts,
|
|
)
|
|
assert mock_acompletion.call_count == total_calls
|
|
|
|
with pytest.raises(ModelGenerationValidationFailureError):
|
|
await stub_model_facade.agenerate(
|
|
prompt="foo",
|
|
parser=_failing_parser,
|
|
system_prompt="bar",
|
|
max_correction_steps=max_correction_steps,
|
|
max_conversation_restarts=max_conversation_restarts,
|
|
)
|
|
assert mock_acompletion.call_count == 2 * total_calls
|
|
|
|
|
|
@patch.object(ModelFacade, "acompletion", new_callable=AsyncMock)
|
|
@pytest.mark.asyncio
|
|
async def test_agenerate_success(
|
|
mock_acompletion: AsyncMock,
|
|
stub_model_facade: ModelFacade,
|
|
) -> None:
|
|
good_response = mock_oai_response_object("parsed output")
|
|
mock_acompletion.return_value = good_response
|
|
|
|
result, trace = await stub_model_facade.agenerate(prompt="test", parser=lambda x: x)
|
|
assert result == "parsed output"
|
|
assert mock_acompletion.call_count == 1
|
|
# Trace should contain at least the user prompt and the assistant response
|
|
assert any(msg.role == "user" for msg in trace)
|
|
assert any(msg.role == "assistant" and msg.content == "parsed output" for msg in trace)
|
|
|
|
|
|
# =============================================================================
|
|
# Async image generation tests
|
|
# =============================================================================
|
|
|
|
|
|
@patch("data_designer.engine.models.facade.CustomRouter.aimage_generation", new_callable=AsyncMock)
|
|
@pytest.mark.asyncio
|
|
async def test_agenerate_image_diffusion_success(
|
|
mock_aimage_generation: AsyncMock,
|
|
stub_model_facade: ModelFacade,
|
|
) -> None:
|
|
"""Test async image generation via diffusion API."""
|
|
mock_response = lazy.litellm.types.utils.ImageResponse(
|
|
data=[
|
|
lazy.litellm.types.utils.ImageObject(b64_json="image1_base64"),
|
|
lazy.litellm.types.utils.ImageObject(b64_json="image2_base64"),
|
|
]
|
|
)
|
|
mock_aimage_generation.return_value = mock_response
|
|
|
|
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True):
|
|
images = await stub_model_facade.agenerate_image(prompt="test prompt")
|
|
|
|
assert len(images) == 2
|
|
assert images == ["image1_base64", "image2_base64"]
|
|
assert mock_aimage_generation.call_count == 1
|
|
# Verify image usage was tracked
|
|
assert stub_model_facade.usage_stats.image_usage.total_images == 2
|
|
|
|
|
|
@patch.object(ModelFacade, "acompletion", new_callable=AsyncMock)
|
|
@pytest.mark.asyncio
|
|
async def test_agenerate_image_chat_completion_success(
|
|
mock_acompletion: AsyncMock,
|
|
stub_model_facade: ModelFacade,
|
|
) -> None:
|
|
"""Test async image generation via chat completion API."""
|
|
mock_message = lazy.litellm.types.utils.Message(
|
|
role="assistant",
|
|
content="",
|
|
images=[
|
|
lazy.litellm.types.utils.ImageURLListItem(
|
|
type="image_url", image_url={"url": "data:image/png;base64,image1"}, index=0
|
|
),
|
|
],
|
|
)
|
|
mock_response = lazy.litellm.types.utils.ModelResponse(
|
|
choices=[lazy.litellm.types.utils.Choices(message=mock_message)]
|
|
)
|
|
mock_acompletion.return_value = mock_response
|
|
|
|
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False):
|
|
images = await stub_model_facade.agenerate_image(prompt="test prompt")
|
|
|
|
assert len(images) == 1
|
|
assert images == ["image1"]
|
|
assert mock_acompletion.call_count == 1
|
|
assert stub_model_facade.usage_stats.image_usage.total_images == 1
|
|
|
|
|
|
@patch("data_designer.engine.models.facade.CustomRouter.aimage_generation", new_callable=AsyncMock)
|
|
@pytest.mark.asyncio
|
|
async def test_agenerate_image_diffusion_no_data(
|
|
mock_aimage_generation: AsyncMock,
|
|
stub_model_facade: ModelFacade,
|
|
) -> None:
|
|
"""Test async image generation raises error when diffusion API returns no data."""
|
|
mock_response = lazy.litellm.types.utils.ImageResponse(data=[])
|
|
mock_aimage_generation.return_value = mock_response
|
|
|
|
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True):
|
|
with pytest.raises(ImageGenerationError, match="Image generation returned no data"):
|
|
await stub_model_facade.agenerate_image(prompt="test prompt")
|
|
|
|
|
|
@patch.object(ModelFacade, "acompletion", new_callable=AsyncMock)
|
|
@pytest.mark.asyncio
|
|
async def test_agenerate_image_chat_completion_no_choices(
|
|
mock_acompletion: AsyncMock,
|
|
stub_model_facade: ModelFacade,
|
|
) -> None:
|
|
"""Test async image generation raises error when response has no choices."""
|
|
mock_response = lazy.litellm.types.utils.ModelResponse(choices=[])
|
|
mock_acompletion.return_value = mock_response
|
|
|
|
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False):
|
|
with pytest.raises(ImageGenerationError, match="Image generation response missing choices"):
|
|
await stub_model_facade.agenerate_image(prompt="test prompt")
|
|
|
|
|
|
@patch("data_designer.engine.models.facade.CustomRouter.aimage_generation", new_callable=AsyncMock)
|
|
@pytest.mark.asyncio
|
|
async def test_agenerate_image_skip_usage_tracking(
|
|
mock_aimage_generation: AsyncMock,
|
|
stub_model_facade: ModelFacade,
|
|
) -> None:
|
|
"""Test that async image generation respects skip_usage_tracking flag."""
|
|
mock_response = lazy.litellm.types.utils.ImageResponse(
|
|
data=[lazy.litellm.types.utils.ImageObject(b64_json="image1_base64")]
|
|
)
|
|
mock_aimage_generation.return_value = mock_response
|
|
|
|
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True):
|
|
images = await stub_model_facade.agenerate_image(prompt="test prompt", skip_usage_tracking=True)
|
|
|
|
assert len(images) == 1
|
|
assert stub_model_facade.usage_stats.image_usage.total_images == 0
|