DataDesigner/tests/engine/resources/test_managed_dataset_generator.py
2025-10-27 14:29:12 -04:00

120 lines
4 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import Mock, patch
import pandas as pd
import pytest
from data_designer.engine.resources.managed_dataset_generator import ManagedDatasetGenerator
from data_designer.engine.resources.managed_dataset_repository import ManagedDatasetRepository
from data_designer.engine.resources.managed_storage import ManagedBlobStorage
from data_designer.engine.sampling_gen.entities.person import load_person_data_sampler
from data_designer.engine.sampling_gen.errors import DatasetNotAvailableForLocaleError
@pytest.fixture
def stub_repository():
mock_repo = Mock(spec=ManagedDatasetRepository)
mock_repo.query.return_value = pd.DataFrame({"name": ["John", "Jane"], "age": [25, 30]})
return mock_repo
@pytest.fixture
def stub_blob_storage():
return Mock(spec=ManagedBlobStorage)
@pytest.mark.parametrize(
"dataset_name",
["en_US", "en_GB", "custom_dataset"],
)
def test_managed_dataset_generator_init(dataset_name, stub_repository):
generator = ManagedDatasetGenerator(stub_repository, dataset_name=dataset_name)
assert generator.managed_datasets == stub_repository
assert generator.dataset_name == dataset_name
@pytest.mark.parametrize(
"size,evidence,seed,expected_query_pattern",
[
(2, None, None, "select * from en_US order by random() limit 2"),
(
1,
{"name": "John"},
None,
"select * from en_US where name IN ('John') order by random() limit 1",
),
(
3,
{"name": ["John", "Jane"], "age": [25]},
None,
"select * from en_US where name IN ('John', 'Jane') and age IN ('25') order by random() limit 3",
),
(
1,
{"name": [], "age": None},
None,
"select * from en_US order by random() limit 1",
),
(1, None, 12345, "select * from en_US order by random() limit 1"),
(
None,
None,
None,
"select * from en_US order by random() limit 1",
),
],
)
def test_generate_samples_scenarios(size, evidence, seed, expected_query_pattern, stub_repository):
generator = ManagedDatasetGenerator(stub_repository, dataset_name="en_US")
if size is None:
result = generator.generate_samples(evidence=evidence, seed=seed)
else:
result = generator.generate_samples(size=size, evidence=evidence, seed=seed)
stub_repository.query.assert_called_once()
call_args = stub_repository.query.call_args[0][0]
assert expected_query_pattern in call_args
assert isinstance(result, pd.DataFrame)
def test_generate_samples_different_locale(stub_repository):
generator = ManagedDatasetGenerator(stub_repository, dataset_name="ja_JP")
result = generator.generate_samples(size=1)
expected_query = "select * from ja_JP order by random() limit 1"
stub_repository.query.assert_called_once_with(expected_query)
assert isinstance(result, pd.DataFrame)
@pytest.mark.parametrize(
"locale",
[
"en_US",
"ja_JP",
"en_IN",
],
)
@patch("data_designer.engine.sampling_gen.entities.person.load_managed_dataset_repository", autospec=True)
def test_load_person_data_sampler_scenarios(mock_load_repo, locale, stub_blob_storage):
mock_repo = Mock()
mock_load_repo.return_value = mock_repo
result = load_person_data_sampler(stub_blob_storage, locale=locale)
mock_load_repo.assert_called_once_with(stub_blob_storage)
assert isinstance(result, ManagedDatasetGenerator)
assert result.managed_datasets == mock_repo
assert result.dataset_name == locale
def test_load_person_data_sampler_invalid_locale(stub_blob_storage):
with pytest.raises(DatasetNotAvailableForLocaleError, match="Locale invalid_locale is not supported"):
load_person_data_sampler(stub_blob_storage, locale="invalid_locale")