DataDesigner/packages/data-designer-engine/tests/test_plugin_manager.py

124 lines
4.1 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Generator
from contextlib import contextmanager
from enum import Enum
from unittest.mock import patch
from data_designer.engine.testing.stubs import (
StubPluginConfigModels,
plugin_blobs_and_seeds,
plugin_models,
plugin_models_and_blobs,
)
from data_designer.plugin_manager import PluginManager
from data_designer.plugins.plugin import Plugin
from data_designer.plugins.registry import PluginRegistry
class MockEntryPoint:
def __init__(self, plugin: Plugin):
self.plugin = plugin
def load(self) -> Plugin:
return self.plugin
@contextmanager
def mock_plugin_system(plugins: list[Plugin]) -> Generator[None, None, None]:
"""Context manager to mock plugin entry points to return the provided plugins.
This works regardless of whether the actual environment has plugins available or not
by patching at the module level where PluginManager is instantiated.
"""
mock_entry_points = [MockEntryPoint(plugin) for plugin in plugins]
with (
patch("data_designer.plugins.registry.entry_points", return_value=mock_entry_points),
patch("data_designer.plugins.registry.PLUGINS_DISABLED", False),
):
yield
PluginRegistry.reset()
def make_test_enum(plugins: list[Plugin]) -> type[Enum]:
TestEnum = Enum("TestEnum", {plugin.name.replace("-", "_").upper(): plugin.name for plugin in plugins}, type=str)
return TestEnum
def test_get_column_generator_plugins_with_plugins() -> None:
"""Test getting plugin column configs when plugins are available."""
with mock_plugin_system([plugin_blobs_and_seeds, plugin_models]):
manager = PluginManager()
result = manager.get_column_generator_plugins()
assert len(result) == 2
assert [p.name for p in result] == [plugin_blobs_and_seeds.name, plugin_models.name]
def test_get_column_generator_plugins_empty() -> None:
"""Test getting plugin column configs when no plugins are registered."""
with mock_plugin_system([]):
manager = PluginManager()
result = manager.get_column_generator_plugins()
assert result == []
def test_get_column_generator_plugin_if_exists_found() -> None:
"""Test getting a specific plugin by name when it exists."""
with mock_plugin_system([plugin_models]):
manager = PluginManager()
result = manager.get_column_generator_plugin_if_exists(plugin_models.name)
assert result == plugin_models
def test_get_column_generator_plugin_if_exists_not_found() -> None:
"""Test getting a specific plugin by name when it doesn't exist."""
with mock_plugin_system([]):
manager = PluginManager()
result = manager.get_column_generator_plugin_if_exists(plugin_models.name)
assert result is None
def test_get_plugin_column_types_with_plugins() -> None:
"""Test getting plugin column types when plugins are available."""
all_plugins = [plugin_models, plugin_models_and_blobs, plugin_blobs_and_seeds]
TestEnum = make_test_enum(all_plugins)
with mock_plugin_system(all_plugins):
manager = PluginManager()
result = manager.get_plugin_column_types(TestEnum)
assert len(result) == 3
assert all(isinstance(item, TestEnum) for item in result)
def test_get_plugin_column_types_empty() -> None:
"""Test getting plugin column types when no plugins are registered."""
TestEnum = make_test_enum([])
with mock_plugin_system([]):
manager = PluginManager()
result = manager.get_plugin_column_types(TestEnum)
assert result == []
def test_inject_into_column_config_type_union_with_plugins() -> None:
"""Test injecting plugins into column config type union."""
class BaseType1:
pass
class BaseType2:
pass
TestUnion = BaseType1 | BaseType2
with mock_plugin_system([plugin_models]):
manager = PluginManager()
result = manager.inject_into_column_config_type_union(TestUnion)
assert result == BaseType1 | BaseType2 | StubPluginConfigModels