mirror of
https://github.com/NVIDIA-NeMo/DataDesigner
synced 2026-05-24 09:48:29 +00:00
* support loading config files from http urls - allow config builder and CLI loader to load YAML/JSON configs from HTTP(S) URLs - reject unsupported URL extensions and remote Python module URLs - update CLI help text and add tests for URL success/failure paths * harden remote config loading and deduplicate URL validation - Add size limit (10 MB) when fetching configs from URLs - Validate parsed YAML is a dict before returning - Make is_http_url public and reuse it in CLI validate_url - Replace local CONFIG_FILE_EXTENSIONS with shared constant - Add tests for is_http_url, URL-with-no-extension edge cases * use requests for remote config loading - replace urllib URL fetching with requests and status checks - parse remote payloads via smart_load_yaml for consistent validation - expand tests for HTTP errors, size limits, and non-dict payloads * lower remote config size limit to 1 MB * improve config URL HTTP error reporting Add granular 401/403/404 and generic HTTP status errors for remote config fetching to make failures actionable. Clarify that authenticated config URL loading is not currently supported and update tests for status-aware behavior. * rewrite github blob URLs for remote loading Handle GitHub blob links by rewriting them to raw content URLs for config and dataframe HTTP loaders, preserving query params but avoiding query token leaks in logs. This also fixes extension detection for URLs with query strings and adds coverage for rewrite behavior. * remove validate_url wrapper in favor of is_http_url The validate_url function in cli/utils was just a thin wrapper around is_http_url from io_helpers. Remove it and have callers use is_http_url directly for clarity and reduced indirection. * fix optional type for artifact_path CLI option * fix URL recursion in smart_load_yaml - avoid treating remote payload strings as new URL inputs - add regression test for URL string payloads from remote config * rewrite huggingface blob URLs for remote loading
188 lines
5.6 KiB
Python
188 lines
5.6 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import subprocess
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from data_designer.cli.utils import (
|
|
check_ngc_cli_available,
|
|
get_ngc_version,
|
|
validate_numeric_range,
|
|
)
|
|
from data_designer.config.errors import InvalidConfigError, InvalidFileFormatError, InvalidFilePathError
|
|
from data_designer.config.utils.io_helpers import (
|
|
ensure_config_dir_exists,
|
|
is_http_url,
|
|
load_config_file,
|
|
save_config_file,
|
|
)
|
|
|
|
|
|
def test_ensure_config_dir_exists(tmp_path: Path) -> None:
|
|
"""Test creating config directory."""
|
|
test_dir = tmp_path / "test_config"
|
|
assert not test_dir.exists()
|
|
|
|
ensure_config_dir_exists(test_dir)
|
|
assert test_dir.exists()
|
|
assert test_dir.is_dir()
|
|
|
|
|
|
def test_save_and_load_config_file(tmp_path: Path) -> None:
|
|
"""Test saving and loading config files."""
|
|
config_file = tmp_path / "test_config.yaml"
|
|
test_config = {
|
|
"providers": [{"name": "test", "endpoint": "https://test.com", "provider_type": "openai"}],
|
|
"default": "test",
|
|
}
|
|
|
|
# Save config
|
|
save_config_file(config_file, test_config)
|
|
assert config_file.exists()
|
|
|
|
# Load config
|
|
loaded_config = load_config_file(config_file)
|
|
assert loaded_config == test_config
|
|
|
|
|
|
def test_load_config_file_not_found(tmp_path: Path) -> None:
|
|
"""Test loading non-existent file."""
|
|
config_file = tmp_path / "nonexistent.yaml"
|
|
|
|
with pytest.raises(InvalidFilePathError):
|
|
load_config_file(config_file)
|
|
|
|
|
|
def test_load_config_file_invalid_yaml(tmp_path: Path) -> None:
|
|
"""Test loading invalid YAML."""
|
|
config_file = tmp_path / "invalid.yaml"
|
|
config_file.write_text("invalid: yaml: content:\n - unmatched")
|
|
|
|
with pytest.raises(InvalidFileFormatError):
|
|
load_config_file(config_file)
|
|
|
|
|
|
def test_load_config_file_empty(tmp_path: Path) -> None:
|
|
"""Test loading empty file."""
|
|
config_file = tmp_path / "empty.yaml"
|
|
config_file.write_text("")
|
|
|
|
with pytest.raises(InvalidConfigError):
|
|
load_config_file(config_file)
|
|
|
|
|
|
def test_is_http_url() -> None:
|
|
"""Test URL validation."""
|
|
# Valid URLs
|
|
assert is_http_url("https://api.example.com/v1")
|
|
assert is_http_url("http://localhost:8000")
|
|
assert is_http_url("https://nvidia.com")
|
|
|
|
# Invalid URLs
|
|
assert not is_http_url("")
|
|
assert not is_http_url("not-a-url")
|
|
assert not is_http_url("ftp://example.com")
|
|
assert not is_http_url("https://")
|
|
assert not is_http_url("http://")
|
|
|
|
|
|
def test_validate_numeric_range() -> None:
|
|
"""Test numeric range validation."""
|
|
# Valid values
|
|
is_valid, value = validate_numeric_range("0.5", 0.0, 1.0)
|
|
assert is_valid
|
|
assert value == 0.5
|
|
|
|
is_valid, value = validate_numeric_range("0", 0.0, 1.0)
|
|
assert is_valid
|
|
assert value == 0.0
|
|
|
|
is_valid, value = validate_numeric_range("1.0", 0.0, 1.0)
|
|
assert is_valid
|
|
assert value == 1.0
|
|
|
|
# Invalid values - out of range
|
|
is_valid, value = validate_numeric_range("-0.1", 0.0, 1.0)
|
|
assert not is_valid
|
|
assert value is None
|
|
|
|
is_valid, value = validate_numeric_range("1.1", 0.0, 1.0)
|
|
assert not is_valid
|
|
assert value is None
|
|
|
|
# Invalid values - not numeric
|
|
is_valid, value = validate_numeric_range("abc", 0.0, 1.0)
|
|
assert not is_valid
|
|
assert value is None
|
|
|
|
|
|
@patch("data_designer.cli.utils.shutil.which")
|
|
@patch("data_designer.cli.utils.get_ngc_version")
|
|
def test_check_ngc_cli_available_returns_true(mock_get_version: MagicMock, mock_which: MagicMock) -> None:
|
|
"""Test NGC CLI availability check when NGC CLI is installed."""
|
|
mock_which.return_value = "/usr/local/bin/ngc"
|
|
mock_get_version.return_value = "NGC CLI 3.41.4"
|
|
|
|
assert check_ngc_cli_available() is True
|
|
mock_which.assert_called_once_with("ngc")
|
|
|
|
|
|
@patch("data_designer.cli.utils.shutil.which")
|
|
def test_check_ngc_cli_available_returns_false(mock_which: MagicMock) -> None:
|
|
"""Test NGC CLI availability check when NGC CLI is not installed."""
|
|
mock_which.return_value = None
|
|
|
|
assert check_ngc_cli_available() is False
|
|
mock_which.assert_called_once_with("ngc")
|
|
|
|
|
|
@patch("data_designer.cli.utils.subprocess.run")
|
|
def test_get_ngc_version_success(mock_run: MagicMock) -> None:
|
|
"""Test getting NGC CLI version successfully."""
|
|
mock_result = MagicMock()
|
|
mock_result.stdout = "NGC CLI 3.41.4\n"
|
|
mock_run.return_value = mock_result
|
|
|
|
version = get_ngc_version()
|
|
|
|
assert version == "NGC CLI 3.41.4"
|
|
mock_run.assert_called_once_with(
|
|
["ngc", "--version"],
|
|
capture_output=True,
|
|
text=True,
|
|
check=True,
|
|
timeout=5,
|
|
)
|
|
|
|
|
|
@patch("data_designer.cli.utils.subprocess.run")
|
|
def test_get_ngc_version_handles_error(mock_run: MagicMock) -> None:
|
|
"""Test getting NGC CLI version when command fails."""
|
|
mock_run.side_effect = subprocess.CalledProcessError(1, "ngc")
|
|
|
|
version = get_ngc_version()
|
|
|
|
assert version is None
|
|
|
|
|
|
@patch("data_designer.cli.utils.subprocess.run")
|
|
def test_get_ngc_version_handles_timeout(mock_run: MagicMock) -> None:
|
|
"""Test getting NGC CLI version when command times out."""
|
|
mock_run.side_effect = subprocess.TimeoutExpired("ngc", 5)
|
|
|
|
version = get_ngc_version()
|
|
|
|
assert version is None
|
|
|
|
|
|
@patch("data_designer.cli.utils.subprocess.run")
|
|
def test_get_ngc_version_handles_file_not_found(mock_run: MagicMock) -> None:
|
|
"""Test getting NGC CLI version when NGC CLI is not found."""
|
|
mock_run.side_effect = FileNotFoundError()
|
|
|
|
version = get_ngc_version()
|
|
|
|
assert version is None
|