mirror of
https://github.com/NVIDIA-NeMo/DataDesigner
synced 2026-05-24 09:48:29 +00:00
243 lines
10 KiB
Python
243 lines
10 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from collections import Counter
|
|
import json
|
|
import tempfile
|
|
|
|
from pydantic import ValidationError
|
|
import pytest
|
|
import yaml
|
|
|
|
from data_designer.config.errors import InvalidConfigError
|
|
from data_designer.config.models import (
|
|
ImageContext,
|
|
ImageFormat,
|
|
InferenceParameters,
|
|
ManualDistribution,
|
|
ManualDistributionParams,
|
|
ModalityDataType,
|
|
ModelConfig,
|
|
UniformDistribution,
|
|
UniformDistributionParams,
|
|
load_model_configs,
|
|
)
|
|
|
|
|
|
def test_image_context_get_context():
|
|
image_context = ImageContext(
|
|
column_name="image_base64", data_type=ModalityDataType.BASE64, image_format=ImageFormat.PNG
|
|
)
|
|
assert image_context.get_context({"image_base64": "somebase64encodedimagestring"}) == {
|
|
"type": "image_url",
|
|
"image_url": {"url": "data:image/png;base64,somebase64encodedimagestring", "format": "png"},
|
|
}
|
|
|
|
image_context = ImageContext(column_name="image_url", data_type=ModalityDataType.URL)
|
|
assert image_context.get_context({"image_url": "https://example.com/examle_image.png"}) == {
|
|
"type": "image_url",
|
|
"image_url": "https://example.com/examle_image.png",
|
|
}
|
|
|
|
|
|
def test_image_context_validate_image_format():
|
|
with pytest.raises(ValueError, match="image_format is required when data_type is base64"):
|
|
ImageContext(column_name="image_base64", data_type=ModalityDataType.BASE64)
|
|
|
|
|
|
def test_inference_parameters_generate_kwargs():
|
|
assert InferenceParameters(
|
|
temperature=0.95,
|
|
top_p=0.95,
|
|
max_tokens=100,
|
|
max_parallel_requests=40,
|
|
timeout=10,
|
|
extra_body={"reasoning_effort": "high"},
|
|
).generate_kwargs == {
|
|
"temperature": 0.95,
|
|
"top_p": 0.95,
|
|
"max_tokens": 100,
|
|
"timeout": 10,
|
|
"extra_body": {"reasoning_effort": "high"},
|
|
}
|
|
|
|
assert InferenceParameters().generate_kwargs == {}
|
|
|
|
inference_parameters_kwargs = InferenceParameters(
|
|
temperature=UniformDistribution(params=UniformDistributionParams(low=0.0, high=1.0)),
|
|
top_p=ManualDistribution(params=ManualDistributionParams(values=[0.0, 1.0], weights=[0.5, 0.5])),
|
|
).generate_kwargs
|
|
assert inference_parameters_kwargs["temperature"] is not None
|
|
assert inference_parameters_kwargs["top_p"] is not None
|
|
|
|
|
|
def test_uniform_distribution_low_lt_high_validation():
|
|
with pytest.raises(ValueError, match="`low` must be less than `high`"):
|
|
UniformDistribution(params=UniformDistributionParams(low=0.8, high=0.8))
|
|
|
|
|
|
def test_uniform_distribution_sampling():
|
|
dist = UniformDistribution(params=UniformDistributionParams(low=0.2, high=0.8))
|
|
samples = [dist.sample() for _ in range(1000)]
|
|
assert all(0.2 <= s <= 0.8 for s in samples)
|
|
|
|
|
|
def test_manual_distribution_sampling():
|
|
values = [0.1, 0.5, 0.9]
|
|
weights = [0.2, 0.5, 0.3]
|
|
dist = ManualDistribution(params=ManualDistributionParams(values=values, weights=weights))
|
|
samples = [dist.sample() for _ in range(10000)]
|
|
assert set(samples) == set(values)
|
|
|
|
sample_counts = Counter(samples)
|
|
total_samples = sum(sample_counts.values())
|
|
for value, weight in zip(values, weights, strict=False):
|
|
observed_freq = sample_counts[value] / total_samples
|
|
expected_freq = weight
|
|
# Allow small margin for randomness
|
|
assert abs(observed_freq - expected_freq) < 0.05
|
|
|
|
|
|
def test_manual_distribution_equal_length_validation():
|
|
with pytest.raises(ValueError, match="`values` and `weights` must have the same length"):
|
|
ManualDistribution(params=ManualDistributionParams(values=[0.1, 0.2], weights=[0.5]))
|
|
|
|
|
|
def test_manual_distribution_weight_normalization():
|
|
dist = ManualDistribution(params=ManualDistributionParams(values=[0.2, 0.8], weights=[2, 2]))
|
|
assert dist.params.weights == [0.5, 0.5]
|
|
|
|
|
|
def test_invalid_manual_distribution():
|
|
# Empty list should fail
|
|
with pytest.raises(ValidationError):
|
|
ManualDistribution(params=ManualDistributionParams(values=[], weights=[]))
|
|
|
|
|
|
def test_manual_distribution_sampling_without_weights():
|
|
dist = ManualDistribution(params=ManualDistributionParams(values=[0.1, 0.4, 0.7]))
|
|
samples = [dist.sample() for _ in range(1000)]
|
|
assert all(0.0 <= s <= 1.0 for s in samples)
|
|
|
|
|
|
def test_inference_parameters_temperature_validation():
|
|
expected_error_msg = "temperature defined in model config must be between 0.0 and 2.0"
|
|
|
|
# All temp values provide in a manual destribution should be valid
|
|
with pytest.raises(ValidationError, match=expected_error_msg):
|
|
InferenceParameters(
|
|
temperature=ManualDistribution(params=ManualDistributionParams(values=[0.5, 2.5], weights=[0.5, 0.5]))
|
|
)
|
|
|
|
# High and low values of uniform distribution should be valid
|
|
with pytest.raises(ValidationError, match=expected_error_msg):
|
|
InferenceParameters(temperature=UniformDistribution(params=UniformDistributionParams(low=0.5, high=2.5)))
|
|
|
|
with pytest.raises(ValidationError, match=expected_error_msg):
|
|
InferenceParameters(temperature=UniformDistribution(params=UniformDistributionParams(low=-0.5, high=2.0)))
|
|
|
|
# Static values should be valid
|
|
with pytest.raises(ValidationError, match=expected_error_msg):
|
|
InferenceParameters(temperature=3.0)
|
|
with pytest.raises(ValidationError, match=expected_error_msg):
|
|
InferenceParameters(temperature=-1.0)
|
|
|
|
# Valid temperature values shouldn't raise validation errors
|
|
try:
|
|
InferenceParameters(temperature=0.1)
|
|
InferenceParameters(temperature=UniformDistribution(params=UniformDistributionParams(low=0.5, high=2.0)))
|
|
InferenceParameters(
|
|
temperature=ManualDistribution(params=ManualDistributionParams(values=[0.5, 2.0], weights=[0.5, 0.5]))
|
|
)
|
|
except Exception:
|
|
pytest.fail("Unexpected exception raised during InferenceParameters temperature validation")
|
|
|
|
|
|
def test_generation_parameters_top_p_validation():
|
|
expected_error_msg = "top_p defined in model config must be between 0.0 and 1.0"
|
|
|
|
# All top_p values provide in a manual destribution should be valid
|
|
with pytest.raises(ValidationError, match=expected_error_msg):
|
|
InferenceParameters(
|
|
top_p=ManualDistribution(params=ManualDistributionParams(values=[0.5, 1.5], weights=[0.5, 0.5]))
|
|
)
|
|
|
|
# High and low values of uniform distribution should be valid
|
|
with pytest.raises(ValidationError, match=expected_error_msg):
|
|
InferenceParameters(top_p=UniformDistribution(params=UniformDistributionParams(low=0.5, high=1.5)))
|
|
with pytest.raises(ValidationError, match=expected_error_msg):
|
|
InferenceParameters(top_p=UniformDistribution(params=UniformDistributionParams(low=-0.5, high=1.0)))
|
|
|
|
# Static values should be valid
|
|
with pytest.raises(ValidationError, match=expected_error_msg):
|
|
InferenceParameters(top_p=1.5)
|
|
with pytest.raises(ValidationError, match=expected_error_msg):
|
|
InferenceParameters(top_p=-0.1)
|
|
|
|
# Valid top_p values shouldn't raise validation errors
|
|
try:
|
|
InferenceParameters(top_p=0.1)
|
|
InferenceParameters(top_p=UniformDistribution(params=UniformDistributionParams(low=0.5, high=1.0)))
|
|
InferenceParameters(
|
|
top_p=ManualDistribution(params=ManualDistributionParams(values=[0.5, 1.0], weights=[0.5, 0.5]))
|
|
)
|
|
except Exception:
|
|
pytest.fail("Unexpected exception raised during InferenceParameters top_p validation")
|
|
|
|
|
|
def test_generation_parameters_max_tokens_validation():
|
|
with pytest.raises(
|
|
ValidationError,
|
|
match="Input should be greater than or equal to 1",
|
|
):
|
|
InferenceParameters(max_tokens=0)
|
|
|
|
# Valid max_tokens values shouldn't raise validation errors
|
|
try:
|
|
InferenceParameters(max_tokens=128_000)
|
|
InferenceParameters(max_tokens=4096)
|
|
InferenceParameters(max_tokens=1)
|
|
except Exception:
|
|
pytest.fail("Unexpected exception raised during InferenceParameters max_tokens validation")
|
|
|
|
|
|
def test_load_model_configs():
|
|
stub_model_configs = [
|
|
ModelConfig(alias="test", model="test", inference_parameters=InferenceParameters()),
|
|
ModelConfig(alias="test2", model="test2", inference_parameters=InferenceParameters()),
|
|
]
|
|
stub_model_configs_dict_list = [mc.model_dump() for mc in stub_model_configs]
|
|
assert load_model_configs([]) == []
|
|
assert load_model_configs(None) == []
|
|
assert load_model_configs(stub_model_configs) == stub_model_configs
|
|
|
|
with tempfile.NamedTemporaryFile(prefix="model_configs", suffix=".yaml") as tmp_file:
|
|
model_configs = {"model_configs": stub_model_configs_dict_list}
|
|
tmp_file.write(yaml.safe_dump(model_configs).encode("utf-8"))
|
|
tmp_file.flush()
|
|
assert load_model_configs(tmp_file.name) == stub_model_configs
|
|
|
|
with tempfile.NamedTemporaryFile(prefix="model_configs", suffix=".json") as tmp_file:
|
|
model_configs = {"model_configs": stub_model_configs_dict_list}
|
|
tmp_file.write(json.dumps(model_configs).encode("utf-8"))
|
|
tmp_file.flush()
|
|
assert load_model_configs(tmp_file.name) == stub_model_configs
|
|
|
|
with pytest.raises(
|
|
InvalidConfigError,
|
|
match="The list of model configs must be provided under model_configs in the configuration file.",
|
|
):
|
|
# test failure when the list is not grouped under model_configs in the yaml file
|
|
with tempfile.NamedTemporaryFile(prefix="model_configs", suffix=".yaml") as tmp_file:
|
|
model_configs = {"some_other_key": "invalid_config"}
|
|
tmp_file.write(yaml.safe_dump(model_configs).encode("utf-8"))
|
|
tmp_file.flush()
|
|
load_model_configs(tmp_file.name)
|
|
|
|
with pytest.raises(ValidationError):
|
|
# test failure when model config validation fails (tests line 220)
|
|
with tempfile.NamedTemporaryFile(prefix="model_configs", suffix=".json") as tmp_file:
|
|
invalid_model_configs = {"model_configs": [{"invalid": "config"}]}
|
|
tmp_file.write(json.dumps(invalid_model_configs).encode("utf-8"))
|
|
tmp_file.flush()
|
|
load_model_configs(tmp_file.name)
|