mirror of
https://github.com/NVIDIA-NeMo/DataDesigner
synced 2026-05-24 09:48:29 +00:00
feat(results): add export() method and --output-format CLI flag
Adds DatasetCreationResults.export(path, format=) supporting jsonl, csv, and parquet. The CLI create command gains --output-format / -f which writes dataset.<format> alongside the parquet batch files.
This commit is contained in:
parent
4054610147
commit
0bdf24ab67
6 changed files with 174 additions and 3 deletions
|
|
@ -7,6 +7,7 @@ import typer
|
|||
|
||||
from data_designer.cli.controllers.generation_controller import GenerationController
|
||||
from data_designer.config.utils.constants import DEFAULT_NUM_RECORDS
|
||||
from data_designer.interface.results import SUPPORTED_EXPORT_FORMATS
|
||||
|
||||
|
||||
def create_command(
|
||||
|
|
@ -35,6 +36,16 @@ def create_command(
|
|||
"-o",
|
||||
help="Path where generated artifacts will be stored. Defaults to ./artifacts.",
|
||||
),
|
||||
output_format: str | None = typer.Option(
|
||||
None,
|
||||
"--output-format",
|
||||
"-f",
|
||||
help=(
|
||||
f"Export the dataset to a single file after generation. "
|
||||
f"Supported formats: {', '.join(SUPPORTED_EXPORT_FORMATS)}. "
|
||||
"The file is written to <artifact-path>/<dataset-name>/dataset.<format>."
|
||||
),
|
||||
),
|
||||
) -> None:
|
||||
"""Create a full dataset and save results to disk.
|
||||
|
||||
|
|
@ -60,4 +71,5 @@ def create_command(
|
|||
num_records=num_records,
|
||||
dataset_name=dataset_name,
|
||||
artifact_path=artifact_path,
|
||||
output_format=output_format,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -116,6 +116,7 @@ class GenerationController:
|
|||
num_records: int,
|
||||
dataset_name: str,
|
||||
artifact_path: str | None,
|
||||
output_format: str | None = None,
|
||||
) -> None:
|
||||
"""Load config, create a full dataset, and save results to disk.
|
||||
|
||||
|
|
@ -124,6 +125,8 @@ class GenerationController:
|
|||
num_records: Number of records to generate.
|
||||
dataset_name: Name for the generated dataset folder.
|
||||
artifact_path: Path where generated artifacts will be stored, or None for default.
|
||||
output_format: If set, export the dataset to a single file in this format after
|
||||
generation. One of 'jsonl', 'csv', 'parquet'.
|
||||
"""
|
||||
config_builder = self._load_config(config_source)
|
||||
|
||||
|
|
@ -157,6 +160,24 @@ class GenerationController:
|
|||
console.print()
|
||||
print_success(f"Dataset created — {len(dataset)} record(s) generated")
|
||||
console.print(f" Artifacts saved to: [bold]{results.artifact_storage.base_dataset_path}[/bold]")
|
||||
|
||||
if output_format is not None:
|
||||
from data_designer.interface.results import SUPPORTED_EXPORT_FORMATS
|
||||
|
||||
if output_format not in SUPPORTED_EXPORT_FORMATS:
|
||||
print_error(
|
||||
f"Unsupported export format: {output_format!r}. "
|
||||
f"Choose one of: {', '.join(SUPPORTED_EXPORT_FORMATS)}."
|
||||
)
|
||||
raise typer.Exit(code=1)
|
||||
export_path = results.artifact_storage.base_dataset_path / f"dataset.{output_format}"
|
||||
try:
|
||||
results.export(export_path, format=output_format) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
print_error(f"Export failed: {e}")
|
||||
raise typer.Exit(code=1)
|
||||
console.print(f" Exported to: [bold]{export_path}[/bold]")
|
||||
|
||||
console.print()
|
||||
|
||||
def _load_config(self, config_source: str) -> DataDesignerConfigBuilder:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
|
||||
from data_designer.config.config_builder import DataDesignerConfigBuilder
|
||||
|
|
@ -19,6 +19,9 @@ if TYPE_CHECKING:
|
|||
|
||||
from data_designer.engine.dataset_builders.utils.task_model import TaskTrace
|
||||
|
||||
ExportFormat = Literal["jsonl", "csv", "parquet"]
|
||||
SUPPORTED_EXPORT_FORMATS: tuple[str, ...] = ("jsonl", "csv", "parquet")
|
||||
|
||||
|
||||
class DatasetCreationResults(WithRecordSamplerMixin):
|
||||
"""Results container for a Data Designer dataset creation run.
|
||||
|
|
@ -95,6 +98,42 @@ class DatasetCreationResults(WithRecordSamplerMixin):
|
|||
raise ArtifactStorageError(f"Processor {processor_name} has no artifacts.")
|
||||
return self.artifact_storage.processors_outputs_path / processor_name
|
||||
|
||||
def export(self, path: Path | str, *, format: ExportFormat = "jsonl") -> Path:
|
||||
"""Export the generated dataset to a single file.
|
||||
|
||||
Args:
|
||||
path: Output file path. The extension is not inferred from *format* —
|
||||
the exact path is used as-is.
|
||||
format: Output format. One of ``'jsonl'``, ``'csv'``, or ``'parquet'``.
|
||||
Defaults to ``'jsonl'``.
|
||||
|
||||
Returns:
|
||||
Path to the written file.
|
||||
|
||||
Raises:
|
||||
ValueError: If an unsupported format is requested.
|
||||
|
||||
Example:
|
||||
>>> results = data_designer.create(config, num_records=1000)
|
||||
>>> results.export("output.jsonl")
|
||||
PosixPath('output.jsonl')
|
||||
>>> results.export("output.csv", format="csv")
|
||||
PosixPath('output.csv')
|
||||
"""
|
||||
if format not in SUPPORTED_EXPORT_FORMATS:
|
||||
raise ValueError(
|
||||
f"Unsupported export format: {format!r}. Choose one of: {', '.join(SUPPORTED_EXPORT_FORMATS)}."
|
||||
)
|
||||
path = Path(path)
|
||||
df = self.load_dataset()
|
||||
if format == "jsonl":
|
||||
df.to_json(path, orient="records", lines=True, force_ascii=False)
|
||||
elif format == "csv":
|
||||
df.to_csv(path, index=False)
|
||||
elif format == "parquet":
|
||||
df.to_parquet(path, index=False)
|
||||
return path
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
repo_id: str,
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ def test_create_command_delegates_to_controller(mock_ctrl_cls: MagicMock) -> Non
|
|||
mock_ctrl = MagicMock()
|
||||
mock_ctrl_cls.return_value = mock_ctrl
|
||||
|
||||
create_command(config_source="config.yaml", num_records=10, dataset_name="dataset", artifact_path=None)
|
||||
create_command(config_source="config.yaml", num_records=10, dataset_name="dataset", artifact_path=None, output_format=None)
|
||||
|
||||
mock_ctrl_cls.assert_called_once()
|
||||
mock_ctrl.run_create.assert_called_once_with(
|
||||
|
|
@ -26,6 +26,7 @@ def test_create_command_delegates_to_controller(mock_ctrl_cls: MagicMock) -> Non
|
|||
num_records=10,
|
||||
dataset_name="dataset",
|
||||
artifact_path=None,
|
||||
output_format=None,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -40,6 +41,7 @@ def test_create_command_passes_custom_options(mock_ctrl_cls: MagicMock) -> None:
|
|||
num_records=100,
|
||||
dataset_name="my_data",
|
||||
artifact_path="/custom/output",
|
||||
output_format=None,
|
||||
)
|
||||
|
||||
mock_ctrl.run_create.assert_called_once_with(
|
||||
|
|
@ -47,6 +49,7 @@ def test_create_command_passes_custom_options(mock_ctrl_cls: MagicMock) -> None:
|
|||
num_records=100,
|
||||
dataset_name="my_data",
|
||||
artifact_path="/custom/output",
|
||||
output_format=None,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -56,11 +59,35 @@ def test_create_command_default_artifact_path_is_none(mock_ctrl_cls: MagicMock)
|
|||
mock_ctrl = MagicMock()
|
||||
mock_ctrl_cls.return_value = mock_ctrl
|
||||
|
||||
create_command(config_source="config.yaml", num_records=5, dataset_name="ds", artifact_path=None)
|
||||
create_command(config_source="config.yaml", num_records=5, dataset_name="ds", artifact_path=None, output_format=None)
|
||||
|
||||
mock_ctrl.run_create.assert_called_once_with(
|
||||
config_source="config.yaml",
|
||||
num_records=5,
|
||||
dataset_name="ds",
|
||||
artifact_path=None,
|
||||
output_format=None,
|
||||
)
|
||||
|
||||
|
||||
@patch("data_designer.cli.commands.create.GenerationController")
|
||||
def test_create_command_passes_output_format(mock_ctrl_cls: MagicMock) -> None:
|
||||
"""Test create_command forwards --output-format to the controller."""
|
||||
mock_ctrl = MagicMock()
|
||||
mock_ctrl_cls.return_value = mock_ctrl
|
||||
|
||||
create_command(
|
||||
config_source="config.yaml",
|
||||
num_records=10,
|
||||
dataset_name="dataset",
|
||||
artifact_path=None,
|
||||
output_format="jsonl",
|
||||
)
|
||||
|
||||
mock_ctrl.run_create.assert_called_once_with(
|
||||
config_source="config.yaml",
|
||||
num_records=10,
|
||||
dataset_name="dataset",
|
||||
artifact_path=None,
|
||||
output_format="jsonl",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -84,4 +84,5 @@ def test_app_dispatches_lazy_create_command(mock_controller_cls: Mock) -> None:
|
|||
num_records=DEFAULT_NUM_RECORDS,
|
||||
dataset_name="dataset",
|
||||
artifact_path=None,
|
||||
output_format=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -259,6 +259,77 @@ def test_load_dataset_independent_of_record_sampler_cache(stub_dataset_creation_
|
|||
stub_artifact_storage.load_dataset.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fmt", ["jsonl", "csv", "parquet"])
|
||||
def test_export_writes_file(stub_dataset_creation_results, tmp_path, fmt):
|
||||
"""export() writes a file in the requested format."""
|
||||
out = tmp_path / f"out.{fmt}"
|
||||
result = stub_dataset_creation_results.export(out, format=fmt)
|
||||
assert result == out
|
||||
assert out.exists()
|
||||
assert out.stat().st_size > 0
|
||||
|
||||
|
||||
def test_export_jsonl_content(stub_dataset_creation_results, stub_dataframe, tmp_path):
|
||||
"""JSONL export writes one JSON object per line."""
|
||||
import json
|
||||
|
||||
out = tmp_path / "out.jsonl"
|
||||
stub_dataset_creation_results.export(out, format="jsonl")
|
||||
lines = out.read_text(encoding="utf-8").splitlines()
|
||||
assert len(lines) == len(stub_dataframe)
|
||||
# Each line must be valid JSON
|
||||
for line in lines:
|
||||
json.loads(line)
|
||||
|
||||
|
||||
def test_export_csv_content(stub_dataset_creation_results, stub_dataframe, tmp_path):
|
||||
"""CSV export has a header row and one data row per record."""
|
||||
import data_designer.lazy_heavy_imports as lazy
|
||||
|
||||
out = tmp_path / "out.csv"
|
||||
stub_dataset_creation_results.export(out, format="csv")
|
||||
loaded = lazy.pd.read_csv(out)
|
||||
assert list(loaded.columns) == list(stub_dataframe.columns)
|
||||
assert len(loaded) == len(stub_dataframe)
|
||||
|
||||
|
||||
def test_export_parquet_content(stub_dataset_creation_results, stub_dataframe, tmp_path):
|
||||
"""Parquet export round-trips to the original DataFrame."""
|
||||
import data_designer.lazy_heavy_imports as lazy
|
||||
|
||||
out = tmp_path / "out.parquet"
|
||||
stub_dataset_creation_results.export(out, format="parquet")
|
||||
loaded = lazy.pd.read_parquet(out)
|
||||
lazy.pd.testing.assert_frame_equal(loaded.reset_index(drop=True), stub_dataframe.reset_index(drop=True))
|
||||
|
||||
|
||||
def test_export_default_format_is_jsonl(stub_dataset_creation_results, tmp_path):
|
||||
"""export() defaults to JSONL when no format is given."""
|
||||
import json
|
||||
|
||||
out = tmp_path / "out.jsonl"
|
||||
stub_dataset_creation_results.export(out)
|
||||
lines = out.read_text(encoding="utf-8").splitlines()
|
||||
# All lines must be valid JSON
|
||||
for line in lines:
|
||||
json.loads(line)
|
||||
|
||||
|
||||
def test_export_unsupported_format_raises(stub_dataset_creation_results, tmp_path):
|
||||
"""export() raises ValueError for unknown formats."""
|
||||
with pytest.raises(ValueError, match="Unsupported export format"):
|
||||
stub_dataset_creation_results.export(tmp_path / "out.xyz", format="xlsx") # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_export_returns_path_object(stub_dataset_creation_results, tmp_path):
|
||||
"""export() returns a Path regardless of whether str or Path was passed."""
|
||||
from pathlib import Path
|
||||
|
||||
out = tmp_path / "out.jsonl"
|
||||
result = stub_dataset_creation_results.export(str(out))
|
||||
assert isinstance(result, Path)
|
||||
|
||||
|
||||
def test_preview_results_dataset_metadata() -> None:
|
||||
"""Test that PreviewResults uses DatasetMetadata in display_sample_record."""
|
||||
config_builder = MagicMock(spec=DataDesignerConfigBuilder)
|
||||
|
|
|
|||
Loading…
Reference in a new issue