mirror of
https://github.com/NVIDIA-NeMo/DataDesigner
synced 2026-05-24 09:48:29 +00:00
75 lines
2.6 KiB
Python
75 lines
2.6 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from unittest.mock import patch
|
|
|
|
from jinja2 import TemplateSyntaxError
|
|
import pytest
|
|
|
|
from data_designer.config.utils.errors import UserJinjaTemplateSyntaxError
|
|
from data_designer.config.utils.misc import (
|
|
assert_valid_jinja2_template,
|
|
get_prompt_template_keywords,
|
|
json_indent_list_of_strings,
|
|
kebab_to_snake,
|
|
template_error_handler,
|
|
)
|
|
|
|
|
|
def test_kebab_to_snake():
|
|
assert kebab_to_snake("kebab-case-string") == "kebab_case_string"
|
|
assert kebab_to_snake("simple") == "simple"
|
|
|
|
|
|
def test_template_error_handler():
|
|
val_set = False
|
|
with template_error_handler():
|
|
val_set = True
|
|
# shouldn't raise any errors
|
|
assert val_set
|
|
|
|
|
|
def test_template_error_handler_invalid_template():
|
|
with pytest.raises(UserJinjaTemplateSyntaxError):
|
|
with template_error_handler():
|
|
raise TemplateSyntaxError("Invalid template", 1)
|
|
|
|
|
|
def test_template_error_handler_catch_non_template_error():
|
|
with pytest.raises(ValueError):
|
|
with template_error_handler():
|
|
raise ValueError("Invalid value")
|
|
|
|
|
|
def test_assert_valid_jinja2_template():
|
|
assert_valid_jinja2_template("{% if name %}Hello, {{ name }}!{% endif %}")
|
|
assert_valid_jinja2_template("{% for item in items %}{{ item }}{% endfor %}")
|
|
|
|
with pytest.raises(UserJinjaTemplateSyntaxError):
|
|
assert_valid_jinja2_template("{{ name }")
|
|
|
|
with patch("data_designer.config.utils.misc.ImmutableSandboxedEnvironment.parse") as mock_sandbox_env_parse:
|
|
mock_sandbox_env_parse.side_effect = ValueError("Invalid Value")
|
|
with pytest.raises(ValueError):
|
|
assert_valid_jinja2_template("{hello}")
|
|
|
|
|
|
def test_get_prompt_template_keywords():
|
|
assert get_prompt_template_keywords("{{ first_name }} {{last_name}}") == {"first_name", "last_name"}
|
|
assert get_prompt_template_keywords("{% if first_name %}Hello, {{ last_name }}!{% endif %}") == {
|
|
"first_name",
|
|
"last_name",
|
|
}
|
|
|
|
with pytest.raises(UserJinjaTemplateSyntaxError):
|
|
get_prompt_template_keywords("{{ name }")
|
|
|
|
|
|
def test_json_indent_list_of_strings():
|
|
assert json_indent_list_of_strings([]) is None
|
|
assert json_indent_list_of_strings(["hello", "world"]) == ["hello", "world"]
|
|
assert json_indent_list_of_strings(["hello", "world"], indent=2) == ["hello", "world"]
|
|
assert (
|
|
json_indent_list_of_strings(["hello", "world", "foo", "bar"], indent=2)
|
|
== '[\n "hello",\n "world",\n "foo",\n "bar"\n]'
|
|
)
|