mirror of
https://github.com/NVIDIA-NeMo/DataDesigner
synced 2026-05-24 09:48:29 +00:00
* update script * update headers * refactor a bit and add test script * update headers * update for edge case * update headers * add step to get file creation date * use git history to get copyright year * generation type is printed with inference parameters * fix unit test
51 lines
1.5 KiB
Python
51 lines
1.5 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import pandas as pd
|
|
import pytest
|
|
|
|
from data_designer.config.seed_source import DataFrameSeedSource
|
|
from data_designer.engine.resources.seed_reader import (
|
|
DataFrameSeedReader,
|
|
LocalFileSeedReader,
|
|
SeedReaderError,
|
|
SeedReaderRegistry,
|
|
)
|
|
from data_designer.engine.secret_resolver import PlaintextResolver
|
|
|
|
|
|
def test_one_reader_per_seed_type():
|
|
local_1 = LocalFileSeedReader()
|
|
local_2 = LocalFileSeedReader()
|
|
|
|
with pytest.raises(SeedReaderError):
|
|
SeedReaderRegistry([local_1, local_2])
|
|
|
|
registry = SeedReaderRegistry([local_1])
|
|
|
|
with pytest.raises(SeedReaderError):
|
|
registry.add_reader(local_2)
|
|
|
|
|
|
def test_get_reader_basic():
|
|
local_reader = LocalFileSeedReader()
|
|
df_reader = DataFrameSeedReader()
|
|
registry = SeedReaderRegistry([local_reader, df_reader])
|
|
|
|
df = pd.DataFrame(data={"a": [1, 2, 3]})
|
|
local_seed_config = DataFrameSeedSource(df=df)
|
|
|
|
reader = registry.get_reader(local_seed_config, PlaintextResolver())
|
|
|
|
assert reader == df_reader
|
|
|
|
|
|
def test_get_reader_missing():
|
|
local_reader = LocalFileSeedReader()
|
|
registry = SeedReaderRegistry([local_reader])
|
|
|
|
df = pd.DataFrame(data={"a": [1, 2, 3]})
|
|
local_seed_config = DataFrameSeedSource(df=df)
|
|
|
|
with pytest.raises(SeedReaderError):
|
|
registry.get_reader(local_seed_config, PlaintextResolver())
|