DataDesigner/tests/engine/sampling_gen/test_utils.py
Johnny Greco f8c201e085
chore: update header script to check for diffs (#195)
* update script

* update headers

* refactor a bit and add test script

* update headers

* update for edge case

* update headers

* add step to get file creation date

* use git history to get copyright year

* generation type is printed with inference parameters

* fix unit test
2026-01-09 17:10:58 -05:00

36 lines
1.4 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
from data_designer.engine.sampling_gen.utils import check_random_state
@pytest.mark.parametrize(
"test_case,input_value,expected_type,expected_seed",
[
("none_input", None, "np.random.mtrand._rand", None),
("np_random_input", np.random, "np.random.mtrand._rand", None),
("integer_input", 42, "np.random.RandomState", 42),
("random_state_input", np.random.RandomState(123), "np.random.RandomState", 123),
],
)
def test_check_random_state_scenarios(test_case, input_value, expected_type, expected_seed):
if test_case == "random_state_input":
result = check_random_state(input_value)
assert result is input_value
else:
result = check_random_state(input_value)
if expected_type == "np.random.mtrand._rand":
assert result is np.random.mtrand._rand
elif expected_type == "np.random.RandomState":
assert isinstance(result, np.random.RandomState)
if expected_seed is not None:
assert result.get_state()[1][0] == expected_seed
def test_check_random_state_invalid():
with pytest.raises(ValueError, match="'invalid' cannot be used to seed a numpy.random.RandomState instance"):
check_random_state("invalid")