DataDesigner/tests/cli/controllers/test_model_controller.py

160 lines
6.5 KiB
Python
Raw Normal View History

# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from data_designer.cli.controllers.model_controller import ModelController
from data_designer.cli.repositories.model_repository import ModelConfigRegistry
from data_designer.cli.repositories.provider_repository import ModelProviderRegistry, ProviderRepository
feat: support native embedding generation (#106) * Add generation type to ModelConfig * pass tests * added generate_text_embeddings * tests * remove sensitive=True old artifact no longer needed * Slight refactor * slight refactor * Added embedding generator * chunk_separator -> chunk_pattern * update tests * rename for consistency * Restructure InferenceParameters -> CompletionInferenceParameters, BaseInferenceParameters, EmbeddingInferenceParameters * Remove purpose from consolidated kwargs * WithModelConfiguration.inference_parameters should should be typed with BaseInferenceParameters * Type as WithModelGeneration * Add image generation modality * update return type for generate_kwargs * make generation_type a field of ModelConfig as opposed to a prop resolved based on the type of InferenceParameters * remove regex based chunking from embedding generator * Remove image generation for now * more tests and updates * column_type_is_llm_generated -> column_type_is_model_generated * change set to list: fix flaky tests * CompletionInferenceParameters -> ChatCompletionInferenceParameters for consistency with generation_type * Update docs * fix deprecation warning originating from cli model settings * update display of inference parameters in cli list * save prog on inference parameter * updates for the ocnfig builder * update cli readme * update cli for inference parmeters * update inference parameter names * flip order of vars * WithCompletion -> WithChatCompletion * specify InferenceParamsT * Update columns.md with EmbeddingColumnConfig info * make generation_type a descriminator field in inference params. add configuration support for max_parallel_requests and timeout * DRY out some stuff in field.py * Update nomenclature. prompt tokens -> input tokens, completion tokens -> output tokens in column statistics for consistency * Add nvidia-embedding and openai-embedding to default model configs * Fix typo in docs * Make generate collab notebooks * fine-tune -> adjust
2025-12-15 18:03:33 +00:00
from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig
@pytest.fixture
def controller(tmp_path: Path, stub_model_providers: list) -> ModelController:
"""Create a controller instance for testing."""
provider_repo = ProviderRepository(tmp_path)
provider_repo.save(ModelProviderRegistry(providers=stub_model_providers, default=stub_model_providers[0].name))
return ModelController(tmp_path)
@pytest.fixture
def controller_with_models(controller: ModelController, stub_model_configs: list[ModelConfig]) -> ModelController:
"""Create a controller instance with existing models."""
controller.model_repository.save(ModelConfigRegistry(model_configs=stub_model_configs))
return controller
def test_init(tmp_path: Path) -> None:
"""Test controller initialization sets up repositories and services correctly."""
controller = ModelController(tmp_path)
assert controller.config_dir == tmp_path
assert controller.model_repository.config_dir == tmp_path
assert controller.model_service.repository == controller.model_repository
assert controller.provider_repository.config_dir == tmp_path
assert controller.provider_service.repository == controller.provider_repository
@patch("data_designer.cli.controllers.model_controller.print_error")
@patch("data_designer.cli.controllers.model_controller.print_info")
@patch("data_designer.cli.controllers.model_controller.print_header")
def test_run_with_no_providers(
mock_print_header: MagicMock, mock_print_info: MagicMock, mock_print_error: MagicMock, tmp_path: Path
) -> None:
"""Test run exits early when no providers are configured."""
controller = ModelController(tmp_path)
controller.run()
mock_print_header.assert_called_once_with("Configure Models")
mock_print_error.assert_called_once_with("No providers available!")
mock_print_info.assert_called_once_with("Please run 'data-designer config providers' first")
def test_run_with_no_models_and_user_cancels(controller: ModelController) -> None:
"""Test run with no existing models prompts for add and handles cancellation."""
mock_builder = MagicMock()
mock_builder.run.return_value = None
with patch("data_designer.cli.controllers.model_controller.ModelFormBuilder", return_value=mock_builder):
controller.run()
# Verify no models were added since user cancelled
assert len(controller.model_service.list_all()) == 0
@patch("data_designer.cli.controllers.model_controller.select_with_arrows", return_value="no")
def test_run_with_no_models_adds_new_model(
mock_select: MagicMock,
controller: ModelController,
stub_new_model_config: ModelConfig,
) -> None:
"""Test run with no existing models successfully adds a new model."""
mock_builder = MagicMock()
mock_builder.run.return_value = stub_new_model_config
with patch("data_designer.cli.controllers.model_controller.ModelFormBuilder", return_value=mock_builder):
controller.run()
# Verify model was actually added through the public interface
models = controller.model_service.list_all()
assert len(models) == 1
assert models[0].alias == stub_new_model_config.alias
@patch("data_designer.cli.controllers.model_controller.select_with_arrows", return_value="exit")
def test_run_with_existing_models_and_exit(
mock_select: MagicMock,
controller_with_models: ModelController,
) -> None:
"""Test run with existing models shows config and respects exit choice."""
initial_count = len(controller_with_models.model_service.list_all())
controller_with_models.run()
# Verify no changes were made
assert len(controller_with_models.model_service.list_all()) == initial_count
@patch("data_designer.cli.controllers.model_controller.confirm_action", return_value=True)
@patch("data_designer.cli.controllers.model_controller.select_with_arrows")
def test_run_deletes_model(
mock_select: MagicMock,
mock_confirm: MagicMock,
controller_with_models: ModelController,
) -> None:
"""Test run can delete a model through delete mode."""
mock_select.side_effect = ["delete", "test-alias-1"]
controller_with_models.run()
# Verify model was actually deleted
remaining_models = controller_with_models.model_service.list_all()
assert len(remaining_models) == 1
assert remaining_models[0].alias == "test-alias-2"
@patch("data_designer.cli.controllers.model_controller.confirm_action", return_value=True)
@patch("data_designer.cli.controllers.model_controller.select_with_arrows", return_value="delete_all")
def test_run_deletes_all_models(
mock_select: MagicMock,
mock_confirm: MagicMock,
controller_with_models: ModelController,
) -> None:
"""Test run can delete all models through delete_all mode."""
controller_with_models.run()
# Verify all models were actually deleted
assert len(controller_with_models.model_service.list_all()) == 0
@patch("data_designer.cli.controllers.model_controller.select_with_arrows")
def test_run_updates_model(
mock_select: MagicMock,
controller_with_models: ModelController,
) -> None:
"""Test run can update an existing model through update mode."""
mock_select.side_effect = ["update", "test-alias-1"]
updated_config = ModelConfig(
alias="test-alias-1-updated",
model="test-model-1-updated",
provider="test-provider-1",
feat: support native embedding generation (#106) * Add generation type to ModelConfig * pass tests * added generate_text_embeddings * tests * remove sensitive=True old artifact no longer needed * Slight refactor * slight refactor * Added embedding generator * chunk_separator -> chunk_pattern * update tests * rename for consistency * Restructure InferenceParameters -> CompletionInferenceParameters, BaseInferenceParameters, EmbeddingInferenceParameters * Remove purpose from consolidated kwargs * WithModelConfiguration.inference_parameters should should be typed with BaseInferenceParameters * Type as WithModelGeneration * Add image generation modality * update return type for generate_kwargs * make generation_type a field of ModelConfig as opposed to a prop resolved based on the type of InferenceParameters * remove regex based chunking from embedding generator * Remove image generation for now * more tests and updates * column_type_is_llm_generated -> column_type_is_model_generated * change set to list: fix flaky tests * CompletionInferenceParameters -> ChatCompletionInferenceParameters for consistency with generation_type * Update docs * fix deprecation warning originating from cli model settings * update display of inference parameters in cli list * save prog on inference parameter * updates for the ocnfig builder * update cli readme * update cli for inference parmeters * update inference parameter names * flip order of vars * WithCompletion -> WithChatCompletion * specify InferenceParamsT * Update columns.md with EmbeddingColumnConfig info * make generation_type a descriminator field in inference params. add configuration support for max_parallel_requests and timeout * DRY out some stuff in field.py * Update nomenclature. prompt tokens -> input tokens, completion tokens -> output tokens in column statistics for consistency * Add nvidia-embedding and openai-embedding to default model configs * Fix typo in docs * Make generate collab notebooks * fine-tune -> adjust
2025-12-15 18:03:33 +00:00
inference_parameters=ChatCompletionInferenceParams(temperature=0.8, top_p=0.95, max_tokens=1024),
)
mock_builder = MagicMock()
mock_builder.run.return_value = updated_config
with patch("data_designer.cli.controllers.model_controller.ModelFormBuilder", return_value=mock_builder):
controller_with_models.run()
# Verify model was actually updated
models = controller_with_models.model_service.list_all()
assert len(models) == 2
updated_model = controller_with_models.model_service.get_by_alias("test-alias-1-updated")
assert updated_model is not None
assert updated_model.model == "test-model-1-updated"
assert controller_with_models.model_service.get_by_alias("test-alias-1") is None