DataDesigner/tests/plugins/test_plugin_registry.py
Johnny Greco f8c201e085
chore: update header script to check for diffs (#195)
* update script

* update headers

* refactor a bit and add test script

* update headers

* update for edge case

* update headers

* add step to get file creation date

* use git history to get copyright year

* generation type is printed with inference parameters

* fix unit test
2026-01-09 17:10:58 -05:00

253 lines
9 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import threading
from contextlib import contextmanager
from importlib.metadata import EntryPoint
from unittest.mock import MagicMock, patch
import pytest
from data_designer.config.base import ConfigBase
from data_designer.plugins.errors import PluginNotFoundError
from data_designer.plugins.plugin import Plugin, PluginType
from data_designer.plugins.registry import PluginRegistry
from data_designer.plugins.testing.stubs import MODULE_NAME, StubPluginConfigA, StubPluginConfigB
# =============================================================================
# Test Fixtures
# =============================================================================
@pytest.fixture
def plugin_a() -> Plugin:
return Plugin(
impl_qualified_name=f"{MODULE_NAME}.StubPluginTaskA",
config_qualified_name=f"{MODULE_NAME}.StubPluginConfigA",
plugin_type=PluginType.COLUMN_GENERATOR,
)
@pytest.fixture
def plugin_b() -> Plugin:
return Plugin(
impl_qualified_name=f"{MODULE_NAME}.StubPluginTaskB",
config_qualified_name=f"{MODULE_NAME}.StubPluginConfigB",
plugin_type=PluginType.COLUMN_GENERATOR,
)
@pytest.fixture(autouse=True)
def clean_plugin_registry() -> None:
"""Reset PluginRegistry singleton state before and after each test."""
PluginRegistry.reset()
yield
PluginRegistry.reset()
@pytest.fixture
def mock_plugin_discovery():
"""Mock plugin discovery to test with specific entry points."""
@contextmanager
def _mock_discovery(entry_points_list):
with patch("data_designer.plugins.registry.PLUGINS_DISABLED", False):
with patch("data_designer.plugins.registry.entry_points", return_value=entry_points_list):
yield
return _mock_discovery
@pytest.fixture
def mock_entry_points(plugin_a: Plugin, plugin_b: Plugin) -> list[MagicMock]:
"""Create mock entry points for plugin_a and plugin_b."""
mock_ep_a = MagicMock(spec=EntryPoint)
mock_ep_a.name = "test-plugin-a"
mock_ep_a.load.return_value = plugin_a
mock_ep_b = MagicMock(spec=EntryPoint)
mock_ep_b.name = "test-plugin-b"
mock_ep_b.load.return_value = plugin_b
return [mock_ep_a, mock_ep_b]
# =============================================================================
# PluginRegistry Singleton Tests
# =============================================================================
def test_plugin_registry_is_singleton(mock_plugin_discovery) -> None:
"""Test PluginRegistry returns same instance."""
with mock_plugin_discovery([]):
manager1 = PluginRegistry()
manager2 = PluginRegistry()
assert manager1 is manager2
def test_plugin_registry_singleton_thread_safety(mock_plugin_discovery) -> None:
"""Test PluginRegistry singleton creation is thread-safe."""
instances: list[PluginRegistry] = []
with mock_plugin_discovery([]):
def create_manager() -> None:
instances.append(PluginRegistry())
threads = [threading.Thread(target=create_manager) for _ in range(10)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert all(instance is instances[0] for instance in instances)
# =============================================================================
# PluginRegistry Discovery Tests
# =============================================================================
def test_plugin_registry_discovers_plugins(
mock_plugin_discovery, mock_entry_points: list[MagicMock], plugin_a: Plugin, plugin_b: Plugin
) -> None:
"""Test PluginRegistry discovers and loads plugins from entry points."""
with mock_plugin_discovery(mock_entry_points):
manager = PluginRegistry()
assert manager.num_plugins(PluginType.COLUMN_GENERATOR) == 2
assert manager.get_plugin("test-plugin-a") == plugin_a
assert manager.get_plugin("test-plugin-b") == plugin_b
def test_plugin_registry_skips_invalid_plugins(mock_plugin_discovery, plugin_a: Plugin) -> None:
"""Test PluginRegistry skips non-Plugin objects during discovery."""
mock_ep_valid = MagicMock(spec=EntryPoint)
mock_ep_valid.name = "test-plugin-a"
mock_ep_valid.load.return_value = plugin_a
mock_ep_invalid = MagicMock(spec=EntryPoint)
mock_ep_invalid.name = "invalid-plugin"
mock_ep_invalid.load.return_value = "not a plugin"
with mock_plugin_discovery([mock_ep_valid, mock_ep_invalid]):
manager = PluginRegistry()
assert manager.num_plugins(PluginType.COLUMN_GENERATOR) == 1
assert manager.get_plugin("test-plugin-a") == plugin_a
def test_plugin_registry_handles_loading_errors(mock_plugin_discovery, plugin_a: Plugin) -> None:
"""Test PluginRegistry gracefully handles plugin loading errors."""
mock_ep_valid = MagicMock(spec=EntryPoint)
mock_ep_valid.name = "test-plugin-a"
mock_ep_valid.load.return_value = plugin_a
mock_ep_error = MagicMock(spec=EntryPoint)
mock_ep_error.name = "error-plugin"
mock_ep_error.load.side_effect = Exception("Loading failed")
with mock_plugin_discovery([mock_ep_valid, mock_ep_error]):
manager = PluginRegistry()
assert manager.num_plugins(PluginType.COLUMN_GENERATOR) == 1
assert manager.get_plugin("test-plugin-a") == plugin_a
def test_plugin_registry_discovery_runs_once() -> None:
"""Test discovery runs once even with multiple PluginRegistry instances."""
mock_entry_points = MagicMock(return_value=[])
with patch("data_designer.plugins.registry.PLUGINS_DISABLED", False):
with patch("data_designer.plugins.registry.entry_points", mock_entry_points):
PluginRegistry()
PluginRegistry()
PluginRegistry()
assert mock_entry_points.call_count == 1
def test_plugin_registry_respects_disabled_flag() -> None:
"""Test PluginRegistry respects DISABLE_DATA_DESIGNER_PLUGINS flag."""
mock_entry_points = MagicMock(return_value=[])
with patch("data_designer.plugins.registry.PLUGINS_DISABLED", True):
with patch("data_designer.plugins.registry.entry_points", mock_entry_points):
manager = PluginRegistry()
assert mock_entry_points.call_count == 0
assert manager.num_plugins(PluginType.COLUMN_GENERATOR) == 0
# =============================================================================
# PluginRegistry Query Methods Tests
# =============================================================================
def test_plugin_registry_get_plugin_raises_error(mock_plugin_discovery) -> None:
"""Test get_plugin() raises error for nonexistent plugin."""
with mock_plugin_discovery([]):
manager = PluginRegistry()
with pytest.raises(PluginNotFoundError, match="Plugin 'nonexistent' not found"):
manager.get_plugin("nonexistent")
def test_plugin_registry_get_plugins_by_type(
mock_plugin_discovery, mock_entry_points: list[MagicMock], plugin_a: Plugin, plugin_b: Plugin
) -> None:
"""Test get_plugins() filters by plugin type."""
with mock_plugin_discovery(mock_entry_points):
manager = PluginRegistry()
plugins = manager.get_plugins(PluginType.COLUMN_GENERATOR)
assert len(plugins) == 2
assert plugin_a in plugins
assert plugin_b in plugins
def test_plugin_registry_get_plugins_empty(mock_plugin_discovery) -> None:
"""Test get_plugins() returns empty list when no plugins match."""
with mock_plugin_discovery([]):
manager = PluginRegistry()
plugins = manager.get_plugins(PluginType.COLUMN_GENERATOR)
assert plugins == []
def test_plugin_registry_get_plugin_names(mock_plugin_discovery, mock_entry_points: list[MagicMock]) -> None:
"""Test get_plugin_names() returns plugin names by type."""
with mock_plugin_discovery(mock_entry_points):
manager = PluginRegistry()
names = manager.get_plugin_names(PluginType.COLUMN_GENERATOR)
assert set(names) == {"test-plugin-a", "test-plugin-b"}
# =============================================================================
# PluginRegistry Type Union Tests
# =============================================================================
def test_plugin_registry_update_type_union(mock_plugin_discovery, mock_entry_points: list[MagicMock]) -> None:
"""Test update_type_union() adds plugin config types to union."""
from typing_extensions import TypeAlias
class DummyConfig(ConfigBase):
pass
with mock_plugin_discovery(mock_entry_points):
manager = PluginRegistry()
# Create a Union with at least 2 types so it has __args__
type_union: TypeAlias = ConfigBase | DummyConfig
updated_union = manager.add_plugin_types_to_union(type_union, PluginType.COLUMN_GENERATOR)
assert StubPluginConfigA in updated_union.__args__
assert StubPluginConfigB in updated_union.__args__
assert ConfigBase in updated_union.__args__
assert DummyConfig in updated_union.__args__