mirror of
https://github.com/NVIDIA-NeMo/DataDesigner
synced 2026-05-24 09:48:29 +00:00
226 lines
9.1 KiB
Python
226 lines
9.1 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from data_designer.cli.controllers.provider_controller import ProviderController
|
|
from data_designer.cli.repositories.model_repository import ModelConfigRegistry
|
|
from data_designer.cli.repositories.provider_repository import ModelProviderRegistry
|
|
from data_designer.config.models import ModelConfig, ModelProvider
|
|
|
|
|
|
@pytest.fixture
|
|
def controller(tmp_path: Path) -> ProviderController:
|
|
"""Create a controller instance for testing."""
|
|
return ProviderController(tmp_path)
|
|
|
|
|
|
@pytest.fixture
|
|
def controller_with_providers(
|
|
controller: ProviderController, stub_model_providers: list[ModelProvider]
|
|
) -> ProviderController:
|
|
"""Create a controller instance with existing providers."""
|
|
controller.repository.save(
|
|
ModelProviderRegistry(providers=stub_model_providers, default=stub_model_providers[0].name)
|
|
)
|
|
return controller
|
|
|
|
|
|
@pytest.fixture
|
|
def controller_with_providers_and_models(
|
|
controller_with_providers: ProviderController, stub_model_configs: list[ModelConfig]
|
|
) -> ProviderController:
|
|
"""Create a controller instance with existing providers and models."""
|
|
controller_with_providers.model_repository.save(ModelConfigRegistry(model_configs=stub_model_configs))
|
|
return controller_with_providers
|
|
|
|
|
|
def test_init(tmp_path: Path) -> None:
|
|
"""Test controller initialization sets up repositories and services correctly."""
|
|
controller = ProviderController(tmp_path)
|
|
assert controller.config_dir == tmp_path
|
|
assert controller.repository.config_dir == tmp_path
|
|
assert controller.service.repository == controller.repository
|
|
assert controller.model_repository.config_dir == tmp_path
|
|
assert controller.model_service.repository == controller.model_repository
|
|
|
|
|
|
def test_run_with_no_providers_and_user_cancels(controller: ProviderController) -> None:
|
|
"""Test run with no existing providers prompts for add and handles cancellation."""
|
|
mock_builder = MagicMock()
|
|
mock_builder.run.return_value = None
|
|
|
|
with patch("data_designer.cli.controllers.provider_controller.ProviderFormBuilder", return_value=mock_builder):
|
|
controller.run()
|
|
|
|
# Verify no providers were added since user cancelled
|
|
assert len(controller.service.list_all()) == 0
|
|
|
|
|
|
@patch("data_designer.cli.controllers.provider_controller.select_with_arrows", return_value="no")
|
|
def test_run_with_no_providers_adds_new_provider(
|
|
mock_select: MagicMock,
|
|
controller: ProviderController,
|
|
stub_new_model_provider: ModelProvider,
|
|
) -> None:
|
|
"""Test run with no existing providers successfully adds a new provider."""
|
|
mock_builder = MagicMock()
|
|
mock_builder.run.return_value = stub_new_model_provider
|
|
|
|
with patch("data_designer.cli.controllers.provider_controller.ProviderFormBuilder", return_value=mock_builder):
|
|
controller.run()
|
|
|
|
# Verify provider was actually added through the public interface
|
|
providers = controller.service.list_all()
|
|
assert len(providers) == 1
|
|
assert providers[0].name == stub_new_model_provider.name
|
|
assert providers[0].endpoint == stub_new_model_provider.endpoint
|
|
|
|
|
|
@patch("data_designer.cli.controllers.provider_controller.select_with_arrows", return_value="exit")
|
|
def test_run_with_existing_providers_and_exit(
|
|
mock_select: MagicMock,
|
|
controller_with_providers: ProviderController,
|
|
) -> None:
|
|
"""Test run with existing providers shows config and respects exit choice."""
|
|
initial_count = len(controller_with_providers.service.list_all())
|
|
|
|
controller_with_providers.run()
|
|
|
|
# Verify no changes were made
|
|
assert len(controller_with_providers.service.list_all()) == initial_count
|
|
|
|
|
|
@patch("data_designer.cli.controllers.provider_controller.confirm_action", return_value=True)
|
|
@patch("data_designer.cli.controllers.provider_controller.select_with_arrows")
|
|
def test_run_deletes_provider_without_models(
|
|
mock_select: MagicMock,
|
|
mock_confirm: MagicMock,
|
|
controller_with_providers: ProviderController,
|
|
) -> None:
|
|
"""Test run can delete a provider through delete mode when no models are associated."""
|
|
mock_select.side_effect = ["delete", "test-provider-1"]
|
|
|
|
controller_with_providers.run()
|
|
|
|
# Verify provider was actually deleted
|
|
remaining_providers = controller_with_providers.service.list_all()
|
|
assert len(remaining_providers) == 1
|
|
assert remaining_providers[0].name == "test-provider-2"
|
|
|
|
|
|
@patch("data_designer.cli.controllers.provider_controller.confirm_action", return_value=True)
|
|
@patch("data_designer.cli.controllers.provider_controller.select_with_arrows")
|
|
def test_run_deletes_provider_with_associated_models(
|
|
mock_select: MagicMock,
|
|
mock_confirm: MagicMock,
|
|
controller_with_providers_and_models: ProviderController,
|
|
) -> None:
|
|
"""Test run deletes provider and associated models when confirmed."""
|
|
mock_select.side_effect = ["delete", "test-provider-1"]
|
|
|
|
controller_with_providers_and_models.run()
|
|
|
|
# Verify provider and associated models were actually deleted
|
|
providers = controller_with_providers_and_models.service.list_all()
|
|
models = controller_with_providers_and_models.model_service.list_all()
|
|
assert len(providers) == 1
|
|
assert providers[0].name == "test-provider-2"
|
|
assert len(models) == 0 # Both models were using test-provider-1
|
|
|
|
|
|
@patch("data_designer.cli.controllers.provider_controller.confirm_action", return_value=True)
|
|
@patch("data_designer.cli.controllers.provider_controller.select_with_arrows", return_value="delete_all")
|
|
def test_run_deletes_all_providers_without_models(
|
|
mock_select: MagicMock,
|
|
mock_confirm: MagicMock,
|
|
controller_with_providers: ProviderController,
|
|
) -> None:
|
|
"""Test run can delete all providers through delete_all mode."""
|
|
controller_with_providers.run()
|
|
|
|
# Verify all providers were actually deleted
|
|
assert len(controller_with_providers.service.list_all()) == 0
|
|
|
|
|
|
@patch("data_designer.cli.controllers.provider_controller.confirm_action", return_value=True)
|
|
@patch("data_designer.cli.controllers.provider_controller.select_with_arrows", return_value="delete_all")
|
|
def test_run_deletes_all_providers_with_models(
|
|
mock_select: MagicMock,
|
|
mock_confirm: MagicMock,
|
|
controller_with_providers_and_models: ProviderController,
|
|
) -> None:
|
|
"""Test run deletes all providers and associated models when confirmed."""
|
|
controller_with_providers_and_models.run()
|
|
|
|
# Verify all providers and models were actually deleted
|
|
assert len(controller_with_providers_and_models.service.list_all()) == 0
|
|
assert len(controller_with_providers_and_models.model_service.list_all()) == 0
|
|
|
|
|
|
@patch("data_designer.cli.controllers.provider_controller.select_with_arrows")
|
|
def test_run_updates_provider(
|
|
mock_select: MagicMock,
|
|
controller_with_providers: ProviderController,
|
|
) -> None:
|
|
"""Test run can update an existing provider through update mode."""
|
|
mock_select.side_effect = ["update", "test-provider-1"]
|
|
|
|
updated_provider = ModelProvider(
|
|
name="test-provider-1-updated",
|
|
endpoint="https://api.example.com/updated",
|
|
provider_type="openai",
|
|
api_key="updated-key",
|
|
)
|
|
|
|
mock_builder = MagicMock()
|
|
mock_builder.run.return_value = updated_provider
|
|
|
|
with patch("data_designer.cli.controllers.provider_controller.ProviderFormBuilder", return_value=mock_builder):
|
|
controller_with_providers.run()
|
|
|
|
# Verify provider was actually updated
|
|
providers = controller_with_providers.service.list_all()
|
|
assert len(providers) == 2
|
|
updated = controller_with_providers.service.get_by_name("test-provider-1-updated")
|
|
assert updated is not None
|
|
assert updated.endpoint == "https://api.example.com/updated"
|
|
assert controller_with_providers.service.get_by_name("test-provider-1") is None
|
|
|
|
|
|
@patch("data_designer.cli.controllers.provider_controller.select_with_arrows")
|
|
def test_run_changes_default_provider(
|
|
mock_select: MagicMock,
|
|
controller_with_providers: ProviderController,
|
|
) -> None:
|
|
"""Test run can change the default provider through change_default mode."""
|
|
mock_select.side_effect = ["change_default", "test-provider-2"]
|
|
|
|
# Verify initial default
|
|
assert controller_with_providers.service.get_default() == "test-provider-1"
|
|
|
|
controller_with_providers.run()
|
|
|
|
# Verify default was actually changed
|
|
assert controller_with_providers.service.get_default() == "test-provider-2"
|
|
|
|
|
|
@patch("data_designer.cli.controllers.provider_controller.confirm_action", return_value=False)
|
|
@patch("data_designer.cli.controllers.provider_controller.select_with_arrows")
|
|
def test_run_respects_delete_cancellation(
|
|
mock_select: MagicMock,
|
|
mock_confirm: MagicMock,
|
|
controller_with_providers: ProviderController,
|
|
) -> None:
|
|
"""Test run respects user's choice to cancel deletion."""
|
|
mock_select.side_effect = ["delete", "test-provider-1"]
|
|
|
|
initial_count = len(controller_with_providers.service.list_all())
|
|
|
|
controller_with_providers.run()
|
|
|
|
# Verify no providers were deleted
|
|
assert len(controller_with_providers.service.list_all()) == initial_count
|