mirror of
https://github.com/NVIDIA-NeMo/DataDesigner
synced 2026-05-24 09:48:29 +00:00
124 lines
4.1 KiB
Python
124 lines
4.1 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from collections.abc import Generator
|
|
from contextlib import contextmanager
|
|
from enum import Enum
|
|
from unittest.mock import patch
|
|
|
|
from data_designer.engine.testing.stubs import (
|
|
StubPluginConfigModels,
|
|
plugin_blobs_and_seeds,
|
|
plugin_models,
|
|
plugin_models_and_blobs,
|
|
)
|
|
from data_designer.plugin_manager import PluginManager
|
|
from data_designer.plugins.plugin import Plugin
|
|
from data_designer.plugins.registry import PluginRegistry
|
|
|
|
|
|
class MockEntryPoint:
|
|
def __init__(self, plugin: Plugin):
|
|
self.plugin = plugin
|
|
|
|
def load(self) -> Plugin:
|
|
return self.plugin
|
|
|
|
|
|
@contextmanager
|
|
def mock_plugin_system(plugins: list[Plugin]) -> Generator[None, None, None]:
|
|
"""Context manager to mock plugin entry points to return the provided plugins.
|
|
|
|
This works regardless of whether the actual environment has plugins available or not
|
|
by patching at the module level where PluginManager is instantiated.
|
|
"""
|
|
mock_entry_points = [MockEntryPoint(plugin) for plugin in plugins]
|
|
with (
|
|
patch("data_designer.plugins.registry.entry_points", return_value=mock_entry_points),
|
|
patch("data_designer.plugins.registry.PLUGINS_DISABLED", False),
|
|
):
|
|
yield
|
|
PluginRegistry.reset()
|
|
|
|
|
|
def make_test_enum(plugins: list[Plugin]) -> type[Enum]:
|
|
TestEnum = Enum("TestEnum", {plugin.name.replace("-", "_").upper(): plugin.name for plugin in plugins}, type=str)
|
|
return TestEnum
|
|
|
|
|
|
def test_get_column_generator_plugins_with_plugins() -> None:
|
|
"""Test getting plugin column configs when plugins are available."""
|
|
with mock_plugin_system([plugin_blobs_and_seeds, plugin_models]):
|
|
manager = PluginManager()
|
|
result = manager.get_column_generator_plugins()
|
|
|
|
assert len(result) == 2
|
|
assert [p.name for p in result] == [plugin_blobs_and_seeds.name, plugin_models.name]
|
|
|
|
|
|
def test_get_column_generator_plugins_empty() -> None:
|
|
"""Test getting plugin column configs when no plugins are registered."""
|
|
with mock_plugin_system([]):
|
|
manager = PluginManager()
|
|
result = manager.get_column_generator_plugins()
|
|
|
|
assert result == []
|
|
|
|
|
|
def test_get_column_generator_plugin_if_exists_found() -> None:
|
|
"""Test getting a specific plugin by name when it exists."""
|
|
with mock_plugin_system([plugin_models]):
|
|
manager = PluginManager()
|
|
result = manager.get_column_generator_plugin_if_exists(plugin_models.name)
|
|
|
|
assert result == plugin_models
|
|
|
|
|
|
def test_get_column_generator_plugin_if_exists_not_found() -> None:
|
|
"""Test getting a specific plugin by name when it doesn't exist."""
|
|
with mock_plugin_system([]):
|
|
manager = PluginManager()
|
|
result = manager.get_column_generator_plugin_if_exists(plugin_models.name)
|
|
|
|
assert result is None
|
|
|
|
|
|
def test_get_plugin_column_types_with_plugins() -> None:
|
|
"""Test getting plugin column types when plugins are available."""
|
|
all_plugins = [plugin_models, plugin_models_and_blobs, plugin_blobs_and_seeds]
|
|
TestEnum = make_test_enum(all_plugins)
|
|
with mock_plugin_system(all_plugins):
|
|
manager = PluginManager()
|
|
result = manager.get_plugin_column_types(TestEnum)
|
|
|
|
assert len(result) == 3
|
|
assert all(isinstance(item, TestEnum) for item in result)
|
|
|
|
|
|
def test_get_plugin_column_types_empty() -> None:
|
|
"""Test getting plugin column types when no plugins are registered."""
|
|
TestEnum = make_test_enum([])
|
|
|
|
with mock_plugin_system([]):
|
|
manager = PluginManager()
|
|
result = manager.get_plugin_column_types(TestEnum)
|
|
|
|
assert result == []
|
|
|
|
|
|
def test_inject_into_column_config_type_union_with_plugins() -> None:
|
|
"""Test injecting plugins into column config type union."""
|
|
|
|
class BaseType1:
|
|
pass
|
|
|
|
class BaseType2:
|
|
pass
|
|
|
|
TestUnion = BaseType1 | BaseType2
|
|
|
|
with mock_plugin_system([plugin_models]):
|
|
manager = PluginManager()
|
|
result = manager.inject_into_column_config_type_union(TestUnion)
|
|
|
|
assert result == BaseType1 | BaseType2 | StubPluginConfigModels
|