diff --git a/packages/data-designer-config/src/data_designer/config/base.py b/packages/data-designer-config/src/data_designer/config/base.py index a4e55fa2..26a3fd03 100644 --- a/packages/data-designer-config/src/data_designer/config/base.py +++ b/packages/data-designer-config/src/data_designer/config/base.py @@ -6,7 +6,10 @@ from __future__ import annotations +import re from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Literal, get_args, get_origin from pydantic import BaseModel, ConfigDict, Field @@ -20,6 +23,31 @@ class ConfigBase(BaseModel): json_schema_mode_override="validation", ) + @classmethod + def schema_text(cls) -> str: + """Return an agent-friendly text summary of the model's fields and defaults.""" + lines: list[str] = [f"{cls.__name__}:"] + docstring = _get_docstring_summary(cls.__doc__) + if docstring: + lines.append(f" {docstring}") + lines.append("") + for name, field_info in cls.model_fields.items(): + annotation = _format_annotation(field_info.annotation) + if field_info.is_required(): + lines.append(f" {name}: {annotation} [required]") + else: + if field_info.default_factory is not None: + factory_name = getattr(field_info.default_factory, "__name__", repr(field_info.default_factory)) + lines.append(f" {name}: {annotation} = {factory_name}()") + else: + default = field_info.default + if isinstance(default, Enum): + default = default.value + lines.append(f" {name}: {annotation} = {default!r}") + if field_info.description: + lines.append(f" {field_info.description}") + return "\n".join(lines) + class SingleColumnConfig(ConfigBase, ABC): """Abstract base class for all single-column configuration types. @@ -83,3 +111,52 @@ class ProcessorConfig(ConfigBase, ABC): description="The name of the processor, used to identify the processor in the results and to write the artifacts to disk.", ) processor_type: str + + +def _format_annotation(annotation: Any) -> str: + """Convert a type annotation to a readable string, stripping module paths.""" + if get_origin(annotation) is Literal: + args = get_args(annotation) + if args: + values = ", ".join(repr(a.value) if isinstance(a, Enum) else repr(a) for a in args) + return f"Literal[{values}]" + raw = str(annotation) if not hasattr(annotation, "__name__") else annotation.__name__ + return re.sub(r"\b[a-zA-Z_]\w*(?:\.[a-zA-Z_]\w*)+", lambda m: m.group().rsplit(".", 1)[-1], raw) + + +_GOOGLE_SECTION_HEADERS = frozenset( + { + "args:", + "arguments:", + "attributes:", + "example:", + "examples:", + "keyword args:", + "keyword arguments:", + "note:", + "notes:", + "raises:", + "references:", + "returns:", + "see also:", + "todo:", + "warns:", + "yields:", + } +) + + +def _get_docstring_summary(docstring: str | None) -> str | None: + """Extract the first paragraph of a docstring, before any Google-style section header.""" + if not docstring: + return None + lines: list[str] = [] + for line in docstring.strip().splitlines(): + stripped = line.strip() + if stripped.lower() in _GOOGLE_SECTION_HEADERS: + break + if not stripped and lines: + break + if stripped: + lines.append(stripped) + return " ".join(lines) if lines else None diff --git a/packages/data-designer-config/src/data_designer/config/column_configs.py b/packages/data-designer-config/src/data_designer/config/column_configs.py index 661bbfa2..fe8e32e5 100644 --- a/packages/data-designer-config/src/data_designer/config/column_configs.py +++ b/packages/data-designer-config/src/data_designer/config/column_configs.py @@ -56,10 +56,19 @@ class SamplerColumnConfig(SingleColumnConfig): ``` """ - sampler_type: SamplerType - params: Annotated[SamplerParamsT, Discriminator("sampler_type")] - conditional_params: dict[str, Annotated[SamplerParamsT, Discriminator("sampler_type")]] = {} - convert_to: str | None = None + sampler_type: SamplerType = Field( + description="Type of sampler to use (e.g., uuid, category, uniform, gaussian, person, datetime)" + ) + params: Annotated[SamplerParamsT, Discriminator("sampler_type")] = Field( + description="Parameters specific to the chosen sampler type" + ) + conditional_params: dict[str, Annotated[SamplerParamsT, Discriminator("sampler_type")]] = Field( + default_factory=dict, + description="Optional dictionary for conditional parameters; keys are conditions, values are params to use when met", + ) + convert_to: str | None = Field( + default=None, description="Optional type conversion after sampling: 'float', 'int', or 'str'" + ) column_type: Literal["sampler"] = "sampler" @staticmethod @@ -136,13 +145,25 @@ class LLMTextColumnConfig(SingleColumnConfig): column_type: Discriminator field, always "llm-text" for this configuration type. """ - prompt: str - model_alias: str - system_prompt: str | None = None - multi_modal_context: list[ImageContext] | None = None - tool_alias: str | None = None - with_trace: TraceType = TraceType.NONE - extract_reasoning_content: bool = False + prompt: str = Field( + description="Jinja2 template for the LLM prompt; can reference other columns via {{ column_name }}" + ) + model_alias: str = Field(description="Alias of the model configuration to use for generation") + system_prompt: str | None = Field( + default=None, description="Optional system prompt to set model behavior and constraints" + ) + multi_modal_context: list[ImageContext] | None = Field( + default=None, description="Optional list of ImageContext for vision model inputs" + ) + tool_alias: str | None = Field( + default=None, description="Optional alias of the tool configuration to use for MCP tool calls" + ) + with_trace: TraceType = Field( + default=TraceType.NONE, description="Trace capture mode: NONE, LAST_MESSAGE, or ALL_MESSAGES" + ) + extract_reasoning_content: bool = Field( + default=False, description="If True, capture chain-of-thought in {name}__reasoning_content column" + ) column_type: Literal["llm-text"] = "llm-text" @staticmethod @@ -219,7 +240,9 @@ class LLMCodeColumnConfig(LLMTextColumnConfig): column containing the reasoning content from the final assistant response. """ - code_lang: CodeLang + code_lang: CodeLang = Field( + description="Target programming language or SQL dialect for code extraction from LLM response" + ) column_type: Literal["llm-code"] = "llm-code" @staticmethod @@ -252,7 +275,9 @@ class LLMStructuredColumnConfig(LLMTextColumnConfig): column containing the reasoning content from the final assistant response. """ - output_format: dict | type[BaseModel] + output_format: dict | type[BaseModel] = Field( + description="Pydantic model or JSON schema dict defining the expected structured output shape" + ) column_type: Literal["llm-structured"] = "llm-structured" @staticmethod @@ -317,7 +342,9 @@ class LLMJudgeColumnConfig(LLMTextColumnConfig): column containing the reasoning content from the final assistant response. """ - scores: list[Score] = Field(..., min_length=1) + scores: list[Score] = Field( + ..., min_length=1, description="List of Score objects defining rubric criteria for LLM judge evaluation" + ) column_type: Literal["llm-judge"] = "llm-judge" @staticmethod @@ -342,8 +369,10 @@ class ExpressionColumnConfig(SingleColumnConfig): """ name: str - expr: str - dtype: Literal["int", "float", "str", "bool"] = "str" + expr: str = Field(description="Jinja2 expression to compute the column value from other columns") + dtype: Literal["int", "float", "str", "bool"] = Field( + default="str", description="Data type for expression result: 'int', 'float', 'str', or 'bool'" + ) column_type: Literal["expression"] = "expression" @staticmethod @@ -410,9 +439,11 @@ class ValidationColumnConfig(SingleColumnConfig): column_type: Discriminator field, always "validation" for this configuration type. """ - target_columns: list[str] - validator_type: ValidatorType - validator_params: Annotated[ValidatorParamsT, Discriminator("validator_type")] + target_columns: list[str] = Field(description="List of column names to validate") + validator_type: ValidatorType = Field(description="Validation method: 'code', 'local_callable', or 'remote'") + validator_params: Annotated[ValidatorParamsT, Discriminator("validator_type")] = Field( + description="Validator-specific parameters (e.g., CodeValidatorParams)" + ) batch_size: int = Field(default=10, ge=1, description="Number of records to process in each batch") column_type: Literal["validation"] = "validation" @@ -479,8 +510,8 @@ class EmbeddingColumnConfig(SingleColumnConfig): column_type: Discriminator field, always "embedding" for this configuration type. """ - target_column: str - model_alias: str + target_column: str = Field(description="Name of the text column to generate embeddings for") + model_alias: str = Field(description="Alias of the model to use for embedding generation") column_type: Literal["embedding"] = "embedding" @staticmethod @@ -513,9 +544,13 @@ class ImageColumnConfig(SingleColumnConfig): column_type: Discriminator field, always "image" for this configuration type. """ - prompt: str - model_alias: str - multi_modal_context: list[ImageContext] | None = None + prompt: str = Field( + description="Jinja2 template for the image generation prompt; can reference other columns via {{ column_name }}" + ) + model_alias: str = Field(description="Alias of the model to use for image generation") + multi_modal_context: list[ImageContext] | None = Field( + default=None, description="Optional list of ImageContext for multi-modal image-to-image generation" + ) column_type: Literal["image"] = "image" @staticmethod diff --git a/packages/data-designer-config/tests/config/test_schema_text.py b/packages/data-designer-config/tests/config/test_schema_text.py new file mode 100644 index 00000000..c916c264 --- /dev/null +++ b/packages/data-designer-config/tests/config/test_schema_text.py @@ -0,0 +1,204 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from enum import Enum +from typing import Literal + +import pytest +from pydantic import Field + +from data_designer.config.base import ConfigBase + + +class Color(Enum): + RED = "red" + GREEN = "green" + + +class RequiredOnlyModel(ConfigBase): + name: str + count: int + + +class DefaultFactoryModel(ConfigBase): + tags: list[str] = Field(default_factory=list) + metadata: dict[str, int] = Field(default_factory=dict) + + +class DefaultFactoryWithDescriptionModel(ConfigBase): + items: list[str] = Field(default_factory=list, description="Collection of items") + + +class EnumDefaultModel(ConfigBase): + color: Color = Color.RED + + +class NoneDefaultModel(ConfigBase): + label: str | None = None + + +class ScalarDefaultModel(ConfigBase): + threshold: float = 0.5 + enabled: bool = True + tag: str = "default" + + +class DescribedFieldsModel(ConfigBase): + name: str = Field(description="The name of the thing") + count: int = Field(default=0, description="How many things") + + +class DocstringModel(ConfigBase): + """A model with a docstring summary.""" + + value: int + + +class NoDocstringModel(ConfigBase): + value: int + + +class LiteralModel(ConfigBase): + tag: Literal["fixed"] = "fixed" + name: str + + +class MixedModel(ConfigBase): + """Model exercising every field variant.""" + + required_field: str + optional_none: int | None = None + with_default: float = 3.14 + enum_field: Color = Color.GREEN + factory_list: list[str] = Field(default_factory=list) + factory_dict: dict[str, int] = Field(default_factory=dict, description="A mapping") + + +# --- Required fields --- + + +def test_required_fields_marked() -> None: + text = RequiredOnlyModel.schema_text() + assert "name: str [required]" in text + assert "count: int [required]" in text + + +# --- default_factory fields --- + + +def test_default_factory_list_shows_factory_call() -> None: + text = DefaultFactoryModel.schema_text() + assert "tags:" in text + assert "= list()" in text + + +def test_default_factory_dict_shows_factory_call() -> None: + text = DefaultFactoryModel.schema_text() + assert "metadata:" in text + assert "= dict()" in text + + +def test_default_factory_does_not_show_pydantic_undefined() -> None: + text = DefaultFactoryModel.schema_text() + assert "PydanticUndefined" not in text + + +def test_default_factory_with_description() -> None: + text = DefaultFactoryWithDescriptionModel.schema_text() + assert "items:" in text + assert "= list()" in text + assert "Collection of items" in text + + +# --- Enum defaults --- + + +def test_enum_default_shows_value() -> None: + text = EnumDefaultModel.schema_text() + assert "color: Color = 'red'" in text + assert "Color.RED" not in text + + +# --- None defaults --- + + +def test_none_default() -> None: + text = NoneDefaultModel.schema_text() + assert "label: str | None = None" in text + + +# --- Scalar defaults --- + + +@pytest.mark.parametrize( + ("field_name", "expected"), + [ + ("threshold", "threshold: float = 0.5"), + ("enabled", "enabled: bool = True"), + ("tag", "tag: str = 'default'"), + ], + ids=["float", "bool", "str"], +) +def test_scalar_defaults(field_name: str, expected: str) -> None: + text = ScalarDefaultModel.schema_text() + assert expected in text + + +# --- Field descriptions --- + + +def test_description_appears_below_field() -> None: + text = DescribedFieldsModel.schema_text() + lines = text.splitlines() + name_idx = next(i for i, line in enumerate(lines) if "name: str" in line) + assert "The name of the thing" in lines[name_idx + 1] + + +# --- Docstrings --- + + +def test_docstring_included() -> None: + text = DocstringModel.schema_text() + assert "A model with a docstring summary." in text + + +def test_no_docstring_still_works() -> None: + text = NoDocstringModel.schema_text() + assert text.startswith("NoDocstringModel:") + assert "value: int [required]" in text + + +# --- Header format --- + + +def test_header_is_class_name() -> None: + text = RequiredOnlyModel.schema_text() + assert text.startswith("RequiredOnlyModel:") + + +# --- Single-value Literal fields --- + + +def test_literal_field_includes_value_and_default() -> None: + text = LiteralModel.schema_text() + assert "tag: Literal['fixed'] = 'fixed'" in text + + +# --- Mixed model exercises all variants together --- + + +def test_mixed_model_all_variants() -> None: + text = MixedModel.schema_text() + assert "Model exercising every field variant." in text + assert "required_field: str [required]" in text + assert "optional_none: int | None = None" in text + assert "with_default: float = 3.14" in text + assert "enum_field: Color = 'green'" in text + assert "factory_list:" in text + assert "= list()" in text + assert "factory_dict:" in text + assert "= dict()" in text + assert "A mapping" in text + assert "PydanticUndefined" not in text diff --git a/packages/data-designer/src/data_designer/cli/agent_command_defs.py b/packages/data-designer/src/data_designer/cli/agent_command_defs.py new file mode 100644 index 00000000..c14b9516 --- /dev/null +++ b/packages/data-designer/src/data_designer/cli/agent_command_defs.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class AgentCommandDef: + name: str + attr: str + help: str + command_pattern: str + returns: str + + +AGENT_COMMANDS: tuple[AgentCommandDef, ...] = ( + AgentCommandDef( + name="context", + attr="context_command", + help="Bootstrap payload with types, state, and builder.", + command_pattern="data-designer agent context", + returns="agent_context", + ), + AgentCommandDef( + name="types", + attr="types_command", + help="Type names and import paths for one or all families.", + command_pattern="data-designer agent types [family]", + returns="agent_types", + ), + AgentCommandDef( + name="schema", + attr="schema_command", + help="Schema for a type or entire family.", + command_pattern="data-designer agent schema | --all", + returns="agent_schema", + ), + AgentCommandDef( + name="builder", + attr="builder_command", + help="ConfigBuilder method surface with signatures.", + command_pattern="data-designer agent builder", + returns="agent_builder", + ), + AgentCommandDef( + name="state.model-aliases", + attr="state_model_aliases_command", + help="Model aliases and usability status.", + command_pattern="data-designer agent state model-aliases", + returns="agent_state_model_aliases", + ), + AgentCommandDef( + name="state.persona-datasets", + attr="state_persona_datasets_command", + help="Persona locales and install status.", + command_pattern="data-designer agent state persona-datasets", + returns="agent_state_persona_datasets", + ), +) diff --git a/packages/data-designer/src/data_designer/cli/commands/agent.py b/packages/data-designer/src/data_designer/cli/commands/agent.py new file mode 100644 index 00000000..8a15a423 --- /dev/null +++ b/packages/data-designer/src/data_designer/cli/commands/agent.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +import typer + +from data_designer.cli.utils.agent_introspection import ( + AgentIntrospectionError, + get_builder_api, + get_context, + get_model_aliases_state, + get_persona_datasets_state, + get_schema, + get_types, +) +from data_designer.cli.utils.agent_text_formatter import ( + format_builder_text, + format_context_text, + format_model_aliases_text, + format_persona_datasets_text, + format_schema_text, + format_types_text, +) +from data_designer.config.utils.constants import DATA_DESIGNER_HOME + + +def context_command() -> None: + """Return a bootstrap payload with types, local state, and builder summary.""" + _run(lambda: get_context(DATA_DESIGNER_HOME), format_context_text) + + +def types_command( + family: str | None = typer.Argument(None, help="Optional schema family name."), +) -> None: + """Return available type names and import paths for one family or all families.""" + _run(lambda: get_types(family), format_types_text) + + +def schema_command( + family: str = typer.Argument(..., help="Schema family name."), + type_name: str | None = typer.Argument(None, help="Type name within the selected family."), + all_types: bool = typer.Option(False, "--all", help="Return every schema in the selected family."), +) -> None: + """Return schema for a specific type or every type in a family.""" + _run(lambda: get_schema(family, type_name, all_types=all_types), format_schema_text) + + +def builder_command() -> None: + """Return the DataDesignerConfigBuilder method surface with signatures and docstrings.""" + _run(get_builder_api, format_builder_text) + + +def state_model_aliases_command() -> None: + """Return configured model aliases and whether each one is currently usable.""" + _run(lambda: get_model_aliases_state(DATA_DESIGNER_HOME), format_model_aliases_text) + + +def state_persona_datasets_command() -> None: + """Return built-in persona locales and whether each dataset is installed locally.""" + _run(lambda: get_persona_datasets_state(DATA_DESIGNER_HOME), format_persona_datasets_text) + + +def _run(get_data: Callable[[], Any], format_text: Callable[[Any], str]) -> None: + try: + data = get_data() + typer.echo(format_text(data)) + except AgentIntrospectionError as exc: + typer.echo(f"Error [{exc.code}]: {exc.message}", err=True) + raise typer.Exit(code=1) + except Exception as exc: + typer.echo(f"Error [internal_error]: {exc}", err=True) + raise typer.Exit(code=1) diff --git a/packages/data-designer/src/data_designer/cli/main.py b/packages/data-designer/src/data_designer/cli/main.py index b4ab7dde..21cdbc41 100644 --- a/packages/data-designer/src/data_designer/cli/main.py +++ b/packages/data-designer/src/data_designer/cli/main.py @@ -5,6 +5,7 @@ from __future__ import annotations import typer +from data_designer.cli.agent_command_defs import AGENT_COMMANDS from data_designer.cli.lazy_group import create_lazy_typer_group from data_designer.cli.runtime import ensure_cli_default_model_settings @@ -99,9 +100,37 @@ download_app = typer.Typer( no_args_is_help=True, ) +_AGENT_CMD = f"{_CMD}.agent" + + +def _build_agent_lazy_group(prefix: str) -> dict[str, dict[str, str]]: + return { + cmd.name.removeprefix(f"{prefix}."): {"module": _AGENT_CMD, "attr": cmd.attr, "help": cmd.help} + for cmd in AGENT_COMMANDS + if (prefix == "" and "." not in cmd.name) or cmd.name.startswith(f"{prefix}.") + } + + +agent_app = typer.Typer( + name="agent", + help="Agent-only interface for dynamic Data Designer introspection", + cls=create_lazy_typer_group(_build_agent_lazy_group("")), + no_args_is_help=True, +) + +agent_state_app = typer.Typer( + name="state", + help="Return current local state relevant to agents", + cls=create_lazy_typer_group(_build_agent_lazy_group("state")), + no_args_is_help=True, +) + +agent_app.add_typer(agent_state_app, name="state") + # Add setup command groups app.add_typer(config_app, name="config", rich_help_panel="Setup") app.add_typer(download_app, name="download", rich_help_panel="Setup") +app.add_typer(agent_app, name="agent", rich_help_panel="Agent") def main() -> None: diff --git a/packages/data-designer/src/data_designer/cli/utils/agent_introspection.py b/packages/data-designer/src/data_designer/cli/utils/agent_introspection.py new file mode 100644 index 00000000..0fb69808 --- /dev/null +++ b/packages/data-designer/src/data_designer/cli/utils/agent_introspection.py @@ -0,0 +1,331 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.metadata +import inspect +import re +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Any, Literal, get_args, get_origin + +import data_designer.config as dd +from data_designer.cli.agent_command_defs import AGENT_COMMANDS +from data_designer.cli.repositories.model_repository import ModelRepository +from data_designer.cli.repositories.persona_repository import PersonaRepository +from data_designer.cli.repositories.provider_repository import ProviderRepository +from data_designer.cli.services.download_service import DownloadService +from data_designer.config.column_types import ColumnConfigT +from data_designer.config.config_builder import DataDesignerConfigBuilder +from data_designer.config.default_model_settings import get_providers_with_missing_api_keys +from data_designer.config.processor_types import ProcessorConfigT +from data_designer.config.sampler_constraints import ColumnConstraintT +from data_designer.config.sampler_params import SamplerParamsT +from data_designer.config.validator_params import ValidatorParamsT + + +@dataclass(frozen=True) +class FamilySpec: + name: str + type_union: Any + discriminator_field: str + + +class AgentIntrospectionError(Exception): + def __init__(self, code: str, message: str, details: dict[str, Any] | None = None) -> None: + super().__init__(message) + self.code = code + self.message = message + self.details = details or {} + + +_FAMILY_SPECS: dict[str, FamilySpec] = { + "columns": FamilySpec(name="columns", type_union=ColumnConfigT, discriminator_field="column_type"), + "samplers": FamilySpec(name="samplers", type_union=SamplerParamsT, discriminator_field="sampler_type"), + "validators": FamilySpec(name="validators", type_union=ValidatorParamsT, discriminator_field="validator_type"), + "processors": FamilySpec(name="processors", type_union=ProcessorConfigT, discriminator_field="processor_type"), + "constraints": FamilySpec(name="constraints", type_union=ColumnConstraintT, discriminator_field="constraint_type"), +} + + +def get_family_names() -> list[str]: + return sorted(_FAMILY_SPECS) + + +def get_library_version() -> str: + try: + return importlib.metadata.version("data-designer") + except importlib.metadata.PackageNotFoundError: + return "unknown" + + +def get_family_spec(family: str) -> FamilySpec: + spec = _FAMILY_SPECS.get(_normalize_family_name(family)) + if spec is None: + raise AgentIntrospectionError( + code="unknown_family", + message=f"Unknown family {family!r}.", + details={"available_families": get_family_names()}, + ) + return spec + + +def discover_family_types(family: str) -> dict[str, type]: + spec = get_family_spec(family) + discovered: dict[str, type] = {} + for model in get_args(spec.type_union): + type_name = _extract_literal_value(model.model_fields[spec.discriminator_field].annotation) + if type_name in discovered and discovered[type_name] is not model: + raise AgentIntrospectionError( + code="duplicate_discriminator_value", + message=f"Duplicate discriminator {type_name!r} in family {family!r}.", + details={"family": family, "type_name": type_name}, + ) + discovered[type_name] = model + return dict(sorted(discovered.items())) + + +def get_import_path(cls: type) -> str: + exported = getattr(dd, cls.__name__, None) + if exported is cls: + return f"data_designer.config.{cls.__name__}" + return f"{cls.__module__}.{cls.__qualname__}" + + +def get_family_catalog(family: str) -> list[dict[str, str]]: + return [ + {"type_name": type_name, "class_name": cls.__name__, "import_path": get_import_path(cls)} + for type_name, cls in discover_family_types(family).items() + ] + + +def get_family_schema(family: str, type_name: str) -> dict[str, Any]: + types_map = discover_family_types(family) + cls = types_map.get(type_name) + if cls is None: + raise AgentIntrospectionError( + code="unknown_type", + message=f"Unknown type {type_name!r} for family {family!r}.", + details={"family": family, "available_types": list(types_map)}, + ) + return _build_schema_dict(get_family_spec(family).name, type_name, cls) + + +def get_family_schemas(family: str) -> dict[str, Any]: + spec = get_family_spec(family) + types_map = discover_family_types(family) + items = [_build_schema_dict(spec.name, tn, cls) for tn, cls in types_map.items()] + return {"family": spec.name, "items": items} + + +def get_builder_api() -> dict[str, Any]: + return { + "class_name": DataDesignerConfigBuilder.__name__, + "import_path": get_import_path(DataDesignerConfigBuilder), + "methods": _get_builder_methods(), + } + + +def get_operations() -> list[dict[str, str]]: + return [ + {"name": c.name, "command_pattern": c.command_pattern, "description": c.help, "returns": c.returns} + for c in AGENT_COMMANDS + ] + + +def get_context(config_dir: Path) -> dict[str, Any]: + catalogs = {f: get_family_catalog(f) for f in get_family_names()} + return { + "operations": get_operations(), + "families": [{"family": f, "count": len(items)} for f, items in catalogs.items()], + "types": catalogs, + "state": { + "model_aliases": get_model_aliases_state(config_dir), + "persona_datasets": get_persona_datasets_state(config_dir), + }, + "builder": get_builder_api(), + } + + +def get_types(family: str | None) -> dict[str, Any]: + if family is None: + catalogs = {f: get_family_catalog(f) for f in get_family_names()} + return { + "families": [{"family": f, "count": len(items)} for f, items in catalogs.items()], + "items": catalogs, + } + return {"family": get_family_spec(family).name, "items": get_family_catalog(family)} + + +def get_schema(family: str, type_name: str | None, *, all_types: bool) -> dict[str, Any]: + if all_types and type_name is not None: + raise AgentIntrospectionError( + code="invalid_schema_request", + message="Provide either a type name or --all, but not both.", + details={"family": family, "type_name": type_name, "all": all_types}, + ) + if all_types: + return get_family_schemas(family) + if type_name is None: + raise AgentIntrospectionError( + code="missing_type_name", + message="A type name is required unless --all is provided.", + details={"family": family}, + ) + return get_family_schema(family, type_name) + + +def get_model_aliases_state(config_dir: Path) -> dict[str, Any]: + model_registry = _load_registry(ModelRepository(config_dir)) + provider_registry = _load_registry(ProviderRepository(config_dir)) + + items: list[dict[str, Any]] = [] + if model_registry is None: + return { + "model_config_present": False, + "provider_config_present": provider_registry is not None, + "default_provider": None if provider_registry is None else provider_registry.default, + "items": items, + } + + providers_by_name: dict[str, Any] = {} + missing_key_names: set[str] = set() + default_provider: str | None = None + if provider_registry is not None: + providers_by_name = {p.name: p for p in provider_registry.providers} + default_provider = provider_registry.default or ( + provider_registry.providers[0].name if provider_registry.providers else None + ) + missing_key_names = {p.name for p in get_providers_with_missing_api_keys(provider_registry.providers)} + + for mc in sorted(model_registry.model_configs, key=lambda m: m.alias): + effective = mc.provider or default_provider + usable = True + reason: str | None = None + if effective is None: + usable, reason = False, "No model provider is configured." + elif effective not in providers_by_name: + usable, reason = False, f"Provider {effective!r} is not configured." + elif effective in missing_key_names: + usable, reason = False, f"Provider {effective!r} is missing an API key." + + items.append( + { + "model_alias": mc.alias, + "model": mc.model, + "generation_type": getattr(mc.generation_type, "value", str(mc.generation_type)), + "configured_provider": mc.provider, + "effective_provider": effective, + "usable": usable, + "reason": reason, + } + ) + + return { + "model_config_present": True, + "provider_config_present": provider_registry is not None, + "default_provider": default_provider, + "items": items, + } + + +def get_persona_datasets_state(config_dir: Path) -> dict[str, Any]: + persona_repo = PersonaRepository() + download_svc = DownloadService(config_dir, persona_repo) + return { + "managed_assets_directory": str(download_svc.get_managed_assets_directory()), + "items": [ + { + "locale": loc.code, + "dataset_name": loc.dataset_name, + "size": loc.size, + "installed": download_svc.is_locale_downloaded(loc.code), + } + for loc in sorted(persona_repo.list_all(), key=lambda loc: loc.code) + ], + } + + +def _build_schema_dict(family_name: str, type_name: str, cls: type) -> dict[str, Any]: + return { + "family": family_name, + "type_name": type_name, + "class_name": cls.__name__, + "import_path": get_import_path(cls), + "schema": cls.model_json_schema(), + "schema_text": cls.schema_text(), + } + + +def _normalize_family_name(family: str) -> str: + normalized = family.strip().lower() + if normalized in _FAMILY_SPECS: + return normalized + plural = f"{normalized}s" + if plural in _FAMILY_SPECS: + return plural + return normalized + + +def _extract_literal_value(annotation: Any) -> str: + if get_origin(annotation) is not Literal or not get_args(annotation): + raise AgentIntrospectionError( + code="invalid_discriminator_annotation", + message=f"Expected non-empty Literal annotation, got {annotation!r}.", + ) + value = get_args(annotation)[0] + return str(value.value) if isinstance(value, Enum) else str(value) + + +def _get_builder_methods() -> list[dict[str, Any]]: + methods: list[dict[str, Any]] = [] + for name, attr in inspect.getmembers(DataDesignerConfigBuilder): + if name.startswith("_") and name != "__init__": + continue + if not callable(attr): + continue + try: + sig = inspect.signature(attr) + except (TypeError, ValueError): + continue + + docstring = inspect.getdoc(attr) + methods.append( + { + "name": name, + "signature": _format_signature(name, sig), + "summary": _get_first_line(docstring), + "docstring": docstring, + } + ) + + return methods + + +def _format_signature(method_name: str, sig: inspect.Signature) -> str: + params = [p for p in sig.parameters.values() if p.name not in {"self", "cls"}] + sig_str = str(sig.replace(parameters=params)) + sig_str = re.sub(r"\b[a-zA-Z_]\w*(?:\.[a-zA-Z_]\w*)+", lambda m: m.group().rsplit(".", 1)[-1], sig_str) + # Strip quotes left by `from __future__ import annotations` around type annotations. + sig_str = re.sub(r"(?<=: )'([^']+)'", r"\1", sig_str) + sig_str = re.sub(r"(?<=-> )'([^']+)'", r"\1", sig_str) + return f"{method_name}{sig_str}" + + +def _get_first_line(text: str | None) -> str | None: + return next((line.strip() for line in text.strip().splitlines() if line.strip()), None) if text else None + + +def _load_registry(repo: Any) -> Any: + if not repo.exists(): + return None + registry = repo.load() + if registry is None: + raise AgentIntrospectionError( + code="invalid_registry", + message=f"Failed to load registry from {str(repo.config_file)!r}.", + details={"config_file": str(repo.config_file)}, + ) + return registry diff --git a/packages/data-designer/src/data_designer/cli/utils/agent_text_formatter.py b/packages/data-designer/src/data_designer/cli/utils/agent_text_formatter.py new file mode 100644 index 00000000..b1d10bfe --- /dev/null +++ b/packages/data-designer/src/data_designer/cli/utils/agent_text_formatter.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Any + +from data_designer.cli.utils.agent_introspection import get_library_version + + +def format_context_text(data: dict[str, Any]) -> str: + """Format the full context payload as sectioned text with tables.""" + sections = [ + f"Data Designer v{get_library_version()}", + "", + "import data_designer.config as dd", + "", + 'A "family" is a group of related config types that share a discriminator field.', + "Use dd. to reference any type below.", + "", + "## Families", + "", + format_types_text({"families": data["families"], "items": data["types"]}), + "", + "## Model Aliases", + "", + format_model_aliases_text(data["state"]["model_aliases"]), + "", + "## Persona Datasets", + "", + format_persona_datasets_text(data["state"]["persona_datasets"]), + "", + "## Builder", + "", + format_builder_text(data["builder"]), + "", + "## Commands", + "", + _format_table(data["operations"], ["command_pattern", "description"]), + ] + return "\n".join(sections) + + +def format_types_text(data: dict[str, Any]) -> str: + """Format type listings for one family or all families.""" + if "families" in data: + lines: list[str] = [f"{f['family']}: {f['count']} types" for f in data["families"]] + lines.append("") + for family_name, items in data["items"].items(): + lines.append(_format_table(items, ["type_name", "class_name"], title=family_name)) + lines.append("") + return "\n".join(lines).rstrip() + return _format_table( + data["items"], + ["type_name", "class_name"], + title=data.get("family"), + ) + + +def format_schema_text(data: dict[str, Any]) -> str: + """Format schema data as human-readable field summaries.""" + if "items" in data: + header = f"# {data['family']} schemas ({len(data['items'])} types)" + schemas = "\n\n".join(item["schema_text"] for item in data["items"]) + return f"{header}\n\n{schemas}" + return data["schema_text"] + + +def format_builder_text(data: dict[str, Any]) -> str: + """Format builder methods with signatures.""" + path = data["import_path"] + hint = f"dd.{path.removeprefix('data_designer.config.')}" if path.startswith("data_designer.config.") else path + lines: list[str] = [ + f"{data['class_name']}:", + f" usage: {hint}", + " methods:", + "", + ] + for method in data["methods"]: + lines.append(f" {method['signature']}") + if method.get("summary"): + lines.append(f" {method['summary']}") + lines.append("") + return "\n".join(lines).rstrip() + + +def format_model_aliases_text(state: dict[str, Any]) -> str: + """Format model aliases as a text table with provider summary.""" + lines: list[str] = [f"default_provider: {state.get('default_provider') or '(none)'}", ""] + lines.append( + _format_table( + state.get("items", []), + ["model_alias", "model", "generation_type", "effective_provider", "usable", "reason"], + column_labels={"effective_provider": "provider"}, + title="model aliases", + ) + ) + return "\n".join(lines) + + +def format_persona_datasets_text(state: dict[str, Any]) -> str: + """Format persona datasets as a text table.""" + return _format_table(state.get("items", []), ["locale", "size", "installed"], title="persona datasets") + + +def _format_table( + items: list[dict[str, Any]], + columns: list[str], + *, + title: str | None = None, + column_labels: dict[str, str] | None = None, +) -> str: + labels = {col: (column_labels or {}).get(col, col) for col in columns} + + if not items: + header = f"# {title}" if title else "# table" + return f"{header}\n(no items)" + + col_widths = {col: max(len(labels[col]), max(len(_cell(row.get(col))) for row in items)) for col in columns} + + lines: list[str] = [] + if title: + lines.append(f"# {title}") + lines.append("") + lines.append(" ".join(f"{labels[col]:<{col_widths[col]}}" for col in columns)) + lines.append(" ".join("-" * col_widths[col] for col in columns)) + for row in items: + lines.append(" ".join(f"{_cell(row.get(col)):<{col_widths[col]}}" for col in columns)) + + return "\n".join(lines) + + +def _cell(value: Any) -> str: + """Convert a cell value to a string, rendering None as empty.""" + if value is None: + return "" + return str(value) diff --git a/packages/data-designer/tests/cli/commands/test_agent_command.py b/packages/data-designer/tests/cli/commands/test_agent_command.py new file mode 100644 index 00000000..54594002 --- /dev/null +++ b/packages/data-designer/tests/cli/commands/test_agent_command.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from unittest.mock import patch + +import pytest +from typer.testing import CliRunner + +from data_designer.cli.main import app + +_PATCH = "data_designer.cli.commands.agent" + + +@pytest.mark.parametrize( + "args,data_fn,format_fn,expected_text", + [ + (["agent", "context"], "get_context", "format_context_text", "Data Designer"), + (["agent", "types", "columns"], "get_types", "format_types_text", "columns"), + (["agent", "builder"], "get_builder_api", "format_builder_text", "Builder:"), + (["agent", "state", "model-aliases"], "get_model_aliases_state", "format_model_aliases_text", "model aliases"), + ( + ["agent", "state", "persona-datasets"], + "get_persona_datasets_state", + "format_persona_datasets_text", + "persona", + ), + ], + ids=["context", "types", "builder", "model-aliases", "persona-datasets"], +) +def test_commands_default_text_mode(args: list[str], data_fn: str, format_fn: str, expected_text: str) -> None: + runner = CliRunner() + with ( + patch(f"{_PATCH}.{data_fn}", return_value={"stub": True}) as mock_get, + patch(f"{_PATCH}.{format_fn}", return_value=expected_text), + ): + result = runner.invoke(app, args) + + assert result.exit_code == 0 + assert expected_text in result.output + mock_get.assert_called_once() + + +def test_schema_command_default_outputs_text() -> None: + runner = CliRunner() + with ( + patch(f"{_PATCH}.get_schema", return_value={"type_name": "llm-text", "schema": {}}), + patch(f"{_PATCH}.format_schema_text", return_value="# llm-text\n{}"), + ): + result = runner.invoke(app, ["agent", "schema", "columns", "llm-text"]) + + assert result.exit_code == 0 + assert "llm-text" in result.output + + +def test_error_outputs_message_to_stderr() -> None: + runner = CliRunner() + with patch(f"{_PATCH}.get_schema", side_effect=ValueError("boom")): + result = runner.invoke(app, ["agent", "schema", "columns", "missing"]) + + assert result.exit_code == 1 + assert result.stdout == "" + assert "internal_error" in result.stderr + assert "boom" in result.stderr diff --git a/packages/data-designer/tests/cli/utils/test_agent_introspection.py b/packages/data-designer/tests/cli/utils/test_agent_introspection.py new file mode 100644 index 00000000..4e366905 --- /dev/null +++ b/packages/data-designer/tests/cli/utils/test_agent_introspection.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +from data_designer.cli.utils.agent_introspection import ( + AgentIntrospectionError, + discover_family_types, + get_builder_api, + get_family_catalog, + get_family_schema, + get_family_spec, + get_operations, + get_types, +) + + +def test_get_family_catalog_accepts_singular_family_names() -> None: + assert get_family_catalog("validator") == get_family_catalog("validators") + + +def test_get_family_catalog_returns_sorted_type_names() -> None: + catalog = get_family_catalog("columns") + assert catalog + assert [item["type_name"] for item in catalog] == sorted(item["type_name"] for item in catalog) + + +def test_get_family_schema_returns_json_schema_payload() -> None: + schema_payload = get_family_schema("validator", "code") + + assert schema_payload["family"] == "validators" + assert schema_payload["type_name"] == "code" + assert schema_payload["class_name"] == "CodeValidatorParams" + assert schema_payload["import_path"] == "data_designer.config.CodeValidatorParams" + assert schema_payload["schema"]["title"] == "CodeValidatorParams" + + +def test_get_family_schema_includes_schema_text() -> None: + schema_payload = get_family_schema("columns", "llm-text") + + text = schema_payload["schema_text"] + assert text.startswith("LLMTextColumnConfig:") + assert "column_type:" in text + assert "name:" in text + assert "Configuration for text generation" in text + assert "Jinja2 template" in text + + +def test_get_family_schema_raises_for_unknown_type() -> None: + with pytest.raises(AgentIntrospectionError) as exc_info: + get_family_schema("validators", "does-not-exist") + + assert exc_info.value.code == "unknown_type" + assert exc_info.value.details["family"] == "validators" + assert "code" in exc_info.value.details["available_types"] + + +def test_discover_family_types_returns_pydantic_classes() -> None: + types_map = discover_family_types("columns") + + assert types_map + assert all(hasattr(cls, "model_fields") for cls in types_map.values()) + + +def test_get_family_spec_returns_discriminator_field() -> None: + spec = get_family_spec("columns") + + assert spec.name == "columns" + assert spec.discriminator_field == "column_type" + + +def test_get_builder_api_includes_docstrings() -> None: + builder_api = get_builder_api() + + assert builder_api["class_name"] == "DataDesignerConfigBuilder" + assert builder_api["import_path"] == "data_designer.config.DataDesignerConfigBuilder" + assert builder_api["methods"] + assert all("docstring" in method for method in builder_api["methods"]) + + +def test_get_types_returns_all_families_when_no_family_given() -> None: + data = get_types(None) + + assert "families" in data + assert "items" in data + assert len(data["families"]) > 0 + assert all(f["family"] in data["items"] for f in data["families"]) + + +def test_get_types_returns_single_family() -> None: + data = get_types("columns") + + assert data["family"] == "columns" + assert isinstance(data["items"], list) + assert len(data["items"]) > 0 + + +def test_get_operations_returns_all_commands() -> None: + ops = get_operations() + + assert len(ops) == 6 + assert all("name" in op and "command_pattern" in op and "description" in op for op in ops) diff --git a/packages/data-designer/tests/cli/utils/test_agent_introspection_integration.py b/packages/data-designer/tests/cli/utils/test_agent_introspection_integration.py new file mode 100644 index 00000000..2e60d477 --- /dev/null +++ b/packages/data-designer/tests/cli/utils/test_agent_introspection_integration.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from pathlib import Path + +from data_designer.cli.repositories.model_repository import ModelConfigRegistry, ModelRepository +from data_designer.cli.repositories.provider_repository import ModelProviderRegistry, ProviderRepository +from data_designer.cli.utils.agent_introspection import ( + get_context, + get_model_aliases_state, + get_persona_datasets_state, +) +from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig, ModelProvider + + +def test_get_model_aliases_state_reports_provider_status(tmp_path: Path) -> None: + provider_repository = ProviderRepository(tmp_path) + provider_repository.save( + ModelProviderRegistry( + providers=[ + ModelProvider( + name="provider-a", + endpoint="https://api.example.com/a", + provider_type="openai", + api_key="test-api-key", + ), + ModelProvider( + name="provider-b", + endpoint="https://api.example.com/b", + provider_type="openai", + api_key="MISSING_PROVIDER_KEY", + ), + ], + default="provider-a", + ) + ) + + model_repository = ModelRepository(tmp_path) + model_repository.save( + ModelConfigRegistry( + model_configs=[ + ModelConfig( + alias="alpha", + model="model-alpha", + provider=None, + inference_parameters=ChatCompletionInferenceParams(), + ), + ModelConfig( + alias="beta", + model="model-beta", + provider="provider-b", + inference_parameters=ChatCompletionInferenceParams(), + ), + ModelConfig( + alias="gamma", + model="model-gamma", + provider="provider-missing", + inference_parameters=ChatCompletionInferenceParams(), + ), + ] + ) + ) + + payload = get_model_aliases_state(tmp_path) + + assert payload["model_config_present"] is True + assert payload["provider_config_present"] is True + assert payload["default_provider"] == "provider-a" + assert payload["items"] == [ + { + "model_alias": "alpha", + "model": "model-alpha", + "generation_type": "chat-completion", + "configured_provider": None, + "effective_provider": "provider-a", + "usable": True, + "reason": None, + }, + { + "model_alias": "beta", + "model": "model-beta", + "generation_type": "chat-completion", + "configured_provider": "provider-b", + "effective_provider": "provider-b", + "usable": False, + "reason": "Provider 'provider-b' is missing an API key.", + }, + { + "model_alias": "gamma", + "model": "model-gamma", + "generation_type": "chat-completion", + "configured_provider": "provider-missing", + "effective_provider": "provider-missing", + "usable": False, + "reason": "Provider 'provider-missing' is not configured.", + }, + ] + + +def test_get_model_aliases_state_handles_missing_local_files(tmp_path: Path) -> None: + payload = get_model_aliases_state(tmp_path) + + assert payload == { + "model_config_present": False, + "provider_config_present": False, + "default_provider": None, + "items": [], + } + + +def test_get_persona_datasets_state_reports_installed_locales(tmp_path: Path) -> None: + managed_assets_dir = tmp_path / "managed-assets" / "datasets" + managed_assets_dir.mkdir(parents=True) + (managed_assets_dir / "en_US.parquet").write_text("stub") + + payload = get_persona_datasets_state(tmp_path) + + assert payload["managed_assets_directory"] == str(managed_assets_dir) + installed_by_locale = {item["locale"]: item["installed"] for item in payload["items"]} + assert installed_by_locale["en_US"] is True + assert any(not item["installed"] for item in payload["items"] if item["locale"] != "en_US") + + +def test_get_context_returns_self_describing_payload(tmp_path: Path) -> None: + payload = get_context(tmp_path) + + operation_names = [operation["name"] for operation in payload["operations"]] + assert operation_names == [ + "context", + "types", + "schema", + "builder", + "state.model-aliases", + "state.persona-datasets", + ] + assert payload["families"] + assert "columns" in payload["types"] + assert payload["builder"]["methods"] diff --git a/packages/data-designer/tests/cli/utils/test_agent_text_formatter.py b/packages/data-designer/tests/cli/utils/test_agent_text_formatter.py new file mode 100644 index 00000000..ca9a31b6 --- /dev/null +++ b/packages/data-designer/tests/cli/utils/test_agent_text_formatter.py @@ -0,0 +1,223 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Any + +import pytest + +from data_designer.cli.utils.agent_introspection import get_family_schema +from data_designer.cli.utils.agent_text_formatter import ( + format_builder_text, + format_context_text, + format_model_aliases_text, + format_persona_datasets_text, + format_schema_text, + format_types_text, +) + +# --- format_context_text --- + + +def test_format_context_text_includes_builder_section() -> None: + data: dict[str, Any] = { + "families": [{"family": "columns", "count": 1}], + "types": { + "columns": [{"type_name": "a", "class_name": "A", "import_path": "m.A"}], + }, + "state": { + "model_aliases": {"default_provider": None, "items": []}, + "persona_datasets": {"items": []}, + }, + "builder": { + "class_name": "DataDesignerConfigBuilder", + "import_path": "data_designer.config.DataDesignerConfigBuilder", + "methods": [{"name": "add_column", "signature": "add_column(col)", "summary": "Add a column."}], + }, + "operations": [{"command_pattern": "agent context", "description": "Bootstrap payload."}], + } + result = format_context_text(data) + + assert "## Builder" in result + assert "DataDesignerConfigBuilder:" in result + assert "add_column(col)" in result + + +# --- format_types_text --- + + +def test_format_types_text_single_family() -> None: + data: dict[str, Any] = { + "family": "columns", + "items": [ + {"type_name": "alpha", "class_name": "AlphaConfig", "import_path": "mod.AlphaConfig"}, + {"type_name": "beta", "class_name": "BetaConfig", "import_path": "mod.BetaConfig"}, + ], + } + result = format_types_text(data) + + assert "# columns" in result + assert "alpha" in result + assert "AlphaConfig" in result + + +def test_format_types_text_all_families() -> None: + data: dict[str, Any] = { + "families": [{"family": "columns", "count": 2}], + "items": { + "columns": [ + {"type_name": "a", "class_name": "A", "import_path": "m.A"}, + {"type_name": "b", "class_name": "B", "import_path": "m.B"}, + ], + }, + } + result = format_types_text(data) + + assert "columns: 2 types" in result + assert "a" in result + assert "b" in result + + +def test_format_types_text_empty_items() -> None: + data: dict[str, Any] = {"family": "columns", "items": []} + result = format_types_text(data) + + assert "(no items)" in result + + +# --- format_schema_text --- + + +def test_format_schema_text_single_type() -> None: + data: dict[str, Any] = { + "type_name": "llm-text", + "class_name": "LLMTextColumnConfig", + "schema_text": "LLMTextColumnConfig:\n name: str [required]", + } + result = format_schema_text(data) + + assert "LLMTextColumnConfig:" in result + assert "name: str [required]" in result + + +def test_format_schema_text_all_types() -> None: + data: dict[str, Any] = { + "family": "columns", + "items": [ + {"type_name": "a", "class_name": "A", "schema_text": "A:\n x: int [required]"}, + {"type_name": "b", "class_name": "B", "schema_text": "B:\n y: str = 'hi'"}, + ], + } + result = format_schema_text(data) + + assert "# columns schemas (2 types)" in result + assert "A:\n x: int [required]" in result + assert "B:\n y: str = 'hi'" in result + + +def test_format_schema_text_passes_through_schema_text() -> None: + schema_text = "TestModel:\n name: str [required]\n count: int = 0" + data: dict[str, Any] = {"type_name": "test", "class_name": "TestModel", "schema_text": schema_text} + result = format_schema_text(data) + + assert result == schema_text + + +# --- format_builder_text --- + + +def test_format_builder_text_renders_methods() -> None: + data: dict[str, Any] = { + "class_name": "MyBuilder", + "import_path": "data_designer.config.MyBuilder", + "methods": [ + {"name": "add_column", "signature": "add_column(column: ColumnConfig)", "summary": "Add a column."}, + {"name": "build", "signature": "build()", "summary": "Build the config."}, + ], + } + result = format_builder_text(data) + + assert "MyBuilder:" in result + assert "usage: dd.MyBuilder" in result + assert "add_column(column: ColumnConfig)" in result + assert "Add a column." in result + + +def test_format_builder_text_handles_method_without_summary() -> None: + data: dict[str, Any] = { + "class_name": "Builder", + "import_path": "mod.Builder", + "methods": [{"name": "reset", "signature": "reset()", "summary": None}], + } + result = format_builder_text(data) + + assert "reset()" in result + + +# --- format_model_aliases_text --- + + +def test_format_model_aliases_text_with_items() -> None: + state: dict[str, Any] = { + "default_provider": "nvidia", + "items": [ + { + "model_alias": "test", + "model": "meta/llama-3", + "generation_type": "chat", + "effective_provider": "nvidia", + "usable": True, + "reason": None, + }, + ], + } + result = format_model_aliases_text(state) + + assert "default_provider: nvidia" in result + assert "test" in result + assert "meta/llama-3" in result + + +def test_format_model_aliases_text_empty() -> None: + state: dict[str, Any] = {"default_provider": None, "items": []} + result = format_model_aliases_text(state) + + assert "default_provider: (none)" in result + assert "(no items)" in result + + +# --- format_persona_datasets_text --- + + +def test_format_persona_datasets_text() -> None: + state: dict[str, Any] = { + "items": [{"locale": "en_US", "size": "10MB", "installed": True}], + } + result = format_persona_datasets_text(state) + + assert "# persona datasets" in result + assert "en_US" in result + assert "True" in result + + +# --- Real config models --- + + +@pytest.mark.parametrize( + "family,type_name", + [ + ("columns", "llm-text"), + ("columns", "sampler"), + ("samplers", "category"), + ("validators", "code"), + ("constraints", "scalar_inequality"), + ], + ids=["columns-llm-text", "columns-sampler", "samplers-category", "validators-code", "constraints-scalar"], +) +def test_format_schema_text_on_real_config_models(family: str, type_name: str) -> None: + schema_data = get_family_schema(family, type_name) + result = format_schema_text(schema_data) + + assert schema_data["class_name"] in result + assert result == schema_data["schema_text"]