DataDesigner/tests/engine/dataset_builders/test_column_wise_builder.py
2025-10-27 14:29:12 -04:00

155 lines
6.6 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import Mock
import pandas as pd
import pytest
from data_designer.config.columns import LLMTextColumnConfig, SamplerColumnConfig
from data_designer.engine.dataset_builders.column_wise_builder import (
MAX_CONCURRENCY_PER_NON_LLM_GENERATOR,
ColumnWiseDatasetBuilder,
)
from data_designer.engine.dataset_builders.errors import DatasetGenerationError
from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry
@pytest.fixture
def stub_test_column_configs():
return [LLMTextColumnConfig(name="test_column", prompt="Test prompt", model_alias="test_model")]
@pytest.fixture
def stub_batch_manager():
mock_batch_manager = Mock()
mock_batch_manager.num_batches = 2
mock_batch_manager.num_records_batch = 10
mock_batch_manager.finish = Mock()
mock_batch_manager.write = Mock()
mock_batch_manager.add_records = Mock()
mock_batch_manager.update_records = Mock()
mock_batch_manager.update_record = Mock()
mock_batch_manager.get_current_batch = Mock(return_value=pd.DataFrame({"existing": [1, 2, 3]}))
mock_batch_manager.get_current_batch_number = Mock(return_value=1)
return mock_batch_manager
@pytest.fixture
def stub_column_wise_builder(stub_resource_provider, stub_test_column_configs):
return ColumnWiseDatasetBuilder(column_configs=stub_test_column_configs, resource_provider=stub_resource_provider)
def test_column_wise_dataset_builder_creation(stub_resource_provider, stub_test_column_configs):
builder = ColumnWiseDatasetBuilder(
column_configs=stub_test_column_configs, resource_provider=stub_resource_provider
)
assert builder._column_configs == stub_test_column_configs
assert builder._resource_provider == stub_resource_provider
assert isinstance(builder._registry, DataDesignerRegistry)
def test_column_wise_dataset_builder_creation_with_custom_registry(stub_resource_provider, stub_test_column_configs):
custom_registry = Mock(spec=DataDesignerRegistry)
builder = ColumnWiseDatasetBuilder(
column_configs=stub_test_column_configs, resource_provider=stub_resource_provider, registry=custom_registry
)
assert builder._registry == custom_registry
def test_column_wise_dataset_builder_artifact_storage_property(stub_column_wise_builder, stub_resource_provider):
assert stub_column_wise_builder.artifact_storage == stub_resource_provider.artifact_storage
def test_column_wise_dataset_builder_records_to_drop_initialization(stub_column_wise_builder):
assert stub_column_wise_builder._records_to_drop == set()
def test_column_wise_dataset_builder_batch_manager_initialization(stub_column_wise_builder, stub_resource_provider):
assert stub_column_wise_builder.batch_manager is not None
assert stub_column_wise_builder.batch_manager.artifact_storage == stub_resource_provider.artifact_storage
@pytest.mark.parametrize(
"config_type,expected_single_configs",
[
("single", [LLMTextColumnConfig(name="test_column", prompt="Test prompt", model_alias="test_model")]),
(
"multi",
[SamplerColumnConfig(name="sampler_col", sampler_type="category", params={"values": ["A", "B", "C"]})],
),
],
)
def test_column_wise_dataset_builder_single_column_configs_property(
stub_resource_provider, config_type, expected_single_configs
):
if config_type == "single":
single_config = LLMTextColumnConfig(name="test_column", prompt="Test prompt", model_alias="test_model")
builder = ColumnWiseDatasetBuilder(column_configs=[single_config], resource_provider=stub_resource_provider)
assert builder.single_column_configs == [single_config]
else:
sampler_config = SamplerColumnConfig(
name="sampler_col", sampler_type="category", params={"values": ["A", "B", "C"]}
)
multi_config = SamplerMultiColumnConfig(columns=[sampler_config])
builder = ColumnWiseDatasetBuilder(column_configs=[multi_config], resource_provider=stub_resource_provider)
assert builder.single_column_configs == [sampler_config]
def test_column_wise_dataset_builder_build_method_basic_flow(
stub_column_wise_builder,
stub_batch_manager,
stub_resource_provider,
):
stub_resource_provider.model_registry.run_health_check = Mock()
stub_resource_provider.model_registry.get_model_usage_stats = Mock(return_value={"test": "stats"})
# Mock the model config to return proper max_parallel_requests
mock_model_config = Mock()
mock_model_config.inference_parameters.max_parallel_requests = 4
stub_resource_provider.model_registry.get_model_config.return_value = mock_model_config
# Mock the batch manager's iter_current_batch method
stub_batch_manager.iter_current_batch.return_value = [(0, {"test": "data"})]
stub_column_wise_builder.batch_manager = stub_batch_manager
result_path = stub_column_wise_builder.build(num_records=100, buffer_size=50)
stub_resource_provider.model_registry.run_health_check.assert_called_once()
stub_batch_manager.finish.assert_called_once()
assert result_path == stub_resource_provider.artifact_storage.final_dataset_path
@pytest.mark.parametrize(
"column_configs,expected_error",
[
([], "No column configs provided"),
(
[LLMTextColumnConfig(name="test_column", prompt="Test prompt", model_alias="test_model")],
"The first column config must be a from-scratch column generator",
),
],
)
def test_column_wise_dataset_builder_validate_column_configs(stub_resource_provider, column_configs, expected_error):
if expected_error == "The first column config must be a from-scratch column generator":
mock_registry = Mock()
mock_generator_class = Mock()
mock_generator_class.can_generate_from_scratch = False
mock_registry.column_generators.get_for_config_type.return_value = mock_generator_class
with pytest.raises(DatasetGenerationError, match=expected_error):
ColumnWiseDatasetBuilder(
column_configs=column_configs, resource_provider=stub_resource_provider, registry=mock_registry
)
else:
with pytest.raises(DatasetGenerationError, match=expected_error):
ColumnWiseDatasetBuilder(column_configs=column_configs, resource_provider=stub_resource_provider)
def test_constants_max_concurrency_constant():
assert MAX_CONCURRENCY_PER_NON_LLM_GENERATOR == 4