DataDesigner/tests/cli/commands/test_reset_command.py
Johnny Greco f8c201e085
chore: update header script to check for diffs (#195)
* update script

* update headers

* refactor a bit and add test script

* update headers

* update for edge case

* update headers

* add step to get file creation date

* use git history to get copyright year

* generation type is printed with inference parameters

* fix unit test
2026-01-09 17:10:58 -05:00

287 lines
10 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Callable
from pathlib import Path
from unittest.mock import Mock, patch
import pytest
import typer
from data_designer.cli.commands.reset import reset_command
from data_designer.config.utils.constants import DATA_DESIGNER_HOME
# Type alias for the factory function
MockRepositoryFactory = Callable[
[bool, bool, Exception | None, Exception | None],
tuple[Mock, Mock, Mock, Mock],
]
# Fixtures for common test data
@pytest.fixture
def stub_fake_provider_path() -> Path:
"""Fake path for provider config file."""
return Path("/fake/providers.json")
@pytest.fixture
def stub_fake_model_path() -> Path:
"""Fake path for model config file."""
return Path("/fake/models.json")
# Helper functions for mock setup
def setup_mock_repository(
exists: bool = True,
config_file: Path | None = None,
delete_side_effect: Exception | None = None,
) -> Mock:
"""Create a mock repository instance with common configuration.
Args:
exists: Whether the config file exists
config_file: Path to the config file
delete_side_effect: Optional exception to raise on delete()
"""
mock_instance = Mock()
mock_instance.exists.return_value = exists
if config_file:
mock_instance.config_file = config_file
if delete_side_effect:
mock_instance.delete.side_effect = delete_side_effect
return mock_instance
@pytest.fixture
def mock_repositories_factory(stub_fake_provider_path: Path, stub_fake_model_path: Path) -> MockRepositoryFactory:
"""Factory fixture for creating mock repositories with different configurations."""
def _factory(
provider_exists: bool = False,
model_exists: bool = False,
provider_delete_error: Exception | None = None,
model_delete_error: Exception | None = None,
) -> tuple[Mock, Mock, Mock, Mock]:
"""Create mocked repositories and their instances.
Returns:
Tuple of (mock_provider_repo, mock_provider_instance,
mock_model_repo, mock_model_instance)
"""
mock_provider_instance = setup_mock_repository(
exists=provider_exists,
config_file=stub_fake_provider_path if provider_exists else None,
delete_side_effect=provider_delete_error,
)
mock_provider_repo = Mock(return_value=mock_provider_instance)
mock_model_instance = setup_mock_repository(
exists=model_exists,
config_file=stub_fake_model_path if model_exists else None,
delete_side_effect=model_delete_error,
)
mock_model_repo = Mock(return_value=mock_model_instance)
return mock_provider_repo, mock_provider_instance, mock_model_repo, mock_model_instance
return _factory
# Tests
@patch("data_designer.cli.commands.reset.ModelRepository")
@patch("data_designer.cli.commands.reset.ProviderRepository")
@patch("data_designer.cli.commands.reset.confirm_action")
def test_reset_no_config_files_exist(
mock_confirm: Mock,
mock_provider_repo: Mock,
mock_model_repo: Mock,
mock_repositories_factory: MockRepositoryFactory,
) -> None:
"""Test reset when no configuration files exist - should exit early."""
_, mock_provider_instance, _, mock_model_instance = mock_repositories_factory(
provider_exists=False, model_exists=False
)
mock_provider_repo.return_value = mock_provider_instance
mock_model_repo.return_value = mock_model_instance
with pytest.raises(typer.Exit) as exc_info:
reset_command()
assert exc_info.value.exit_code == 0
mock_confirm.assert_not_called()
mock_provider_instance.delete.assert_not_called()
mock_model_instance.delete.assert_not_called()
@patch("data_designer.cli.commands.reset.ModelRepository")
@patch("data_designer.cli.commands.reset.ProviderRepository")
@patch("data_designer.cli.commands.reset.confirm_action")
def test_reset_both_files_exist_user_confirms_both(
mock_confirm: Mock,
mock_provider_repo: Mock,
mock_model_repo: Mock,
mock_repositories_factory: MockRepositoryFactory,
) -> None:
"""Test reset when both config files exist and user confirms deletion of both."""
_, mock_provider_instance, _, mock_model_instance = mock_repositories_factory(
provider_exists=True, model_exists=True
)
mock_provider_repo.return_value = mock_provider_instance
mock_model_repo.return_value = mock_model_instance
mock_confirm.return_value = True
reset_command()
assert mock_confirm.call_count == 2
mock_provider_instance.delete.assert_called_once()
mock_model_instance.delete.assert_called_once()
@patch("data_designer.cli.commands.reset.ModelRepository")
@patch("data_designer.cli.commands.reset.ProviderRepository")
@patch("data_designer.cli.commands.reset.confirm_action")
def test_reset_both_files_exist_user_declines_both(
mock_confirm: Mock,
mock_provider_repo: Mock,
mock_model_repo: Mock,
mock_repositories_factory: MockRepositoryFactory,
) -> None:
"""Test reset when both config files exist but user declines deletion."""
_, mock_provider_instance, _, mock_model_instance = mock_repositories_factory(
provider_exists=True, model_exists=True
)
mock_provider_repo.return_value = mock_provider_instance
mock_model_repo.return_value = mock_model_instance
mock_confirm.return_value = False
reset_command()
assert mock_confirm.call_count == 2
mock_provider_instance.delete.assert_not_called()
mock_model_instance.delete.assert_not_called()
@patch("data_designer.cli.commands.reset.ModelRepository")
@patch("data_designer.cli.commands.reset.ProviderRepository")
@patch("data_designer.cli.commands.reset.confirm_action")
def test_reset_mixed_confirmation(
mock_confirm: Mock,
mock_provider_repo: Mock,
mock_model_repo: Mock,
mock_repositories_factory: MockRepositoryFactory,
) -> None:
"""Test reset when user confirms one file but not the other."""
_, mock_provider_instance, _, mock_model_instance = mock_repositories_factory(
provider_exists=True, model_exists=True
)
mock_provider_repo.return_value = mock_provider_instance
mock_model_repo.return_value = mock_model_instance
mock_confirm.side_effect = [True, False]
reset_command()
assert mock_confirm.call_count == 2
mock_provider_instance.delete.assert_called_once()
mock_model_instance.delete.assert_not_called()
@pytest.mark.parametrize(
"provider_error,model_error,expected_provider_calls,expected_model_calls",
[
(Exception("Permission denied"), None, 1, 1),
(None, OSError("Disk error"), 1, 1),
(Exception("Error 1"), Exception("Error 2"), 1, 1),
],
ids=["provider_fails", "model_fails", "both_fail"],
)
@patch("data_designer.cli.commands.reset.ModelRepository")
@patch("data_designer.cli.commands.reset.ProviderRepository")
@patch("data_designer.cli.commands.reset.confirm_action")
def test_reset_deletion_failures(
mock_confirm: Mock,
mock_provider_repo: Mock,
mock_model_repo: Mock,
mock_repositories_factory: MockRepositoryFactory,
provider_error: Exception | None,
model_error: Exception | None,
expected_provider_calls: int,
expected_model_calls: int,
) -> None:
"""Test reset when deletion fails for one or more repositories."""
_, mock_provider_instance, _, mock_model_instance = mock_repositories_factory(
provider_exists=True,
model_exists=True,
provider_delete_error=provider_error,
model_delete_error=model_error,
)
mock_provider_repo.return_value = mock_provider_instance
mock_model_repo.return_value = mock_model_instance
mock_confirm.return_value = True
with pytest.raises(typer.Exit) as exc_info:
reset_command()
assert exc_info.value.exit_code == 1
assert mock_provider_instance.delete.call_count == expected_provider_calls
assert mock_model_instance.delete.call_count == expected_model_calls
@pytest.mark.parametrize(
"provider_exists,model_exists,expected_confirms,expected_provider_deletes,expected_model_deletes",
[
(True, False, 1, 1, 0),
(False, True, 1, 0, 1),
],
ids=["only_provider", "only_model"],
)
@patch("data_designer.cli.commands.reset.ModelRepository")
@patch("data_designer.cli.commands.reset.ProviderRepository")
@patch("data_designer.cli.commands.reset.confirm_action")
def test_reset_single_file_exists(
mock_confirm: Mock,
mock_provider_repo: Mock,
mock_model_repo: Mock,
mock_repositories_factory: MockRepositoryFactory,
provider_exists: bool,
model_exists: bool,
expected_confirms: int,
expected_provider_deletes: int,
expected_model_deletes: int,
) -> None:
"""Test reset when only one config file exists."""
_, mock_provider_instance, _, mock_model_instance = mock_repositories_factory(
provider_exists=provider_exists, model_exists=model_exists
)
mock_provider_repo.return_value = mock_provider_instance
mock_model_repo.return_value = mock_model_instance
mock_confirm.return_value = True
reset_command()
assert mock_confirm.call_count == expected_confirms
assert mock_provider_instance.delete.call_count == expected_provider_deletes
assert mock_model_instance.delete.call_count == expected_model_deletes
@patch("data_designer.cli.commands.reset.ModelRepository")
@patch("data_designer.cli.commands.reset.ProviderRepository")
@patch("data_designer.cli.commands.reset.confirm_action")
def test_reset_uses_default_config_dir_when_none_provided(
mock_confirm: Mock,
mock_provider_repo: Mock,
mock_model_repo: Mock,
mock_repositories_factory: MockRepositoryFactory,
) -> None:
"""Test that default config directory is used when config_dir is None."""
_, mock_provider_instance, _, mock_model_instance = mock_repositories_factory(
provider_exists=False, model_exists=False
)
mock_provider_repo.return_value = mock_provider_instance
mock_model_repo.return_value = mock_model_instance
with pytest.raises(typer.Exit):
reset_command()
mock_provider_repo.assert_called_once_with(DATA_DESIGNER_HOME)
mock_model_repo.assert_called_once_with(DATA_DESIGNER_HOME)