mirror of
https://github.com/NVIDIA-NeMo/DataDesigner
synced 2026-05-24 09:48:29 +00:00
* update script * update headers * refactor a bit and add test script * update headers * update for edge case * update headers * add step to get file creation date * use git history to get copyright year * generation type is printed with inference parameters * fix unit test
33 lines
1.1 KiB
Python
33 lines
1.1 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import pandas as pd
|
|
import pytest
|
|
|
|
from data_designer.config.validator_params import LocalCallableValidatorParams
|
|
from data_designer.engine.validators.local_callable import LocalCallableValidator
|
|
|
|
|
|
@pytest.fixture()
|
|
def stub_data() -> list[dict]:
|
|
return [{"text": "Sample text", "id": 1}]
|
|
|
|
|
|
def test_validate_with_callback_validator(stub_data: list[dict]):
|
|
def callback_fn(df: pd.DataFrame) -> pd.DataFrame:
|
|
if df.iloc[0]["text"] == "Sample text":
|
|
return pd.DataFrame([{"is_valid": True, "confidence": "0.98"}])
|
|
else:
|
|
return pd.DataFrame([{"is_valid": False, "confidence": "0.0"}])
|
|
|
|
validator = LocalCallableValidator(
|
|
LocalCallableValidatorParams(
|
|
validation_function=callback_fn,
|
|
)
|
|
)
|
|
|
|
results = validator.run_validation(stub_data)
|
|
|
|
assert len(results) == 1
|
|
assert results[0].is_valid is True
|
|
assert results[0].confidence == "0.98"
|