DataDesigner/packages/data-designer/tests/interface/test_results.py
Przemysław Boruta 0afe287a5f
feat(results): add export() method and --output-format CLI flag (#540)
* feat(results): add export() method and --output-format CLI flag

Adds DatasetCreationResults.export(path, format=) supporting jsonl,
csv, and parquet. The CLI create command gains --output-format / -f
which writes dataset.<format> alongside the parquet batch files.

* fix(cli): validate output_format before dataset generation

* fix(cli): remove top-level results import from create.py to preserve lazy loading

* fix(results): address andreatgretel review — error types, UX ordering, import hygiene

- Derive SUPPORTED_EXPORT_FORMATS from get_args(ExportFormat) so the two can't drift apart
- Replace ValueError with InvalidFileFormatError in export() — consistent with project error conventions
- Add date_format="iso" to to_json() for consistent datetime serialization across formats
- Add click.Choice(SUPPORTED_EXPORT_FORMATS) to --output-format CLI option for parse-time
  validation, better --help output, and tab completion
- Fix double load_dataset() in run_create: inline len() so the DataFrame ref dies before export
- Move success message after the export block to avoid "Dataset created" followed by "Export failed"
- Move imports to module level in test_results.py (json, Path, lazy already imported)
- Add controller-level tests for output_format happy path, bad format rejection, and export failure

* fix(results): correct Raises docstring — ValueError -> InvalidFileFormatError

* feat(results): stream batch files in export() to avoid OOM on large datasets

- Rewrite export() to read batch parquet files one at a time instead of
  materialising the full dataset via load_dataset(); peak memory is now
  proportional to a single batch regardless of dataset size
- Infer output format from file extension by default; format= parameter
  kept as an explicit override (e.g. writing .txt as JSONL)
- _export_parquet unifies schemas across batches (pa.unify_schemas) to
  handle type drift (e.g. int64 vs float64 in the same column)
- Drop format= from the controller's export() call — path already carries
  the correct extension
- Rewrite export tests around real batch parquet files (stub_batch_dir
  fixture); add tests for multi-batch output, schema unification, unknown
  extension, empty batch directory, and explicit format override

* fix(results): address nabinchha review — memory safety, error wrapping, UX

- Replace load_dataset() with count_records() in CLI to avoid OOM on
  large datasets; add count_records() method using pq.read_metadata
  (reads file metadata only, no data pages loaded)
- Remove redundant format validation in controller — click.Choice in
  create.py already rejects invalid values at parse time; dead code
  removed along with corresponding test
- Wrap pa.unify_schemas / table.cast ArrowInvalid as InvalidFileFormatError
  to normalize third-party exceptions at module boundaries per AGENTS.md
- Lowercase file extension before format lookup so .JSONL/.CSV/.PARQUET
  are accepted without error
- Add clarifying comment to trailing-newline guard in _export_jsonl
- Add tests: count_records(), uppercase extension, incompatible schemas

* fix(results): fix parquet export schema unification and controller path bug

- Use promote_options="permissive" in pa.unify_schemas so minor numeric
  type drift (int64 vs float64) is handled by promotion instead of raising
- Also catch ArrowTypeError from unify_schemas and ValueError from
  table.cast() — the actual exception types thrown by pyarrow for these
  cases (ArrowInvalid alone is not sufficient)
- Wrap base_dataset_path in Path() in generation_controller.run_create
  to guard against callers that return a str (mock returns str, Path
  does not support / with str operands)
- Update test_export_parquet_incompatible_schemas_raises to match the
  new error source: with permissive unification, different-column-name
  batches fail at cast() not at unify_schemas(), so the match string
  changes from "Cannot unify batch schemas" to "Cannot cast batch"

* fix(results,cli): address nabinchha review round 2

- Use public pa.ArrowInvalid/ArrowTypeError instead of pa.lib.* in _export_parquet
- Drop dead trailing-newline guard in _export_jsonl; skip empty batches with `if content`
- Rename num_records→actual_record_count after count_records() call to avoid shadowing
- Unlink partial export file before re-raising on export failure in run_create
- Export filename now uses dataset_name (<dataset-name>.<format>) instead of literal "dataset"
- Update help text and tests to match new export filename convention

---------

Co-authored-by: Andre Manoel <165937436+andreatgretel@users.noreply.github.com>
2026-05-06 17:13:57 -06:00

451 lines
20 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import json
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
import data_designer.lazy_heavy_imports as lazy
from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
from data_designer.config.config_builder import DataDesignerConfigBuilder
from data_designer.config.dataset_metadata import DatasetMetadata
from data_designer.config.errors import InvalidFileFormatError
from data_designer.config.preview_results import PreviewResults
from data_designer.config.utils.errors import DatasetSampleDisplayError
from data_designer.config.utils.visualization import display_sample_record as display_fn
from data_designer.engine.dataset_builders.errors import ArtifactStorageError
from data_designer.engine.storage.artifact_storage import ArtifactStorage
from data_designer.interface.results import DatasetCreationResults
@pytest.fixture
def stub_artifact_storage(stub_dataframe):
"""Mock artifact storage that returns a test DataFrame."""
storage = MagicMock(spec=ArtifactStorage)
storage.load_dataset.return_value = stub_dataframe
return storage
@pytest.fixture
def stub_dataset_metadata():
"""Fixture providing a DatasetMetadata instance."""
return DatasetMetadata()
@pytest.fixture
def stub_dataset_creation_results(
stub_artifact_storage, stub_dataset_profiler_results, stub_complete_builder, stub_dataset_metadata
):
"""Fixture providing a DatasetCreationResults instance."""
return DatasetCreationResults(
artifact_storage=stub_artifact_storage,
analysis=stub_dataset_profiler_results,
config_builder=stub_complete_builder,
dataset_metadata=stub_dataset_metadata,
)
def test_init(stub_artifact_storage, stub_dataset_profiler_results, stub_complete_builder, stub_dataset_metadata):
"""Test DatasetCreationResults initialization."""
results = DatasetCreationResults(
artifact_storage=stub_artifact_storage,
analysis=stub_dataset_profiler_results,
config_builder=stub_complete_builder,
dataset_metadata=stub_dataset_metadata,
)
assert results.artifact_storage == stub_artifact_storage
assert results._analysis == stub_dataset_profiler_results
assert results._config_builder == stub_complete_builder
assert results.dataset_metadata == stub_dataset_metadata
def test_load_dataset(stub_dataset_creation_results, stub_artifact_storage, stub_dataframe):
"""Test loading the dataset."""
dataset = stub_dataset_creation_results.load_dataset()
assert isinstance(dataset, lazy.pd.DataFrame)
stub_artifact_storage.load_dataset.assert_called_once()
lazy.pd.testing.assert_frame_equal(dataset, stub_dataframe)
def test_load_analysis(stub_dataset_creation_results, stub_dataset_profiler_results):
"""Test loading the analysis results."""
analysis = stub_dataset_creation_results.load_analysis()
assert isinstance(analysis, DatasetProfilerResults)
assert analysis == stub_dataset_profiler_results
def test_load_analysis_returns_same_instance(stub_dataset_creation_results):
"""Test that load_analysis returns the same analysis instance."""
analysis1 = stub_dataset_creation_results.load_analysis()
analysis2 = stub_dataset_creation_results.load_analysis()
assert analysis1 is analysis2
def test_record_sampler_dataset_initialization(stub_dataset_creation_results, stub_artifact_storage):
"""Test that _record_sampler_dataset cached property loads dataset correctly."""
# Access the cached property
dataset = stub_dataset_creation_results._record_sampler_dataset
# Verify load_dataset was called
stub_artifact_storage.load_dataset.assert_called_once()
lazy.pd.testing.assert_frame_equal(dataset, stub_artifact_storage.load_dataset.return_value)
@patch("data_designer.config.utils.visualization.display_sample_record", autospec=True)
def test_display_sample_record_with_default_params(
mock_display_sample_record, stub_dataset_creation_results, stub_dataframe
):
"""Test display_sample_record with default parameters."""
stub_dataset_creation_results.display_sample_record()
# Verify the underlying display_sample_record function was called
mock_display_sample_record.assert_called_once()
call_kwargs = mock_display_sample_record.call_args.kwargs
assert call_kwargs["syntax_highlighting_theme"] == "dracula"
assert call_kwargs["background_color"] is None
assert call_kwargs["record_index"] == 0
# Verify the record passed is the first row of the dataframe
lazy.pd.testing.assert_series_equal(mock_display_sample_record.call_args.kwargs["record"], stub_dataframe.iloc[0])
@patch("data_designer.config.utils.visualization.display_sample_record", autospec=True)
def test_display_sample_record_with_custom_index(
mock_display_sample_record, stub_dataset_creation_results, stub_dataframe
):
"""Test display_sample_record with a specific index."""
stub_dataset_creation_results.display_sample_record(index=5)
mock_display_sample_record.assert_called_once()
call_kwargs = mock_display_sample_record.call_args.kwargs
assert call_kwargs["record_index"] == 5
assert call_kwargs["syntax_highlighting_theme"] == "dracula"
assert call_kwargs["background_color"] is None
# Verify the record passed is the correct row
lazy.pd.testing.assert_series_equal(mock_display_sample_record.call_args.kwargs["record"], stub_dataframe.iloc[5])
@patch("data_designer.config.utils.visualization.display_sample_record", autospec=True)
def test_display_sample_record_with_custom_theme(mock_display_sample_record, stub_dataset_creation_results):
"""Test display_sample_record with custom syntax highlighting theme."""
stub_dataset_creation_results.display_sample_record(syntax_highlighting_theme="monokai")
mock_display_sample_record.assert_called_once()
call_kwargs = mock_display_sample_record.call_args.kwargs
assert call_kwargs["syntax_highlighting_theme"] == "monokai"
assert call_kwargs["background_color"] is None
@patch("data_designer.config.utils.visualization.display_sample_record", autospec=True)
def test_display_sample_record_with_custom_background_color(mock_display_sample_record, stub_dataset_creation_results):
"""Test display_sample_record with custom background color."""
stub_dataset_creation_results.display_sample_record(background_color="#282a36")
mock_display_sample_record.assert_called_once()
call_kwargs = mock_display_sample_record.call_args.kwargs
assert call_kwargs["syntax_highlighting_theme"] == "dracula"
assert call_kwargs["background_color"] == "#282a36"
@patch("data_designer.config.utils.visualization.display_sample_record", autospec=True)
def test_display_sample_record_with_all_custom_params(mock_display_sample_record, stub_dataset_creation_results):
"""Test display_sample_record with all parameters customized."""
stub_dataset_creation_results.display_sample_record(
index=3,
syntax_highlighting_theme="monokai",
background_color="#1e1e1e",
)
mock_display_sample_record.assert_called_once()
call_kwargs = mock_display_sample_record.call_args.kwargs
assert call_kwargs["record_index"] == 3
assert call_kwargs["syntax_highlighting_theme"] == "monokai"
assert call_kwargs["background_color"] == "#1e1e1e"
@patch("data_designer.config.utils.visualization.display_sample_record", autospec=True)
def test_display_sample_record_multiple_calls(
mock_display_sample_record, stub_dataset_creation_results, stub_dataframe
):
"""Test that display_sample_record cycles through records on multiple calls."""
num_records = len(stub_dataframe)
# Call multiple times to test cycling
for i in range(5):
stub_dataset_creation_results.display_sample_record()
assert mock_display_sample_record.call_count == 5
# Verify that record indices cycle through 0, 1, 2, ..., num_records-1, 0, ...
for i in range(5):
call_kwargs = mock_display_sample_record.call_args_list[i].kwargs
expected_index = i % num_records
assert call_kwargs["record_index"] == expected_index
def test_display_sample_record_with_empty_dataset():
"""Test display_sample_record behavior with empty dataset."""
empty_storage = MagicMock(spec=ArtifactStorage)
empty_storage.load_dataset.return_value = lazy.pd.DataFrame()
results = DatasetCreationResults(
artifact_storage=empty_storage,
analysis=MagicMock(spec=DatasetProfilerResults),
config_builder=MagicMock(spec=DataDesignerConfigBuilder),
dataset_metadata=DatasetMetadata(),
)
# Empty DataFrame is still a valid DataFrame, so accessing _record_sampler_dataset succeeds
# but display_sample_record fails when trying to access index 0
# Note: Currently raises UnboundLocalError due to bug in error handling, but tests that it fails
with pytest.raises((DatasetSampleDisplayError, UnboundLocalError)):
results.display_sample_record()
def test_display_sample_record_with_none_dataset():
"""Test display_sample_record behavior when dataset is None."""
none_storage = MagicMock(spec=ArtifactStorage)
none_storage.load_dataset.return_value = None
results = DatasetCreationResults(
artifact_storage=none_storage,
analysis=MagicMock(spec=DatasetProfilerResults),
config_builder=MagicMock(spec=DataDesignerConfigBuilder),
dataset_metadata=DatasetMetadata(),
)
# Mixin raises DatasetSampleDisplayError when dataset is None
with pytest.raises(DatasetSampleDisplayError, match="No valid dataset found"):
results.display_sample_record()
def test_results_protocol_conformance(stub_dataset_creation_results):
"""Test that DatasetCreationResults conforms to ResultsProtocol."""
# ResultsProtocol requires these methods
assert hasattr(stub_dataset_creation_results, "load_dataset")
assert hasattr(stub_dataset_creation_results, "load_analysis")
assert hasattr(stub_dataset_creation_results, "display_sample_record")
assert callable(stub_dataset_creation_results.load_dataset)
assert callable(stub_dataset_creation_results.load_analysis)
assert callable(stub_dataset_creation_results.display_sample_record)
def test_artifact_storage_load_dataset_called_once_for_caching(stub_dataset_creation_results, stub_artifact_storage):
"""Test that artifact_storage.load_dataset is called once when _record_sampler_dataset is cached."""
# First access to _record_sampler_dataset
_ = stub_dataset_creation_results._record_sampler_dataset
# Second access to _record_sampler_dataset (should use cached value)
_ = stub_dataset_creation_results._record_sampler_dataset
# Should only be called once due to caching
assert stub_artifact_storage.load_dataset.call_count == 1
def test_load_dataset_independent_of_record_sampler_cache(stub_dataset_creation_results, stub_artifact_storage):
"""Test that load_dataset calls artifact_storage.load_dataset independently of cache."""
# Access _record_sampler_dataset to trigger caching
_ = stub_dataset_creation_results._record_sampler_dataset
# Reset the call count
stub_artifact_storage.load_dataset.reset_mock()
# Call load_dataset
stub_dataset_creation_results.load_dataset()
# Should call load_dataset again (independent of cache)
stub_artifact_storage.load_dataset.assert_called_once()
@pytest.fixture
def stub_batch_dir(stub_dataframe, tmp_path):
"""Directory with two batch parquet files split from stub_dataframe.
Splitting into two batches exercises the multi-batch streaming path in export().
"""
batch_dir = tmp_path / "parquet-files"
batch_dir.mkdir()
mid = len(stub_dataframe) // 2
stub_dataframe.iloc[:mid].to_parquet(batch_dir / "batch_00000.parquet", index=False)
stub_dataframe.iloc[mid:].to_parquet(batch_dir / "batch_00001.parquet", index=False)
return batch_dir
@pytest.mark.parametrize("fmt", ["jsonl", "csv", "parquet"])
def test_export_writes_file(stub_dataset_creation_results, stub_batch_dir, tmp_path, fmt) -> None:
"""export() writes a non-empty file for each supported format."""
stub_dataset_creation_results.artifact_storage.final_dataset_path = stub_batch_dir
out = tmp_path / f"out.{fmt}"
result = stub_dataset_creation_results.export(out)
assert result == out
assert out.exists()
assert out.stat().st_size > 0
def test_export_jsonl_content(stub_dataset_creation_results, stub_dataframe, stub_batch_dir, tmp_path) -> None:
"""JSONL export writes one valid JSON object per line, covering all records."""
stub_dataset_creation_results.artifact_storage.final_dataset_path = stub_batch_dir
out = tmp_path / "out.jsonl"
stub_dataset_creation_results.export(out)
lines = out.read_text(encoding="utf-8").splitlines()
assert len(lines) == len(stub_dataframe)
for line in lines:
json.loads(line)
def test_export_csv_content(stub_dataset_creation_results, stub_dataframe, stub_batch_dir, tmp_path) -> None:
"""CSV export produces a single header row and one data row per record."""
stub_dataset_creation_results.artifact_storage.final_dataset_path = stub_batch_dir
out = tmp_path / "out.csv"
stub_dataset_creation_results.export(out)
loaded = lazy.pd.read_csv(out)
assert list(loaded.columns) == list(stub_dataframe.columns)
assert len(loaded) == len(stub_dataframe)
def test_export_parquet_content(stub_dataset_creation_results, stub_dataframe, stub_batch_dir, tmp_path) -> None:
"""Parquet export round-trips to the original DataFrame across two batches."""
stub_dataset_creation_results.artifact_storage.final_dataset_path = stub_batch_dir
out = tmp_path / "out.parquet"
stub_dataset_creation_results.export(out)
loaded = lazy.pd.read_parquet(out)
lazy.pd.testing.assert_frame_equal(
loaded.reset_index(drop=True),
stub_dataframe.reset_index(drop=True),
)
def test_export_infers_format_from_extension(stub_dataset_creation_results, stub_batch_dir, tmp_path) -> None:
"""export() infers the output format from the file extension when format is omitted."""
stub_dataset_creation_results.artifact_storage.final_dataset_path = stub_batch_dir
out = tmp_path / "out.jsonl"
stub_dataset_creation_results.export(out)
lines = out.read_text(encoding="utf-8").splitlines()
for line in lines:
json.loads(line)
def test_export_explicit_format_overrides_extension(
stub_dataset_creation_results, stub_dataframe, stub_batch_dir, tmp_path
) -> None:
"""Passing format= explicitly overrides extension-based inference."""
stub_dataset_creation_results.artifact_storage.final_dataset_path = stub_batch_dir
out = tmp_path / "data.txt"
stub_dataset_creation_results.export(out, format="jsonl")
lines = out.read_text(encoding="utf-8").splitlines()
assert len(lines) == len(stub_dataframe)
for line in lines:
json.loads(line)
def test_export_parquet_schema_unification(stub_dataset_creation_results, tmp_path) -> None:
"""Parquet export unifies schemas across batches with diverging column types."""
batch_dir = tmp_path / "parquet-files"
batch_dir.mkdir()
# Batch 0: 'value' as int64; Batch 1: 'value' as float64 (type drift)
lazy.pd.DataFrame({"value": lazy.pd.array([1, 2], dtype="int64")}).to_parquet(
batch_dir / "batch_00000.parquet", index=False
)
lazy.pd.DataFrame({"value": lazy.pd.array([3.0, 4.0], dtype="float64")}).to_parquet(
batch_dir / "batch_00001.parquet", index=False
)
stub_dataset_creation_results.artifact_storage.final_dataset_path = batch_dir
out = tmp_path / "out.parquet"
stub_dataset_creation_results.export(out)
loaded = lazy.pd.read_parquet(out)
assert list(loaded["value"]) == [1.0, 2.0, 3.0, 4.0]
def test_export_unknown_extension_raises(stub_dataset_creation_results, tmp_path) -> None:
"""export() raises InvalidFileFormatError when the extension is not a supported format."""
with pytest.raises(InvalidFileFormatError, match="Unsupported export format"):
stub_dataset_creation_results.export(tmp_path / "out.xyz")
def test_export_unsupported_explicit_format_raises(stub_dataset_creation_results, tmp_path) -> None:
"""export() raises InvalidFileFormatError for an explicit unsupported format override."""
with pytest.raises(InvalidFileFormatError, match="Unsupported export format"):
stub_dataset_creation_results.export(tmp_path / "out.jsonl", format="xlsx") # type: ignore[arg-type]
def test_export_no_batch_files_raises(stub_dataset_creation_results, tmp_path) -> None:
"""export() raises ArtifactStorageError when the batch directory is empty."""
empty_dir = tmp_path / "parquet-files"
empty_dir.mkdir()
stub_dataset_creation_results.artifact_storage.final_dataset_path = empty_dir
with pytest.raises(ArtifactStorageError, match="No batch parquet files found"):
stub_dataset_creation_results.export(tmp_path / "out.jsonl")
def test_count_records(stub_dataset_creation_results, stub_dataframe, stub_batch_dir) -> None:
"""count_records() returns the total row count without loading data pages."""
stub_dataset_creation_results.artifact_storage.final_dataset_path = stub_batch_dir
assert stub_dataset_creation_results.count_records() == len(stub_dataframe)
def test_export_uppercase_extension_is_recognised(stub_dataset_creation_results, stub_batch_dir, tmp_path) -> None:
"""export() treats file extensions case-insensitively (e.g. .JSONL → jsonl)."""
stub_dataset_creation_results.artifact_storage.final_dataset_path = stub_batch_dir
out = tmp_path / "out.JSONL"
result = stub_dataset_creation_results.export(out)
assert result == out
assert out.exists()
lines = out.read_text(encoding="utf-8").splitlines()
for line in lines:
json.loads(line)
def test_export_parquet_incompatible_schemas_raises(stub_dataset_creation_results, tmp_path) -> None:
"""_export_parquet wraps schema cast failures (incompatible column names) as InvalidFileFormatError.
With promote_options="permissive", pa.unify_schemas merges the two schemas into a superset
{col_a, col_b}. The cast step then raises ValueError because batch_00000 only has col_a.
"""
batch_dir = tmp_path / "parquet-files"
batch_dir.mkdir()
lazy.pd.DataFrame({"col_a": [1, 2]}).to_parquet(batch_dir / "batch_00000.parquet", index=False)
lazy.pd.DataFrame({"col_b": [3, 4]}).to_parquet(batch_dir / "batch_00001.parquet", index=False)
stub_dataset_creation_results.artifact_storage.final_dataset_path = batch_dir
with pytest.raises(InvalidFileFormatError, match="Cannot cast batch"):
stub_dataset_creation_results.export(tmp_path / "out.parquet")
def test_export_returns_path_object(stub_dataset_creation_results, stub_batch_dir, tmp_path) -> None:
"""export() returns a Path regardless of whether str or Path was passed."""
stub_dataset_creation_results.artifact_storage.final_dataset_path = stub_batch_dir
out = tmp_path / "out.jsonl"
result = stub_dataset_creation_results.export(str(out))
assert isinstance(result, Path)
def test_preview_results_dataset_metadata() -> None:
"""Test that PreviewResults uses DatasetMetadata in display_sample_record."""
config_builder = MagicMock(spec=DataDesignerConfigBuilder)
config_builder.get_columns_of_type.return_value = []
dataset_metadata = DatasetMetadata(seed_column_names=["name", "age"])
results = PreviewResults(
config_builder=config_builder,
dataset=lazy.pd.DataFrame({"name": ["Alice"], "age": [25], "greeting": ["Hello"]}),
dataset_metadata=dataset_metadata,
)
# Verify metadata is stored as public attribute
assert results.dataset_metadata == dataset_metadata
assert results.dataset_metadata.seed_column_names == ["name", "age"]
# Patch display_sample_record to capture the seed_column_names argument
with patch("data_designer.config.utils.visualization.display_sample_record", wraps=display_fn) as mock_display:
results.display_sample_record(index=0)
# Verify seed_column_names was passed to the display function
mock_display.assert_called_once()
call_kwargs = mock_display.call_args.kwargs
assert call_kwargs["seed_column_names"] == ["name", "age"]