feat: agent CLI introspection (simplified) (#415)

* feat: add agent introspection cli

* refactor: remove agent cli schema version

* refactor: omit missing builder docstrings from context

* refactor: tighten agent cli contract

* feat: add schema_text() to ConfigBase for human-readable field summaries

ConfigBase.schema_text() returns a concise text representation including
the class docstring summary, field names, types, defaults, and
descriptions. Field descriptions added to column config types to
surface through this method.

* refactor: flatten agent CLI into plain functions with text output mode

Delete AgentController class and agent_command_defs module. Move all
logic into agent_introspection (data) and agent_text_formatter (display)
as plain functions. Add --json flag so commands default to human-readable
text using schema_text(), with JSON as opt-in. Unify _emit helper,
remove include_docstrings parameter, deduplicate catalog calls, and fix
N+1 discover_family_types in get_family_schemas.

* fix: port stale controller tests and consolidate command descriptions

Port test_agent_controller.py to use plain functions instead of deleted
AgentController. Extract AGENT_COMMANDS constant as single source for
operation descriptions, syncing with main.py help strings.

* style: fix ruff formatting in agent_introspection

* refactor: centralize agent command definitions

Extract AGENT_COMMANDS into agent_command_defs.py so main.py and
agent_introspection.py share a single source for command names,
help text, and metadata. The new module has no heavy dependencies,
keeping --help latency unaffected.

* fix: handle default_factory and empty providers in schema_text and introspection

- schema_text() now detects default_factory fields and renders e.g. "list()"
  instead of leaking PydanticUndefined
- Guard against IndexError when provider registry has an empty providers list
- Add 15 edge-case tests for schema_text covering default_factory, enum
  defaults, None defaults, scalar defaults, descriptions, and docstrings

* refactor: remove JSON output mode from agent CLI commands

Text-only output simplifies the interface. Structured output can be
added back trivially since the functions already return dicts.

* docs: update schema_text docstring to reflect agent focus

* fix: include builder section and import_path in agent text output

- format_context_text now renders a ## Builder section
- format_types_text now includes import_path column in tables

* refactor: drop import_path from types tables

All config objects are imported via dd.<ClassName>, so the full import
path is redundant noise in agent output.

* docs: add family definition and import hint to context output

* refactor: rename Types section to Families, drop redundant "types" from sub-headers

* fix: coerce None to empty string in table cells

row.get(col, '') returns None when the key exists with value None,
causing str(None) to render "None" in the output. Use `or ''` instead.

* refactor: move agent controller tests to utils as introspection integration tests

There is no controller layer — these tests exercise functions in
agent_introspection.py, so they belong in tests/cli/utils/.

* fix: only coerce None to empty string in table cells, not False

The previous `or ''` pattern treated all falsy values (including False)
as empty. Use an explicit None check so booleans render correctly.

* style: address review nits from nabin

- Add explicit parentheses to and/or precedence in _build_agent_lazy_group
- Rename loop variable l to line in test_schema_text
- Move get_family_schema import to module level in test_agent_text_formatter

* fix: improve schema_text Literal display, builder signature quotes, and docstring parsing

- _format_annotation now renders Literal['value'] instead of bare Literal
- _format_signature strips quotes from stringified annotations caused by
  `from __future__ import annotations`
- _get_docstring_summary stops at any Google-style section header, not
  just Attributes:
This commit is contained in:
Johnny Greco 2026-03-13 18:26:00 -04:00 committed by GitHub
parent 02744d152d
commit 4c19dba74b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 1506 additions and 24 deletions

View file

@ -6,7 +6,10 @@
from __future__ import annotations
import re
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Literal, get_args, get_origin
from pydantic import BaseModel, ConfigDict, Field
@ -20,6 +23,31 @@ class ConfigBase(BaseModel):
json_schema_mode_override="validation",
)
@classmethod
def schema_text(cls) -> str:
"""Return an agent-friendly text summary of the model's fields and defaults."""
lines: list[str] = [f"{cls.__name__}:"]
docstring = _get_docstring_summary(cls.__doc__)
if docstring:
lines.append(f" {docstring}")
lines.append("")
for name, field_info in cls.model_fields.items():
annotation = _format_annotation(field_info.annotation)
if field_info.is_required():
lines.append(f" {name}: {annotation} [required]")
else:
if field_info.default_factory is not None:
factory_name = getattr(field_info.default_factory, "__name__", repr(field_info.default_factory))
lines.append(f" {name}: {annotation} = {factory_name}()")
else:
default = field_info.default
if isinstance(default, Enum):
default = default.value
lines.append(f" {name}: {annotation} = {default!r}")
if field_info.description:
lines.append(f" {field_info.description}")
return "\n".join(lines)
class SingleColumnConfig(ConfigBase, ABC):
"""Abstract base class for all single-column configuration types.
@ -83,3 +111,52 @@ class ProcessorConfig(ConfigBase, ABC):
description="The name of the processor, used to identify the processor in the results and to write the artifacts to disk.",
)
processor_type: str
def _format_annotation(annotation: Any) -> str:
"""Convert a type annotation to a readable string, stripping module paths."""
if get_origin(annotation) is Literal:
args = get_args(annotation)
if args:
values = ", ".join(repr(a.value) if isinstance(a, Enum) else repr(a) for a in args)
return f"Literal[{values}]"
raw = str(annotation) if not hasattr(annotation, "__name__") else annotation.__name__
return re.sub(r"\b[a-zA-Z_]\w*(?:\.[a-zA-Z_]\w*)+", lambda m: m.group().rsplit(".", 1)[-1], raw)
_GOOGLE_SECTION_HEADERS = frozenset(
{
"args:",
"arguments:",
"attributes:",
"example:",
"examples:",
"keyword args:",
"keyword arguments:",
"note:",
"notes:",
"raises:",
"references:",
"returns:",
"see also:",
"todo:",
"warns:",
"yields:",
}
)
def _get_docstring_summary(docstring: str | None) -> str | None:
"""Extract the first paragraph of a docstring, before any Google-style section header."""
if not docstring:
return None
lines: list[str] = []
for line in docstring.strip().splitlines():
stripped = line.strip()
if stripped.lower() in _GOOGLE_SECTION_HEADERS:
break
if not stripped and lines:
break
if stripped:
lines.append(stripped)
return " ".join(lines) if lines else None

View file

@ -56,10 +56,19 @@ class SamplerColumnConfig(SingleColumnConfig):
```
"""
sampler_type: SamplerType
params: Annotated[SamplerParamsT, Discriminator("sampler_type")]
conditional_params: dict[str, Annotated[SamplerParamsT, Discriminator("sampler_type")]] = {}
convert_to: str | None = None
sampler_type: SamplerType = Field(
description="Type of sampler to use (e.g., uuid, category, uniform, gaussian, person, datetime)"
)
params: Annotated[SamplerParamsT, Discriminator("sampler_type")] = Field(
description="Parameters specific to the chosen sampler type"
)
conditional_params: dict[str, Annotated[SamplerParamsT, Discriminator("sampler_type")]] = Field(
default_factory=dict,
description="Optional dictionary for conditional parameters; keys are conditions, values are params to use when met",
)
convert_to: str | None = Field(
default=None, description="Optional type conversion after sampling: 'float', 'int', or 'str'"
)
column_type: Literal["sampler"] = "sampler"
@staticmethod
@ -136,13 +145,25 @@ class LLMTextColumnConfig(SingleColumnConfig):
column_type: Discriminator field, always "llm-text" for this configuration type.
"""
prompt: str
model_alias: str
system_prompt: str | None = None
multi_modal_context: list[ImageContext] | None = None
tool_alias: str | None = None
with_trace: TraceType = TraceType.NONE
extract_reasoning_content: bool = False
prompt: str = Field(
description="Jinja2 template for the LLM prompt; can reference other columns via {{ column_name }}"
)
model_alias: str = Field(description="Alias of the model configuration to use for generation")
system_prompt: str | None = Field(
default=None, description="Optional system prompt to set model behavior and constraints"
)
multi_modal_context: list[ImageContext] | None = Field(
default=None, description="Optional list of ImageContext for vision model inputs"
)
tool_alias: str | None = Field(
default=None, description="Optional alias of the tool configuration to use for MCP tool calls"
)
with_trace: TraceType = Field(
default=TraceType.NONE, description="Trace capture mode: NONE, LAST_MESSAGE, or ALL_MESSAGES"
)
extract_reasoning_content: bool = Field(
default=False, description="If True, capture chain-of-thought in {name}__reasoning_content column"
)
column_type: Literal["llm-text"] = "llm-text"
@staticmethod
@ -219,7 +240,9 @@ class LLMCodeColumnConfig(LLMTextColumnConfig):
column containing the reasoning content from the final assistant response.
"""
code_lang: CodeLang
code_lang: CodeLang = Field(
description="Target programming language or SQL dialect for code extraction from LLM response"
)
column_type: Literal["llm-code"] = "llm-code"
@staticmethod
@ -252,7 +275,9 @@ class LLMStructuredColumnConfig(LLMTextColumnConfig):
column containing the reasoning content from the final assistant response.
"""
output_format: dict | type[BaseModel]
output_format: dict | type[BaseModel] = Field(
description="Pydantic model or JSON schema dict defining the expected structured output shape"
)
column_type: Literal["llm-structured"] = "llm-structured"
@staticmethod
@ -317,7 +342,9 @@ class LLMJudgeColumnConfig(LLMTextColumnConfig):
column containing the reasoning content from the final assistant response.
"""
scores: list[Score] = Field(..., min_length=1)
scores: list[Score] = Field(
..., min_length=1, description="List of Score objects defining rubric criteria for LLM judge evaluation"
)
column_type: Literal["llm-judge"] = "llm-judge"
@staticmethod
@ -342,8 +369,10 @@ class ExpressionColumnConfig(SingleColumnConfig):
"""
name: str
expr: str
dtype: Literal["int", "float", "str", "bool"] = "str"
expr: str = Field(description="Jinja2 expression to compute the column value from other columns")
dtype: Literal["int", "float", "str", "bool"] = Field(
default="str", description="Data type for expression result: 'int', 'float', 'str', or 'bool'"
)
column_type: Literal["expression"] = "expression"
@staticmethod
@ -410,9 +439,11 @@ class ValidationColumnConfig(SingleColumnConfig):
column_type: Discriminator field, always "validation" for this configuration type.
"""
target_columns: list[str]
validator_type: ValidatorType
validator_params: Annotated[ValidatorParamsT, Discriminator("validator_type")]
target_columns: list[str] = Field(description="List of column names to validate")
validator_type: ValidatorType = Field(description="Validation method: 'code', 'local_callable', or 'remote'")
validator_params: Annotated[ValidatorParamsT, Discriminator("validator_type")] = Field(
description="Validator-specific parameters (e.g., CodeValidatorParams)"
)
batch_size: int = Field(default=10, ge=1, description="Number of records to process in each batch")
column_type: Literal["validation"] = "validation"
@ -479,8 +510,8 @@ class EmbeddingColumnConfig(SingleColumnConfig):
column_type: Discriminator field, always "embedding" for this configuration type.
"""
target_column: str
model_alias: str
target_column: str = Field(description="Name of the text column to generate embeddings for")
model_alias: str = Field(description="Alias of the model to use for embedding generation")
column_type: Literal["embedding"] = "embedding"
@staticmethod
@ -513,9 +544,13 @@ class ImageColumnConfig(SingleColumnConfig):
column_type: Discriminator field, always "image" for this configuration type.
"""
prompt: str
model_alias: str
multi_modal_context: list[ImageContext] | None = None
prompt: str = Field(
description="Jinja2 template for the image generation prompt; can reference other columns via {{ column_name }}"
)
model_alias: str = Field(description="Alias of the model to use for image generation")
multi_modal_context: list[ImageContext] | None = Field(
default=None, description="Optional list of ImageContext for multi-modal image-to-image generation"
)
column_type: Literal["image"] = "image"
@staticmethod

View file

@ -0,0 +1,204 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from enum import Enum
from typing import Literal
import pytest
from pydantic import Field
from data_designer.config.base import ConfigBase
class Color(Enum):
RED = "red"
GREEN = "green"
class RequiredOnlyModel(ConfigBase):
name: str
count: int
class DefaultFactoryModel(ConfigBase):
tags: list[str] = Field(default_factory=list)
metadata: dict[str, int] = Field(default_factory=dict)
class DefaultFactoryWithDescriptionModel(ConfigBase):
items: list[str] = Field(default_factory=list, description="Collection of items")
class EnumDefaultModel(ConfigBase):
color: Color = Color.RED
class NoneDefaultModel(ConfigBase):
label: str | None = None
class ScalarDefaultModel(ConfigBase):
threshold: float = 0.5
enabled: bool = True
tag: str = "default"
class DescribedFieldsModel(ConfigBase):
name: str = Field(description="The name of the thing")
count: int = Field(default=0, description="How many things")
class DocstringModel(ConfigBase):
"""A model with a docstring summary."""
value: int
class NoDocstringModel(ConfigBase):
value: int
class LiteralModel(ConfigBase):
tag: Literal["fixed"] = "fixed"
name: str
class MixedModel(ConfigBase):
"""Model exercising every field variant."""
required_field: str
optional_none: int | None = None
with_default: float = 3.14
enum_field: Color = Color.GREEN
factory_list: list[str] = Field(default_factory=list)
factory_dict: dict[str, int] = Field(default_factory=dict, description="A mapping")
# --- Required fields ---
def test_required_fields_marked() -> None:
text = RequiredOnlyModel.schema_text()
assert "name: str [required]" in text
assert "count: int [required]" in text
# --- default_factory fields ---
def test_default_factory_list_shows_factory_call() -> None:
text = DefaultFactoryModel.schema_text()
assert "tags:" in text
assert "= list()" in text
def test_default_factory_dict_shows_factory_call() -> None:
text = DefaultFactoryModel.schema_text()
assert "metadata:" in text
assert "= dict()" in text
def test_default_factory_does_not_show_pydantic_undefined() -> None:
text = DefaultFactoryModel.schema_text()
assert "PydanticUndefined" not in text
def test_default_factory_with_description() -> None:
text = DefaultFactoryWithDescriptionModel.schema_text()
assert "items:" in text
assert "= list()" in text
assert "Collection of items" in text
# --- Enum defaults ---
def test_enum_default_shows_value() -> None:
text = EnumDefaultModel.schema_text()
assert "color: Color = 'red'" in text
assert "Color.RED" not in text
# --- None defaults ---
def test_none_default() -> None:
text = NoneDefaultModel.schema_text()
assert "label: str | None = None" in text
# --- Scalar defaults ---
@pytest.mark.parametrize(
("field_name", "expected"),
[
("threshold", "threshold: float = 0.5"),
("enabled", "enabled: bool = True"),
("tag", "tag: str = 'default'"),
],
ids=["float", "bool", "str"],
)
def test_scalar_defaults(field_name: str, expected: str) -> None:
text = ScalarDefaultModel.schema_text()
assert expected in text
# --- Field descriptions ---
def test_description_appears_below_field() -> None:
text = DescribedFieldsModel.schema_text()
lines = text.splitlines()
name_idx = next(i for i, line in enumerate(lines) if "name: str" in line)
assert "The name of the thing" in lines[name_idx + 1]
# --- Docstrings ---
def test_docstring_included() -> None:
text = DocstringModel.schema_text()
assert "A model with a docstring summary." in text
def test_no_docstring_still_works() -> None:
text = NoDocstringModel.schema_text()
assert text.startswith("NoDocstringModel:")
assert "value: int [required]" in text
# --- Header format ---
def test_header_is_class_name() -> None:
text = RequiredOnlyModel.schema_text()
assert text.startswith("RequiredOnlyModel:")
# --- Single-value Literal fields ---
def test_literal_field_includes_value_and_default() -> None:
text = LiteralModel.schema_text()
assert "tag: Literal['fixed'] = 'fixed'" in text
# --- Mixed model exercises all variants together ---
def test_mixed_model_all_variants() -> None:
text = MixedModel.schema_text()
assert "Model exercising every field variant." in text
assert "required_field: str [required]" in text
assert "optional_none: int | None = None" in text
assert "with_default: float = 3.14" in text
assert "enum_field: Color = 'green'" in text
assert "factory_list:" in text
assert "= list()" in text
assert "factory_dict:" in text
assert "= dict()" in text
assert "A mapping" in text
assert "PydanticUndefined" not in text

View file

@ -0,0 +1,61 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True, slots=True)
class AgentCommandDef:
name: str
attr: str
help: str
command_pattern: str
returns: str
AGENT_COMMANDS: tuple[AgentCommandDef, ...] = (
AgentCommandDef(
name="context",
attr="context_command",
help="Bootstrap payload with types, state, and builder.",
command_pattern="data-designer agent context",
returns="agent_context",
),
AgentCommandDef(
name="types",
attr="types_command",
help="Type names and import paths for one or all families.",
command_pattern="data-designer agent types [family]",
returns="agent_types",
),
AgentCommandDef(
name="schema",
attr="schema_command",
help="Schema for a type or entire family.",
command_pattern="data-designer agent schema <family> <type> | --all",
returns="agent_schema",
),
AgentCommandDef(
name="builder",
attr="builder_command",
help="ConfigBuilder method surface with signatures.",
command_pattern="data-designer agent builder",
returns="agent_builder",
),
AgentCommandDef(
name="state.model-aliases",
attr="state_model_aliases_command",
help="Model aliases and usability status.",
command_pattern="data-designer agent state model-aliases",
returns="agent_state_model_aliases",
),
AgentCommandDef(
name="state.persona-datasets",
attr="state_persona_datasets_command",
help="Persona locales and install status.",
command_pattern="data-designer agent state persona-datasets",
returns="agent_state_persona_datasets",
),
)

View file

@ -0,0 +1,76 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from collections.abc import Callable
from typing import Any
import typer
from data_designer.cli.utils.agent_introspection import (
AgentIntrospectionError,
get_builder_api,
get_context,
get_model_aliases_state,
get_persona_datasets_state,
get_schema,
get_types,
)
from data_designer.cli.utils.agent_text_formatter import (
format_builder_text,
format_context_text,
format_model_aliases_text,
format_persona_datasets_text,
format_schema_text,
format_types_text,
)
from data_designer.config.utils.constants import DATA_DESIGNER_HOME
def context_command() -> None:
"""Return a bootstrap payload with types, local state, and builder summary."""
_run(lambda: get_context(DATA_DESIGNER_HOME), format_context_text)
def types_command(
family: str | None = typer.Argument(None, help="Optional schema family name."),
) -> None:
"""Return available type names and import paths for one family or all families."""
_run(lambda: get_types(family), format_types_text)
def schema_command(
family: str = typer.Argument(..., help="Schema family name."),
type_name: str | None = typer.Argument(None, help="Type name within the selected family."),
all_types: bool = typer.Option(False, "--all", help="Return every schema in the selected family."),
) -> None:
"""Return schema for a specific type or every type in a family."""
_run(lambda: get_schema(family, type_name, all_types=all_types), format_schema_text)
def builder_command() -> None:
"""Return the DataDesignerConfigBuilder method surface with signatures and docstrings."""
_run(get_builder_api, format_builder_text)
def state_model_aliases_command() -> None:
"""Return configured model aliases and whether each one is currently usable."""
_run(lambda: get_model_aliases_state(DATA_DESIGNER_HOME), format_model_aliases_text)
def state_persona_datasets_command() -> None:
"""Return built-in persona locales and whether each dataset is installed locally."""
_run(lambda: get_persona_datasets_state(DATA_DESIGNER_HOME), format_persona_datasets_text)
def _run(get_data: Callable[[], Any], format_text: Callable[[Any], str]) -> None:
try:
data = get_data()
typer.echo(format_text(data))
except AgentIntrospectionError as exc:
typer.echo(f"Error [{exc.code}]: {exc.message}", err=True)
raise typer.Exit(code=1)
except Exception as exc:
typer.echo(f"Error [internal_error]: {exc}", err=True)
raise typer.Exit(code=1)

View file

@ -5,6 +5,7 @@ from __future__ import annotations
import typer
from data_designer.cli.agent_command_defs import AGENT_COMMANDS
from data_designer.cli.lazy_group import create_lazy_typer_group
from data_designer.cli.runtime import ensure_cli_default_model_settings
@ -99,9 +100,37 @@ download_app = typer.Typer(
no_args_is_help=True,
)
_AGENT_CMD = f"{_CMD}.agent"
def _build_agent_lazy_group(prefix: str) -> dict[str, dict[str, str]]:
return {
cmd.name.removeprefix(f"{prefix}."): {"module": _AGENT_CMD, "attr": cmd.attr, "help": cmd.help}
for cmd in AGENT_COMMANDS
if (prefix == "" and "." not in cmd.name) or cmd.name.startswith(f"{prefix}.")
}
agent_app = typer.Typer(
name="agent",
help="Agent-only interface for dynamic Data Designer introspection",
cls=create_lazy_typer_group(_build_agent_lazy_group("")),
no_args_is_help=True,
)
agent_state_app = typer.Typer(
name="state",
help="Return current local state relevant to agents",
cls=create_lazy_typer_group(_build_agent_lazy_group("state")),
no_args_is_help=True,
)
agent_app.add_typer(agent_state_app, name="state")
# Add setup command groups
app.add_typer(config_app, name="config", rich_help_panel="Setup")
app.add_typer(download_app, name="download", rich_help_panel="Setup")
app.add_typer(agent_app, name="agent", rich_help_panel="Agent")
def main() -> None:

View file

@ -0,0 +1,331 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import importlib.metadata
import inspect
import re
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Literal, get_args, get_origin
import data_designer.config as dd
from data_designer.cli.agent_command_defs import AGENT_COMMANDS
from data_designer.cli.repositories.model_repository import ModelRepository
from data_designer.cli.repositories.persona_repository import PersonaRepository
from data_designer.cli.repositories.provider_repository import ProviderRepository
from data_designer.cli.services.download_service import DownloadService
from data_designer.config.column_types import ColumnConfigT
from data_designer.config.config_builder import DataDesignerConfigBuilder
from data_designer.config.default_model_settings import get_providers_with_missing_api_keys
from data_designer.config.processor_types import ProcessorConfigT
from data_designer.config.sampler_constraints import ColumnConstraintT
from data_designer.config.sampler_params import SamplerParamsT
from data_designer.config.validator_params import ValidatorParamsT
@dataclass(frozen=True)
class FamilySpec:
name: str
type_union: Any
discriminator_field: str
class AgentIntrospectionError(Exception):
def __init__(self, code: str, message: str, details: dict[str, Any] | None = None) -> None:
super().__init__(message)
self.code = code
self.message = message
self.details = details or {}
_FAMILY_SPECS: dict[str, FamilySpec] = {
"columns": FamilySpec(name="columns", type_union=ColumnConfigT, discriminator_field="column_type"),
"samplers": FamilySpec(name="samplers", type_union=SamplerParamsT, discriminator_field="sampler_type"),
"validators": FamilySpec(name="validators", type_union=ValidatorParamsT, discriminator_field="validator_type"),
"processors": FamilySpec(name="processors", type_union=ProcessorConfigT, discriminator_field="processor_type"),
"constraints": FamilySpec(name="constraints", type_union=ColumnConstraintT, discriminator_field="constraint_type"),
}
def get_family_names() -> list[str]:
return sorted(_FAMILY_SPECS)
def get_library_version() -> str:
try:
return importlib.metadata.version("data-designer")
except importlib.metadata.PackageNotFoundError:
return "unknown"
def get_family_spec(family: str) -> FamilySpec:
spec = _FAMILY_SPECS.get(_normalize_family_name(family))
if spec is None:
raise AgentIntrospectionError(
code="unknown_family",
message=f"Unknown family {family!r}.",
details={"available_families": get_family_names()},
)
return spec
def discover_family_types(family: str) -> dict[str, type]:
spec = get_family_spec(family)
discovered: dict[str, type] = {}
for model in get_args(spec.type_union):
type_name = _extract_literal_value(model.model_fields[spec.discriminator_field].annotation)
if type_name in discovered and discovered[type_name] is not model:
raise AgentIntrospectionError(
code="duplicate_discriminator_value",
message=f"Duplicate discriminator {type_name!r} in family {family!r}.",
details={"family": family, "type_name": type_name},
)
discovered[type_name] = model
return dict(sorted(discovered.items()))
def get_import_path(cls: type) -> str:
exported = getattr(dd, cls.__name__, None)
if exported is cls:
return f"data_designer.config.{cls.__name__}"
return f"{cls.__module__}.{cls.__qualname__}"
def get_family_catalog(family: str) -> list[dict[str, str]]:
return [
{"type_name": type_name, "class_name": cls.__name__, "import_path": get_import_path(cls)}
for type_name, cls in discover_family_types(family).items()
]
def get_family_schema(family: str, type_name: str) -> dict[str, Any]:
types_map = discover_family_types(family)
cls = types_map.get(type_name)
if cls is None:
raise AgentIntrospectionError(
code="unknown_type",
message=f"Unknown type {type_name!r} for family {family!r}.",
details={"family": family, "available_types": list(types_map)},
)
return _build_schema_dict(get_family_spec(family).name, type_name, cls)
def get_family_schemas(family: str) -> dict[str, Any]:
spec = get_family_spec(family)
types_map = discover_family_types(family)
items = [_build_schema_dict(spec.name, tn, cls) for tn, cls in types_map.items()]
return {"family": spec.name, "items": items}
def get_builder_api() -> dict[str, Any]:
return {
"class_name": DataDesignerConfigBuilder.__name__,
"import_path": get_import_path(DataDesignerConfigBuilder),
"methods": _get_builder_methods(),
}
def get_operations() -> list[dict[str, str]]:
return [
{"name": c.name, "command_pattern": c.command_pattern, "description": c.help, "returns": c.returns}
for c in AGENT_COMMANDS
]
def get_context(config_dir: Path) -> dict[str, Any]:
catalogs = {f: get_family_catalog(f) for f in get_family_names()}
return {
"operations": get_operations(),
"families": [{"family": f, "count": len(items)} for f, items in catalogs.items()],
"types": catalogs,
"state": {
"model_aliases": get_model_aliases_state(config_dir),
"persona_datasets": get_persona_datasets_state(config_dir),
},
"builder": get_builder_api(),
}
def get_types(family: str | None) -> dict[str, Any]:
if family is None:
catalogs = {f: get_family_catalog(f) for f in get_family_names()}
return {
"families": [{"family": f, "count": len(items)} for f, items in catalogs.items()],
"items": catalogs,
}
return {"family": get_family_spec(family).name, "items": get_family_catalog(family)}
def get_schema(family: str, type_name: str | None, *, all_types: bool) -> dict[str, Any]:
if all_types and type_name is not None:
raise AgentIntrospectionError(
code="invalid_schema_request",
message="Provide either a type name or --all, but not both.",
details={"family": family, "type_name": type_name, "all": all_types},
)
if all_types:
return get_family_schemas(family)
if type_name is None:
raise AgentIntrospectionError(
code="missing_type_name",
message="A type name is required unless --all is provided.",
details={"family": family},
)
return get_family_schema(family, type_name)
def get_model_aliases_state(config_dir: Path) -> dict[str, Any]:
model_registry = _load_registry(ModelRepository(config_dir))
provider_registry = _load_registry(ProviderRepository(config_dir))
items: list[dict[str, Any]] = []
if model_registry is None:
return {
"model_config_present": False,
"provider_config_present": provider_registry is not None,
"default_provider": None if provider_registry is None else provider_registry.default,
"items": items,
}
providers_by_name: dict[str, Any] = {}
missing_key_names: set[str] = set()
default_provider: str | None = None
if provider_registry is not None:
providers_by_name = {p.name: p for p in provider_registry.providers}
default_provider = provider_registry.default or (
provider_registry.providers[0].name if provider_registry.providers else None
)
missing_key_names = {p.name for p in get_providers_with_missing_api_keys(provider_registry.providers)}
for mc in sorted(model_registry.model_configs, key=lambda m: m.alias):
effective = mc.provider or default_provider
usable = True
reason: str | None = None
if effective is None:
usable, reason = False, "No model provider is configured."
elif effective not in providers_by_name:
usable, reason = False, f"Provider {effective!r} is not configured."
elif effective in missing_key_names:
usable, reason = False, f"Provider {effective!r} is missing an API key."
items.append(
{
"model_alias": mc.alias,
"model": mc.model,
"generation_type": getattr(mc.generation_type, "value", str(mc.generation_type)),
"configured_provider": mc.provider,
"effective_provider": effective,
"usable": usable,
"reason": reason,
}
)
return {
"model_config_present": True,
"provider_config_present": provider_registry is not None,
"default_provider": default_provider,
"items": items,
}
def get_persona_datasets_state(config_dir: Path) -> dict[str, Any]:
persona_repo = PersonaRepository()
download_svc = DownloadService(config_dir, persona_repo)
return {
"managed_assets_directory": str(download_svc.get_managed_assets_directory()),
"items": [
{
"locale": loc.code,
"dataset_name": loc.dataset_name,
"size": loc.size,
"installed": download_svc.is_locale_downloaded(loc.code),
}
for loc in sorted(persona_repo.list_all(), key=lambda loc: loc.code)
],
}
def _build_schema_dict(family_name: str, type_name: str, cls: type) -> dict[str, Any]:
return {
"family": family_name,
"type_name": type_name,
"class_name": cls.__name__,
"import_path": get_import_path(cls),
"schema": cls.model_json_schema(),
"schema_text": cls.schema_text(),
}
def _normalize_family_name(family: str) -> str:
normalized = family.strip().lower()
if normalized in _FAMILY_SPECS:
return normalized
plural = f"{normalized}s"
if plural in _FAMILY_SPECS:
return plural
return normalized
def _extract_literal_value(annotation: Any) -> str:
if get_origin(annotation) is not Literal or not get_args(annotation):
raise AgentIntrospectionError(
code="invalid_discriminator_annotation",
message=f"Expected non-empty Literal annotation, got {annotation!r}.",
)
value = get_args(annotation)[0]
return str(value.value) if isinstance(value, Enum) else str(value)
def _get_builder_methods() -> list[dict[str, Any]]:
methods: list[dict[str, Any]] = []
for name, attr in inspect.getmembers(DataDesignerConfigBuilder):
if name.startswith("_") and name != "__init__":
continue
if not callable(attr):
continue
try:
sig = inspect.signature(attr)
except (TypeError, ValueError):
continue
docstring = inspect.getdoc(attr)
methods.append(
{
"name": name,
"signature": _format_signature(name, sig),
"summary": _get_first_line(docstring),
"docstring": docstring,
}
)
return methods
def _format_signature(method_name: str, sig: inspect.Signature) -> str:
params = [p for p in sig.parameters.values() if p.name not in {"self", "cls"}]
sig_str = str(sig.replace(parameters=params))
sig_str = re.sub(r"\b[a-zA-Z_]\w*(?:\.[a-zA-Z_]\w*)+", lambda m: m.group().rsplit(".", 1)[-1], sig_str)
# Strip quotes left by `from __future__ import annotations` around type annotations.
sig_str = re.sub(r"(?<=: )'([^']+)'", r"\1", sig_str)
sig_str = re.sub(r"(?<=-> )'([^']+)'", r"\1", sig_str)
return f"{method_name}{sig_str}"
def _get_first_line(text: str | None) -> str | None:
return next((line.strip() for line in text.strip().splitlines() if line.strip()), None) if text else None
def _load_registry(repo: Any) -> Any:
if not repo.exists():
return None
registry = repo.load()
if registry is None:
raise AgentIntrospectionError(
code="invalid_registry",
message=f"Failed to load registry from {str(repo.config_file)!r}.",
details={"config_file": str(repo.config_file)},
)
return registry

View file

@ -0,0 +1,137 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from typing import Any
from data_designer.cli.utils.agent_introspection import get_library_version
def format_context_text(data: dict[str, Any]) -> str:
"""Format the full context payload as sectioned text with tables."""
sections = [
f"Data Designer v{get_library_version()}",
"",
"import data_designer.config as dd",
"",
'A "family" is a group of related config types that share a discriminator field.',
"Use dd.<ClassName> to reference any type below.",
"",
"## Families",
"",
format_types_text({"families": data["families"], "items": data["types"]}),
"",
"## Model Aliases",
"",
format_model_aliases_text(data["state"]["model_aliases"]),
"",
"## Persona Datasets",
"",
format_persona_datasets_text(data["state"]["persona_datasets"]),
"",
"## Builder",
"",
format_builder_text(data["builder"]),
"",
"## Commands",
"",
_format_table(data["operations"], ["command_pattern", "description"]),
]
return "\n".join(sections)
def format_types_text(data: dict[str, Any]) -> str:
"""Format type listings for one family or all families."""
if "families" in data:
lines: list[str] = [f"{f['family']}: {f['count']} types" for f in data["families"]]
lines.append("")
for family_name, items in data["items"].items():
lines.append(_format_table(items, ["type_name", "class_name"], title=family_name))
lines.append("")
return "\n".join(lines).rstrip()
return _format_table(
data["items"],
["type_name", "class_name"],
title=data.get("family"),
)
def format_schema_text(data: dict[str, Any]) -> str:
"""Format schema data as human-readable field summaries."""
if "items" in data:
header = f"# {data['family']} schemas ({len(data['items'])} types)"
schemas = "\n\n".join(item["schema_text"] for item in data["items"])
return f"{header}\n\n{schemas}"
return data["schema_text"]
def format_builder_text(data: dict[str, Any]) -> str:
"""Format builder methods with signatures."""
path = data["import_path"]
hint = f"dd.{path.removeprefix('data_designer.config.')}" if path.startswith("data_designer.config.") else path
lines: list[str] = [
f"{data['class_name']}:",
f" usage: {hint}",
" methods:",
"",
]
for method in data["methods"]:
lines.append(f" {method['signature']}")
if method.get("summary"):
lines.append(f" {method['summary']}")
lines.append("")
return "\n".join(lines).rstrip()
def format_model_aliases_text(state: dict[str, Any]) -> str:
"""Format model aliases as a text table with provider summary."""
lines: list[str] = [f"default_provider: {state.get('default_provider') or '(none)'}", ""]
lines.append(
_format_table(
state.get("items", []),
["model_alias", "model", "generation_type", "effective_provider", "usable", "reason"],
column_labels={"effective_provider": "provider"},
title="model aliases",
)
)
return "\n".join(lines)
def format_persona_datasets_text(state: dict[str, Any]) -> str:
"""Format persona datasets as a text table."""
return _format_table(state.get("items", []), ["locale", "size", "installed"], title="persona datasets")
def _format_table(
items: list[dict[str, Any]],
columns: list[str],
*,
title: str | None = None,
column_labels: dict[str, str] | None = None,
) -> str:
labels = {col: (column_labels or {}).get(col, col) for col in columns}
if not items:
header = f"# {title}" if title else "# table"
return f"{header}\n(no items)"
col_widths = {col: max(len(labels[col]), max(len(_cell(row.get(col))) for row in items)) for col in columns}
lines: list[str] = []
if title:
lines.append(f"# {title}")
lines.append("")
lines.append(" ".join(f"{labels[col]:<{col_widths[col]}}" for col in columns))
lines.append(" ".join("-" * col_widths[col] for col in columns))
for row in items:
lines.append(" ".join(f"{_cell(row.get(col)):<{col_widths[col]}}" for col in columns))
return "\n".join(lines)
def _cell(value: Any) -> str:
"""Convert a cell value to a string, rendering None as empty."""
if value is None:
return ""
return str(value)

View file

@ -0,0 +1,65 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from unittest.mock import patch
import pytest
from typer.testing import CliRunner
from data_designer.cli.main import app
_PATCH = "data_designer.cli.commands.agent"
@pytest.mark.parametrize(
"args,data_fn,format_fn,expected_text",
[
(["agent", "context"], "get_context", "format_context_text", "Data Designer"),
(["agent", "types", "columns"], "get_types", "format_types_text", "columns"),
(["agent", "builder"], "get_builder_api", "format_builder_text", "Builder:"),
(["agent", "state", "model-aliases"], "get_model_aliases_state", "format_model_aliases_text", "model aliases"),
(
["agent", "state", "persona-datasets"],
"get_persona_datasets_state",
"format_persona_datasets_text",
"persona",
),
],
ids=["context", "types", "builder", "model-aliases", "persona-datasets"],
)
def test_commands_default_text_mode(args: list[str], data_fn: str, format_fn: str, expected_text: str) -> None:
runner = CliRunner()
with (
patch(f"{_PATCH}.{data_fn}", return_value={"stub": True}) as mock_get,
patch(f"{_PATCH}.{format_fn}", return_value=expected_text),
):
result = runner.invoke(app, args)
assert result.exit_code == 0
assert expected_text in result.output
mock_get.assert_called_once()
def test_schema_command_default_outputs_text() -> None:
runner = CliRunner()
with (
patch(f"{_PATCH}.get_schema", return_value={"type_name": "llm-text", "schema": {}}),
patch(f"{_PATCH}.format_schema_text", return_value="# llm-text\n{}"),
):
result = runner.invoke(app, ["agent", "schema", "columns", "llm-text"])
assert result.exit_code == 0
assert "llm-text" in result.output
def test_error_outputs_message_to_stderr() -> None:
runner = CliRunner()
with patch(f"{_PATCH}.get_schema", side_effect=ValueError("boom")):
result = runner.invoke(app, ["agent", "schema", "columns", "missing"])
assert result.exit_code == 1
assert result.stdout == ""
assert "internal_error" in result.stderr
assert "boom" in result.stderr

View file

@ -0,0 +1,104 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import pytest
from data_designer.cli.utils.agent_introspection import (
AgentIntrospectionError,
discover_family_types,
get_builder_api,
get_family_catalog,
get_family_schema,
get_family_spec,
get_operations,
get_types,
)
def test_get_family_catalog_accepts_singular_family_names() -> None:
assert get_family_catalog("validator") == get_family_catalog("validators")
def test_get_family_catalog_returns_sorted_type_names() -> None:
catalog = get_family_catalog("columns")
assert catalog
assert [item["type_name"] for item in catalog] == sorted(item["type_name"] for item in catalog)
def test_get_family_schema_returns_json_schema_payload() -> None:
schema_payload = get_family_schema("validator", "code")
assert schema_payload["family"] == "validators"
assert schema_payload["type_name"] == "code"
assert schema_payload["class_name"] == "CodeValidatorParams"
assert schema_payload["import_path"] == "data_designer.config.CodeValidatorParams"
assert schema_payload["schema"]["title"] == "CodeValidatorParams"
def test_get_family_schema_includes_schema_text() -> None:
schema_payload = get_family_schema("columns", "llm-text")
text = schema_payload["schema_text"]
assert text.startswith("LLMTextColumnConfig:")
assert "column_type:" in text
assert "name:" in text
assert "Configuration for text generation" in text
assert "Jinja2 template" in text
def test_get_family_schema_raises_for_unknown_type() -> None:
with pytest.raises(AgentIntrospectionError) as exc_info:
get_family_schema("validators", "does-not-exist")
assert exc_info.value.code == "unknown_type"
assert exc_info.value.details["family"] == "validators"
assert "code" in exc_info.value.details["available_types"]
def test_discover_family_types_returns_pydantic_classes() -> None:
types_map = discover_family_types("columns")
assert types_map
assert all(hasattr(cls, "model_fields") for cls in types_map.values())
def test_get_family_spec_returns_discriminator_field() -> None:
spec = get_family_spec("columns")
assert spec.name == "columns"
assert spec.discriminator_field == "column_type"
def test_get_builder_api_includes_docstrings() -> None:
builder_api = get_builder_api()
assert builder_api["class_name"] == "DataDesignerConfigBuilder"
assert builder_api["import_path"] == "data_designer.config.DataDesignerConfigBuilder"
assert builder_api["methods"]
assert all("docstring" in method for method in builder_api["methods"])
def test_get_types_returns_all_families_when_no_family_given() -> None:
data = get_types(None)
assert "families" in data
assert "items" in data
assert len(data["families"]) > 0
assert all(f["family"] in data["items"] for f in data["families"])
def test_get_types_returns_single_family() -> None:
data = get_types("columns")
assert data["family"] == "columns"
assert isinstance(data["items"], list)
assert len(data["items"]) > 0
def test_get_operations_returns_all_commands() -> None:
ops = get_operations()
assert len(ops) == 6
assert all("name" in op and "command_pattern" in op and "description" in op for op in ops)

View file

@ -0,0 +1,140 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from pathlib import Path
from data_designer.cli.repositories.model_repository import ModelConfigRegistry, ModelRepository
from data_designer.cli.repositories.provider_repository import ModelProviderRegistry, ProviderRepository
from data_designer.cli.utils.agent_introspection import (
get_context,
get_model_aliases_state,
get_persona_datasets_state,
)
from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig, ModelProvider
def test_get_model_aliases_state_reports_provider_status(tmp_path: Path) -> None:
provider_repository = ProviderRepository(tmp_path)
provider_repository.save(
ModelProviderRegistry(
providers=[
ModelProvider(
name="provider-a",
endpoint="https://api.example.com/a",
provider_type="openai",
api_key="test-api-key",
),
ModelProvider(
name="provider-b",
endpoint="https://api.example.com/b",
provider_type="openai",
api_key="MISSING_PROVIDER_KEY",
),
],
default="provider-a",
)
)
model_repository = ModelRepository(tmp_path)
model_repository.save(
ModelConfigRegistry(
model_configs=[
ModelConfig(
alias="alpha",
model="model-alpha",
provider=None,
inference_parameters=ChatCompletionInferenceParams(),
),
ModelConfig(
alias="beta",
model="model-beta",
provider="provider-b",
inference_parameters=ChatCompletionInferenceParams(),
),
ModelConfig(
alias="gamma",
model="model-gamma",
provider="provider-missing",
inference_parameters=ChatCompletionInferenceParams(),
),
]
)
)
payload = get_model_aliases_state(tmp_path)
assert payload["model_config_present"] is True
assert payload["provider_config_present"] is True
assert payload["default_provider"] == "provider-a"
assert payload["items"] == [
{
"model_alias": "alpha",
"model": "model-alpha",
"generation_type": "chat-completion",
"configured_provider": None,
"effective_provider": "provider-a",
"usable": True,
"reason": None,
},
{
"model_alias": "beta",
"model": "model-beta",
"generation_type": "chat-completion",
"configured_provider": "provider-b",
"effective_provider": "provider-b",
"usable": False,
"reason": "Provider 'provider-b' is missing an API key.",
},
{
"model_alias": "gamma",
"model": "model-gamma",
"generation_type": "chat-completion",
"configured_provider": "provider-missing",
"effective_provider": "provider-missing",
"usable": False,
"reason": "Provider 'provider-missing' is not configured.",
},
]
def test_get_model_aliases_state_handles_missing_local_files(tmp_path: Path) -> None:
payload = get_model_aliases_state(tmp_path)
assert payload == {
"model_config_present": False,
"provider_config_present": False,
"default_provider": None,
"items": [],
}
def test_get_persona_datasets_state_reports_installed_locales(tmp_path: Path) -> None:
managed_assets_dir = tmp_path / "managed-assets" / "datasets"
managed_assets_dir.mkdir(parents=True)
(managed_assets_dir / "en_US.parquet").write_text("stub")
payload = get_persona_datasets_state(tmp_path)
assert payload["managed_assets_directory"] == str(managed_assets_dir)
installed_by_locale = {item["locale"]: item["installed"] for item in payload["items"]}
assert installed_by_locale["en_US"] is True
assert any(not item["installed"] for item in payload["items"] if item["locale"] != "en_US")
def test_get_context_returns_self_describing_payload(tmp_path: Path) -> None:
payload = get_context(tmp_path)
operation_names = [operation["name"] for operation in payload["operations"]]
assert operation_names == [
"context",
"types",
"schema",
"builder",
"state.model-aliases",
"state.persona-datasets",
]
assert payload["families"]
assert "columns" in payload["types"]
assert payload["builder"]["methods"]

View file

@ -0,0 +1,223 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from typing import Any
import pytest
from data_designer.cli.utils.agent_introspection import get_family_schema
from data_designer.cli.utils.agent_text_formatter import (
format_builder_text,
format_context_text,
format_model_aliases_text,
format_persona_datasets_text,
format_schema_text,
format_types_text,
)
# --- format_context_text ---
def test_format_context_text_includes_builder_section() -> None:
data: dict[str, Any] = {
"families": [{"family": "columns", "count": 1}],
"types": {
"columns": [{"type_name": "a", "class_name": "A", "import_path": "m.A"}],
},
"state": {
"model_aliases": {"default_provider": None, "items": []},
"persona_datasets": {"items": []},
},
"builder": {
"class_name": "DataDesignerConfigBuilder",
"import_path": "data_designer.config.DataDesignerConfigBuilder",
"methods": [{"name": "add_column", "signature": "add_column(col)", "summary": "Add a column."}],
},
"operations": [{"command_pattern": "agent context", "description": "Bootstrap payload."}],
}
result = format_context_text(data)
assert "## Builder" in result
assert "DataDesignerConfigBuilder:" in result
assert "add_column(col)" in result
# --- format_types_text ---
def test_format_types_text_single_family() -> None:
data: dict[str, Any] = {
"family": "columns",
"items": [
{"type_name": "alpha", "class_name": "AlphaConfig", "import_path": "mod.AlphaConfig"},
{"type_name": "beta", "class_name": "BetaConfig", "import_path": "mod.BetaConfig"},
],
}
result = format_types_text(data)
assert "# columns" in result
assert "alpha" in result
assert "AlphaConfig" in result
def test_format_types_text_all_families() -> None:
data: dict[str, Any] = {
"families": [{"family": "columns", "count": 2}],
"items": {
"columns": [
{"type_name": "a", "class_name": "A", "import_path": "m.A"},
{"type_name": "b", "class_name": "B", "import_path": "m.B"},
],
},
}
result = format_types_text(data)
assert "columns: 2 types" in result
assert "a" in result
assert "b" in result
def test_format_types_text_empty_items() -> None:
data: dict[str, Any] = {"family": "columns", "items": []}
result = format_types_text(data)
assert "(no items)" in result
# --- format_schema_text ---
def test_format_schema_text_single_type() -> None:
data: dict[str, Any] = {
"type_name": "llm-text",
"class_name": "LLMTextColumnConfig",
"schema_text": "LLMTextColumnConfig:\n name: str [required]",
}
result = format_schema_text(data)
assert "LLMTextColumnConfig:" in result
assert "name: str [required]" in result
def test_format_schema_text_all_types() -> None:
data: dict[str, Any] = {
"family": "columns",
"items": [
{"type_name": "a", "class_name": "A", "schema_text": "A:\n x: int [required]"},
{"type_name": "b", "class_name": "B", "schema_text": "B:\n y: str = 'hi'"},
],
}
result = format_schema_text(data)
assert "# columns schemas (2 types)" in result
assert "A:\n x: int [required]" in result
assert "B:\n y: str = 'hi'" in result
def test_format_schema_text_passes_through_schema_text() -> None:
schema_text = "TestModel:\n name: str [required]\n count: int = 0"
data: dict[str, Any] = {"type_name": "test", "class_name": "TestModel", "schema_text": schema_text}
result = format_schema_text(data)
assert result == schema_text
# --- format_builder_text ---
def test_format_builder_text_renders_methods() -> None:
data: dict[str, Any] = {
"class_name": "MyBuilder",
"import_path": "data_designer.config.MyBuilder",
"methods": [
{"name": "add_column", "signature": "add_column(column: ColumnConfig)", "summary": "Add a column."},
{"name": "build", "signature": "build()", "summary": "Build the config."},
],
}
result = format_builder_text(data)
assert "MyBuilder:" in result
assert "usage: dd.MyBuilder" in result
assert "add_column(column: ColumnConfig)" in result
assert "Add a column." in result
def test_format_builder_text_handles_method_without_summary() -> None:
data: dict[str, Any] = {
"class_name": "Builder",
"import_path": "mod.Builder",
"methods": [{"name": "reset", "signature": "reset()", "summary": None}],
}
result = format_builder_text(data)
assert "reset()" in result
# --- format_model_aliases_text ---
def test_format_model_aliases_text_with_items() -> None:
state: dict[str, Any] = {
"default_provider": "nvidia",
"items": [
{
"model_alias": "test",
"model": "meta/llama-3",
"generation_type": "chat",
"effective_provider": "nvidia",
"usable": True,
"reason": None,
},
],
}
result = format_model_aliases_text(state)
assert "default_provider: nvidia" in result
assert "test" in result
assert "meta/llama-3" in result
def test_format_model_aliases_text_empty() -> None:
state: dict[str, Any] = {"default_provider": None, "items": []}
result = format_model_aliases_text(state)
assert "default_provider: (none)" in result
assert "(no items)" in result
# --- format_persona_datasets_text ---
def test_format_persona_datasets_text() -> None:
state: dict[str, Any] = {
"items": [{"locale": "en_US", "size": "10MB", "installed": True}],
}
result = format_persona_datasets_text(state)
assert "# persona datasets" in result
assert "en_US" in result
assert "True" in result
# --- Real config models ---
@pytest.mark.parametrize(
"family,type_name",
[
("columns", "llm-text"),
("columns", "sampler"),
("samplers", "category"),
("validators", "code"),
("constraints", "scalar_inequality"),
],
ids=["columns-llm-text", "columns-sampler", "samplers-category", "validators-code", "constraints-scalar"],
)
def test_format_schema_text_on_real_config_models(family: str, type_name: str) -> None:
schema_data = get_family_schema(family, type_name)
result = format_schema_text(schema_data)
assert schema_data["class_name"] in result
assert result == schema_data["schema_text"]