DataDesigner/packages/data-designer-engine/tests/engine/test_configurable_task.py
Andre Manoel 8b79b21298 Initialize orphan Fern docs website branch
Preserves tree from previous docs-website head: 5e47d33ea8. This branch is a CI-managed publish artifact like gh-pages; source provenance is tracked in commit messages rather than Git ancestry.
2026-05-14 01:17:51 +00:00

112 lines
3.6 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from unittest.mock import Mock
import pytest
from data_designer.config.base import ConfigBase
from data_designer.engine.configurable_task import ConfigurableTask, DataT, TaskConfigT
from data_designer.engine.models.registry import ModelRegistry
from data_designer.engine.resources.resource_provider import ResourceProvider
from data_designer.engine.storage.artifact_storage import ArtifactStorage
def test_configurable_task_generic_type_variables() -> None:
# DataT constraints are deferred to TYPE_CHECKING to avoid eagerly importing pandas.
# At runtime it is an unconstrained TypeVar; type checkers still see the constraints.
assert DataT.__constraints__ == ()
assert TaskConfigT.__bound__ == ConfigBase
def test_configurable_task_concrete_implementation(tmp_path) -> None:
class TestConfig(ConfigBase):
value: str
class TestTask(ConfigurableTask[TestConfig]):
@classmethod
def get_config_type(cls) -> type[TestConfig]:
return TestConfig
def _validate(self) -> None:
pass
def _initialize(self) -> None:
pass
config = TestConfig(value="test")
artifact_storage = ArtifactStorage(artifact_path=tmp_path)
resource_provider = ResourceProvider(artifact_storage=artifact_storage)
task = TestTask(config=config, resource_provider=resource_provider)
assert task._config == config
assert task._resource_provider == resource_provider
def test_configurable_task_config_validation(tmp_path) -> None:
class TestConfig(ConfigBase):
value: str
class TestTask(ConfigurableTask[TestConfig]):
@classmethod
def get_config_type(cls) -> type[TestConfig]:
return TestConfig
def _validate(self) -> None:
if self._config.value == "invalid":
raise ValueError("Invalid config")
config = TestConfig(value="test")
artifact_storage = ArtifactStorage(artifact_path=tmp_path)
resource_provider = ResourceProvider(artifact_storage=artifact_storage)
task = TestTask(config=config, resource_provider=resource_provider)
assert task._config.value == "test"
invalid_config = TestConfig(value="invalid")
with pytest.raises(ValueError, match="Invalid config"):
TestTask(config=invalid_config, resource_provider=resource_provider)
def test_configurable_task_resource_validation(tmp_path) -> None:
class TestConfig(ConfigBase):
value: str
class TestTask(ConfigurableTask[TestConfig]):
@classmethod
def get_config_type(cls) -> type[TestConfig]:
return TestConfig
def _validate(self) -> None:
pass
def _initialize(self) -> None:
pass
config = TestConfig(value="test")
artifact_storage = ArtifactStorage(artifact_path=tmp_path)
mock_model_registry = Mock(spec=ModelRegistry)
resource_provider = ResourceProvider(artifact_storage=artifact_storage, model_registry=mock_model_registry)
task = TestTask(config=config, resource_provider=resource_provider)
assert task._resource_provider == resource_provider
def test_configurable_task_resource_provider_is_none() -> None:
class TestConfig(ConfigBase):
value: str
class TestTask(ConfigurableTask[TestConfig]):
def _validate(self) -> None:
pass
def _initialize(self) -> None:
pass
config = TestConfig(value="test")
task = TestTask(config=config, resource_provider=None)
assert task._resource_provider is None