# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 from unittest.mock import Mock, patch import pytest from data_designer.config.base import SkipConfig from data_designer.config.column_configs import ( ExpressionColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, LLMTextColumnConfig, SamplerColumnConfig, Score, SeedDatasetColumnConfig, ValidationColumnConfig, ) from data_designer.config.models import ImageContext, ModalityDataType from data_designer.config.processors import ( DropColumnsProcessorConfig, SchemaTransformProcessorConfig, ) from data_designer.config.utils.code_lang import CodeLang from data_designer.config.validator_params import CodeValidatorParams from data_designer.engine.validation import ( Violation, ViolationLevel, ViolationType, rich_print_violations, validate_code_validation, validate_columns_not_all_dropped, validate_data_designer_config, validate_drop_columns_processor, validate_expression_references, validate_prompt_templates, validate_schema_transform_processor, validate_skip_references, ) STUB_MODEL_ALIAS = "stub-alias" VALID_COLUMNS = [ SamplerColumnConfig( name="random_number", sampler_type="uniform", params={"low": 0, "high": 10}, ), LLMTextColumnConfig( name="valid_reference", prompt="Why is {{ random_number }} your favorite number?", model_alias=STUB_MODEL_ALIAS, ), LLMCodeColumnConfig( name="code_column_python", prompt="Generate some python about {{ valid_reference }}.", code_lang="python", model_alias=STUB_MODEL_ALIAS, ), ] INVALID_COLUMNS = [ LLMTextColumnConfig( name="text_no_references", prompt="Generate a name for the person", model_alias=STUB_MODEL_ALIAS, ), LLMTextColumnConfig( name="text_invalid_reference", prompt="Generate a name for the person: {{ this_column_does_not_exist }}", model_alias=STUB_MODEL_ALIAS, ), LLMJudgeColumnConfig( name="judge_no_references", prompt="Judge the name for the person.", scores=[Mock(spec=Score)], model_alias=STUB_MODEL_ALIAS, ), LLMJudgeColumnConfig( name="judge_invalid_reference", prompt="Judge the name for the person: {{ this_column_does_not_exist }}", scores=[Mock(spec=Score)], model_alias=STUB_MODEL_ALIAS, ), ValidationColumnConfig( name="code_validation_python", target_columns=["code_column_missing"], validator_type="code", validator_params=CodeValidatorParams(code_lang=CodeLang.SQL_ANSI), ), ValidationColumnConfig( name="code_validation_ansi", target_columns=["code_column_python"], validator_type="code", validator_params=CodeValidatorParams(code_lang=CodeLang.SQL_ANSI), ), ValidationColumnConfig( name="code_validation_not_code", target_columns=["text_no_references"], validator_type="code", validator_params=CodeValidatorParams(code_lang=CodeLang.PYTHON), ), ] COLUMNS = VALID_COLUMNS + INVALID_COLUMNS PROCESSOR_CONFIGS = [ DropColumnsProcessorConfig( name="drop_columns_processor", column_names=["inexistent_column"], ), SchemaTransformProcessorConfig( name="schema_transform_processor_invalid_reference", template={"text": "{{ invalid_reference }}"}, ), ] ALLOWED_REFERENCE = [c.name for c in COLUMNS] @patch("data_designer.engine.validation.validate_prompt_templates") @patch("data_designer.engine.validation.validate_code_validation") @patch("data_designer.engine.validation.validate_expression_references") @patch("data_designer.engine.validation.validate_skip_references") @patch("data_designer.engine.validation.validate_columns_not_all_dropped") @patch("data_designer.engine.validation.validate_drop_columns_processor") @patch("data_designer.engine.validation.validate_schema_transform_processor") def test_validate_data_designer_config( mock_validate_schema_transform_processor: Mock, mock_validate_drop_columns_processor: Mock, mock_validate_columns_not_all_dropped: Mock, mock_validate_skip_references: Mock, mock_validate_expression_references: Mock, mock_validate_code_validation: Mock, mock_validate_prompt_templates: Mock, ) -> None: mock_validate_columns_not_all_dropped.return_value = [ Violation( column="test_column", type=ViolationType.ALL_COLUMNS_DROPPED, message="test error message", level=ViolationLevel.ERROR, ) ] mock_validate_expression_references.return_value = [ Violation( column="test_column", type=ViolationType.EXPRESSION_REFERENCE_MISSING, message="test error message", level=ViolationLevel.ERROR, ) ] mock_validate_code_validation.return_value = [ Violation( column="test_column", type=ViolationType.CODE_COLUMN_MISSING, message="test error message", level=ViolationLevel.ERROR, ) ] mock_validate_prompt_templates.return_value = [ Violation( column="test_column", type=ViolationType.PROMPT_WITHOUT_REFERENCES, message="test error message", level=ViolationLevel.ERROR, ) ] mock_validate_drop_columns_processor.return_value = [ Violation( column="test_column", type=ViolationType.INVALID_COLUMN, message="test error message", level=ViolationLevel.ERROR, ) ] mock_validate_schema_transform_processor.return_value = [ Violation( column="text", type=ViolationType.INVALID_REFERENCE, message="Ancillary dataset processor attempts to reference columns 'invalid_reference' in the template for 'text', but the columns are not defined in the dataset.", level=ViolationLevel.ERROR, ) ] mock_validate_skip_references.return_value = [ Violation( column="test_column", type=ViolationType.SKIP_REFERENCE_MISSING, message="test error message", level=ViolationLevel.ERROR, ) ] violations = validate_data_designer_config(COLUMNS, PROCESSOR_CONFIGS, ALLOWED_REFERENCE) assert len(violations) == 7 mock_validate_columns_not_all_dropped.assert_called_once() mock_validate_expression_references.assert_called_once() mock_validate_skip_references.assert_called_once() mock_validate_code_validation.assert_called_once() mock_validate_prompt_templates.assert_called_once() mock_validate_drop_columns_processor.assert_called_once() mock_validate_schema_transform_processor.assert_called_once() def test_validate_prompt_templates(): violations = validate_prompt_templates(COLUMNS, ALLOWED_REFERENCE) assert len(violations) == 4 assert violations[0].type == ViolationType.PROMPT_WITHOUT_REFERENCES assert violations[1].type == ViolationType.INVALID_REFERENCE assert violations[2].type == ViolationType.PROMPT_WITHOUT_REFERENCES assert violations[3].type == ViolationType.INVALID_REFERENCE def test_validate_code_validation(): violations = validate_code_validation(COLUMNS) assert len(violations) == 3 assert violations[0].type == ViolationType.CODE_COLUMN_MISSING assert violations[1].type == ViolationType.CODE_LANG_MISMATCH assert violations[2].type == ViolationType.CODE_COLUMN_NOT_CODE def test_validate_detect_f_string_syntax(): columns = VALID_COLUMNS columns.append( LLMTextColumnConfig( name="f_string_ref", prompt="Why is {random_number} your favorite number? {{ valid_reference }}", model_alias=STUB_MODEL_ALIAS, ) ) violations = validate_prompt_templates(columns, [c.name for c in columns]) assert len(violations) == 1 assert violations[0].type == ViolationType.F_STRING_SYNTAX assert violations[0].level == ViolationLevel.WARNING def test_validate_column_config_with_multi_modal_context(): column = LLMTextColumnConfig( name="image_description", prompt="Describe the image in no less that 10 sentences.", model_alias=STUB_MODEL_ALIAS, multi_modal_context=[ImageContext(column_name="image_url", data_type=ModalityDataType.URL)], ) violations = validate_prompt_templates([column], [column.name]) # there should be no violations because the prompt does not reference any columns and it's not necessary # when multi modal context is provided assert len(violations) == 0 def test_validate_columns_not_all_dropped(): violations = validate_columns_not_all_dropped( [ SamplerColumnConfig( name="random_number", sampler_type="uniform", params={"low": 0, "high": 10}, drop=True, ), LLMTextColumnConfig( name="valid_reference", prompt="Why is {{ random_number }} your favorite number?", model_alias=STUB_MODEL_ALIAS, drop=True, ), ] ) assert len(violations) == 1 assert violations[0].type == ViolationType.ALL_COLUMNS_DROPPED def test_validate_columns_not_all_dropped_allows_seeded_processor_only_config(): violations = validate_columns_not_all_dropped( [SeedDatasetColumnConfig(name="seed_text")], processor_configs=[ SchemaTransformProcessorConfig(name="format", template={"text": "{{ seed_text }}"}), ], ) assert violations == [] def test_validate_columns_not_all_dropped_rejects_seeded_processor_only_config_with_no_output_columns(): violations = validate_columns_not_all_dropped( [SeedDatasetColumnConfig(name="seed_text")], processor_configs=[ DropColumnsProcessorConfig(name="drop_seed", column_names=["seed_text"]), ], ) assert len(violations) == 1 assert violations[0].type == ViolationType.ALL_COLUMNS_DROPPED def test_validate_columns_not_all_dropped_allows_generated_columns_dropped_by_processors(): violations = validate_columns_not_all_dropped( [ LLMTextColumnConfig(name="question", prompt="Generate a question.", model_alias=STUB_MODEL_ALIAS), LLMTextColumnConfig(name="answer", prompt="Answer {{ question }}.", model_alias=STUB_MODEL_ALIAS), ], processor_configs=[ DropColumnsProcessorConfig(name="drop_raw", column_names=["question", "answer"]), SchemaTransformProcessorConfig(name="format", template={"messages": "{{ question }} {{ answer }}"}), ], ) assert violations == [] def test_validate_columns_not_all_dropped_still_rejects_seed_only_config(): violations = validate_columns_not_all_dropped([SeedDatasetColumnConfig(name="seed_text")]) assert len(violations) == 1 assert violations[0].type == ViolationType.ALL_COLUMNS_DROPPED def test_validate_expression_references(): violations = validate_expression_references( [ ExpressionColumnConfig( name="expression_column", expr="{{ random_number }}", dtype="int", ), ], allowed_references=["some_other_column"], ) assert len(violations) == 1 assert violations[0].type == ViolationType.EXPRESSION_REFERENCE_MISSING def test_validate_schema_transform_processor(): violations = validate_schema_transform_processor(COLUMNS, PROCESSOR_CONFIGS) assert len(violations) == 1 assert violations[0].type == ViolationType.INVALID_REFERENCE assert violations[0].column is None assert ( violations[0].message == "Ancillary dataset processor attempts to reference columns 'invalid_reference' in the template for 'text', but the columns are not defined in the dataset." ) assert violations[0].level == ViolationLevel.ERROR @pytest.mark.parametrize( "extract_reasoning, expected_violations", [ (True, 0), (False, 1), ], ) def test_validate_drop_columns_processor_reasoning_column(extract_reasoning, expected_violations): columns = [ LLMTextColumnConfig( name="answer", prompt="Answer the question.", model_alias=STUB_MODEL_ALIAS, extract_reasoning_content=extract_reasoning, ), ] processor_configs = [ DropColumnsProcessorConfig( name="drop_reasoning", column_names=["answer__reasoning_content"], ), ] violations = validate_drop_columns_processor(columns, processor_configs) assert len(violations) == expected_violations @pytest.mark.parametrize( "pattern, expected_violations, expected_level", [ ("*__reasoning_content", 0, None), ("zzz_*", 1, ViolationLevel.WARNING), ], ) def test_validate_drop_columns_processor_glob(pattern, expected_violations, expected_level): columns = [ LLMTextColumnConfig( name="answer", prompt="Answer the question.", model_alias=STUB_MODEL_ALIAS, extract_reasoning_content=True, ), ] processor_configs = [ DropColumnsProcessorConfig(name="drop_glob", column_names=[pattern]), ] violations = validate_drop_columns_processor(columns, processor_configs) assert len(violations) == expected_violations if expected_level: assert violations[0].level == expected_level @patch("data_designer.engine.validation.Console.print") def test_rich_print_violations(mock_console_print): rich_print_violations([]) mock_console_print.assert_not_called() rich_print_violations( [ Violation( column="test_column", type=ViolationType.EXPRESSION_REFERENCE_MISSING, message="test error message", level=ViolationLevel.ERROR, ) ] ) mock_console_print.assert_called_once() def test_validate_skip_references_missing_column() -> None: columns = [ LLMTextColumnConfig( name="with_skip", prompt="test {{ real_col }}", model_alias=STUB_MODEL_ALIAS, skip=SkipConfig(when="{{ ghost }}"), ), ] violations = validate_skip_references(columns, allowed_references=["real_col"]) assert len(violations) == 1 assert violations[0].type == ViolationType.SKIP_REFERENCE_MISSING assert violations[0].column == "with_skip" def test_validate_skip_references_valid() -> None: columns = [ LLMTextColumnConfig( name="with_skip", prompt="test {{ gate }}", model_alias=STUB_MODEL_ALIAS, skip=SkipConfig(when="{{ gate == 0 }}"), ), ] violations = validate_skip_references(columns, allowed_references=["gate", "with_skip"]) assert len(violations) == 0 def test_validate_skip_on_sampler_seed() -> None: col = SamplerColumnConfig.model_construct( name="sampler_with_skip", column_type="sampler", sampler_type="uniform", params={"low": 0, "high": 10}, skip=SkipConfig(when="{{ y }}"), drop=False, allow_resize=False, propagate_skip=True, ) violations = validate_skip_references([col], allowed_references=["y"]) assert len(violations) == 1 assert violations[0].type == ViolationType.SKIP_ON_SAMPLER_SEED assert violations[0].column == "sampler_with_skip" def test_validate_skip_with_allow_resize() -> None: col = LLMTextColumnConfig.model_construct( name="with_skip", column_type="llm-text", prompt="test {{ gate }}", model_alias=STUB_MODEL_ALIAS, skip=SkipConfig(when="{{ gate == 0 }}"), allow_resize=True, drop=False, propagate_skip=True, ) violations = validate_skip_references([col], allowed_references=["gate"]) assert len(violations) == 1 assert violations[0].type == ViolationType.SKIP_WITH_ALLOW_RESIZE assert violations[0].column == "with_skip"