mirror of
https://github.com/NVIDIA-NeMo/DataDesigner
synced 2026-05-24 09:48:29 +00:00
feat: agent CLI introspection (simplified) (#415)
* feat: add agent introspection cli * refactor: remove agent cli schema version * refactor: omit missing builder docstrings from context * refactor: tighten agent cli contract * feat: add schema_text() to ConfigBase for human-readable field summaries ConfigBase.schema_text() returns a concise text representation including the class docstring summary, field names, types, defaults, and descriptions. Field descriptions added to column config types to surface through this method. * refactor: flatten agent CLI into plain functions with text output mode Delete AgentController class and agent_command_defs module. Move all logic into agent_introspection (data) and agent_text_formatter (display) as plain functions. Add --json flag so commands default to human-readable text using schema_text(), with JSON as opt-in. Unify _emit helper, remove include_docstrings parameter, deduplicate catalog calls, and fix N+1 discover_family_types in get_family_schemas. * fix: port stale controller tests and consolidate command descriptions Port test_agent_controller.py to use plain functions instead of deleted AgentController. Extract AGENT_COMMANDS constant as single source for operation descriptions, syncing with main.py help strings. * style: fix ruff formatting in agent_introspection * refactor: centralize agent command definitions Extract AGENT_COMMANDS into agent_command_defs.py so main.py and agent_introspection.py share a single source for command names, help text, and metadata. The new module has no heavy dependencies, keeping --help latency unaffected. * fix: handle default_factory and empty providers in schema_text and introspection - schema_text() now detects default_factory fields and renders e.g. "list()" instead of leaking PydanticUndefined - Guard against IndexError when provider registry has an empty providers list - Add 15 edge-case tests for schema_text covering default_factory, enum defaults, None defaults, scalar defaults, descriptions, and docstrings * refactor: remove JSON output mode from agent CLI commands Text-only output simplifies the interface. Structured output can be added back trivially since the functions already return dicts. * docs: update schema_text docstring to reflect agent focus * fix: include builder section and import_path in agent text output - format_context_text now renders a ## Builder section - format_types_text now includes import_path column in tables * refactor: drop import_path from types tables All config objects are imported via dd.<ClassName>, so the full import path is redundant noise in agent output. * docs: add family definition and import hint to context output * refactor: rename Types section to Families, drop redundant "types" from sub-headers * fix: coerce None to empty string in table cells row.get(col, '') returns None when the key exists with value None, causing str(None) to render "None" in the output. Use `or ''` instead. * refactor: move agent controller tests to utils as introspection integration tests There is no controller layer — these tests exercise functions in agent_introspection.py, so they belong in tests/cli/utils/. * fix: only coerce None to empty string in table cells, not False The previous `or ''` pattern treated all falsy values (including False) as empty. Use an explicit None check so booleans render correctly. * style: address review nits from nabin - Add explicit parentheses to and/or precedence in _build_agent_lazy_group - Rename loop variable l to line in test_schema_text - Move get_family_schema import to module level in test_agent_text_formatter * fix: improve schema_text Literal display, builder signature quotes, and docstring parsing - _format_annotation now renders Literal['value'] instead of bare Literal - _format_signature strips quotes from stringified annotations caused by `from __future__ import annotations` - _get_docstring_summary stops at any Google-style section header, not just Attributes:
This commit is contained in:
parent
02744d152d
commit
4c19dba74b
12 changed files with 1506 additions and 24 deletions
|
|
@ -6,7 +6,10 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
|
@ -20,6 +23,31 @@ class ConfigBase(BaseModel):
|
|||
json_schema_mode_override="validation",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def schema_text(cls) -> str:
|
||||
"""Return an agent-friendly text summary of the model's fields and defaults."""
|
||||
lines: list[str] = [f"{cls.__name__}:"]
|
||||
docstring = _get_docstring_summary(cls.__doc__)
|
||||
if docstring:
|
||||
lines.append(f" {docstring}")
|
||||
lines.append("")
|
||||
for name, field_info in cls.model_fields.items():
|
||||
annotation = _format_annotation(field_info.annotation)
|
||||
if field_info.is_required():
|
||||
lines.append(f" {name}: {annotation} [required]")
|
||||
else:
|
||||
if field_info.default_factory is not None:
|
||||
factory_name = getattr(field_info.default_factory, "__name__", repr(field_info.default_factory))
|
||||
lines.append(f" {name}: {annotation} = {factory_name}()")
|
||||
else:
|
||||
default = field_info.default
|
||||
if isinstance(default, Enum):
|
||||
default = default.value
|
||||
lines.append(f" {name}: {annotation} = {default!r}")
|
||||
if field_info.description:
|
||||
lines.append(f" {field_info.description}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class SingleColumnConfig(ConfigBase, ABC):
|
||||
"""Abstract base class for all single-column configuration types.
|
||||
|
|
@ -83,3 +111,52 @@ class ProcessorConfig(ConfigBase, ABC):
|
|||
description="The name of the processor, used to identify the processor in the results and to write the artifacts to disk.",
|
||||
)
|
||||
processor_type: str
|
||||
|
||||
|
||||
def _format_annotation(annotation: Any) -> str:
|
||||
"""Convert a type annotation to a readable string, stripping module paths."""
|
||||
if get_origin(annotation) is Literal:
|
||||
args = get_args(annotation)
|
||||
if args:
|
||||
values = ", ".join(repr(a.value) if isinstance(a, Enum) else repr(a) for a in args)
|
||||
return f"Literal[{values}]"
|
||||
raw = str(annotation) if not hasattr(annotation, "__name__") else annotation.__name__
|
||||
return re.sub(r"\b[a-zA-Z_]\w*(?:\.[a-zA-Z_]\w*)+", lambda m: m.group().rsplit(".", 1)[-1], raw)
|
||||
|
||||
|
||||
_GOOGLE_SECTION_HEADERS = frozenset(
|
||||
{
|
||||
"args:",
|
||||
"arguments:",
|
||||
"attributes:",
|
||||
"example:",
|
||||
"examples:",
|
||||
"keyword args:",
|
||||
"keyword arguments:",
|
||||
"note:",
|
||||
"notes:",
|
||||
"raises:",
|
||||
"references:",
|
||||
"returns:",
|
||||
"see also:",
|
||||
"todo:",
|
||||
"warns:",
|
||||
"yields:",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_docstring_summary(docstring: str | None) -> str | None:
|
||||
"""Extract the first paragraph of a docstring, before any Google-style section header."""
|
||||
if not docstring:
|
||||
return None
|
||||
lines: list[str] = []
|
||||
for line in docstring.strip().splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped.lower() in _GOOGLE_SECTION_HEADERS:
|
||||
break
|
||||
if not stripped and lines:
|
||||
break
|
||||
if stripped:
|
||||
lines.append(stripped)
|
||||
return " ".join(lines) if lines else None
|
||||
|
|
|
|||
|
|
@ -56,10 +56,19 @@ class SamplerColumnConfig(SingleColumnConfig):
|
|||
```
|
||||
"""
|
||||
|
||||
sampler_type: SamplerType
|
||||
params: Annotated[SamplerParamsT, Discriminator("sampler_type")]
|
||||
conditional_params: dict[str, Annotated[SamplerParamsT, Discriminator("sampler_type")]] = {}
|
||||
convert_to: str | None = None
|
||||
sampler_type: SamplerType = Field(
|
||||
description="Type of sampler to use (e.g., uuid, category, uniform, gaussian, person, datetime)"
|
||||
)
|
||||
params: Annotated[SamplerParamsT, Discriminator("sampler_type")] = Field(
|
||||
description="Parameters specific to the chosen sampler type"
|
||||
)
|
||||
conditional_params: dict[str, Annotated[SamplerParamsT, Discriminator("sampler_type")]] = Field(
|
||||
default_factory=dict,
|
||||
description="Optional dictionary for conditional parameters; keys are conditions, values are params to use when met",
|
||||
)
|
||||
convert_to: str | None = Field(
|
||||
default=None, description="Optional type conversion after sampling: 'float', 'int', or 'str'"
|
||||
)
|
||||
column_type: Literal["sampler"] = "sampler"
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -136,13 +145,25 @@ class LLMTextColumnConfig(SingleColumnConfig):
|
|||
column_type: Discriminator field, always "llm-text" for this configuration type.
|
||||
"""
|
||||
|
||||
prompt: str
|
||||
model_alias: str
|
||||
system_prompt: str | None = None
|
||||
multi_modal_context: list[ImageContext] | None = None
|
||||
tool_alias: str | None = None
|
||||
with_trace: TraceType = TraceType.NONE
|
||||
extract_reasoning_content: bool = False
|
||||
prompt: str = Field(
|
||||
description="Jinja2 template for the LLM prompt; can reference other columns via {{ column_name }}"
|
||||
)
|
||||
model_alias: str = Field(description="Alias of the model configuration to use for generation")
|
||||
system_prompt: str | None = Field(
|
||||
default=None, description="Optional system prompt to set model behavior and constraints"
|
||||
)
|
||||
multi_modal_context: list[ImageContext] | None = Field(
|
||||
default=None, description="Optional list of ImageContext for vision model inputs"
|
||||
)
|
||||
tool_alias: str | None = Field(
|
||||
default=None, description="Optional alias of the tool configuration to use for MCP tool calls"
|
||||
)
|
||||
with_trace: TraceType = Field(
|
||||
default=TraceType.NONE, description="Trace capture mode: NONE, LAST_MESSAGE, or ALL_MESSAGES"
|
||||
)
|
||||
extract_reasoning_content: bool = Field(
|
||||
default=False, description="If True, capture chain-of-thought in {name}__reasoning_content column"
|
||||
)
|
||||
column_type: Literal["llm-text"] = "llm-text"
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -219,7 +240,9 @@ class LLMCodeColumnConfig(LLMTextColumnConfig):
|
|||
column containing the reasoning content from the final assistant response.
|
||||
"""
|
||||
|
||||
code_lang: CodeLang
|
||||
code_lang: CodeLang = Field(
|
||||
description="Target programming language or SQL dialect for code extraction from LLM response"
|
||||
)
|
||||
column_type: Literal["llm-code"] = "llm-code"
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -252,7 +275,9 @@ class LLMStructuredColumnConfig(LLMTextColumnConfig):
|
|||
column containing the reasoning content from the final assistant response.
|
||||
"""
|
||||
|
||||
output_format: dict | type[BaseModel]
|
||||
output_format: dict | type[BaseModel] = Field(
|
||||
description="Pydantic model or JSON schema dict defining the expected structured output shape"
|
||||
)
|
||||
column_type: Literal["llm-structured"] = "llm-structured"
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -317,7 +342,9 @@ class LLMJudgeColumnConfig(LLMTextColumnConfig):
|
|||
column containing the reasoning content from the final assistant response.
|
||||
"""
|
||||
|
||||
scores: list[Score] = Field(..., min_length=1)
|
||||
scores: list[Score] = Field(
|
||||
..., min_length=1, description="List of Score objects defining rubric criteria for LLM judge evaluation"
|
||||
)
|
||||
column_type: Literal["llm-judge"] = "llm-judge"
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -342,8 +369,10 @@ class ExpressionColumnConfig(SingleColumnConfig):
|
|||
"""
|
||||
|
||||
name: str
|
||||
expr: str
|
||||
dtype: Literal["int", "float", "str", "bool"] = "str"
|
||||
expr: str = Field(description="Jinja2 expression to compute the column value from other columns")
|
||||
dtype: Literal["int", "float", "str", "bool"] = Field(
|
||||
default="str", description="Data type for expression result: 'int', 'float', 'str', or 'bool'"
|
||||
)
|
||||
column_type: Literal["expression"] = "expression"
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -410,9 +439,11 @@ class ValidationColumnConfig(SingleColumnConfig):
|
|||
column_type: Discriminator field, always "validation" for this configuration type.
|
||||
"""
|
||||
|
||||
target_columns: list[str]
|
||||
validator_type: ValidatorType
|
||||
validator_params: Annotated[ValidatorParamsT, Discriminator("validator_type")]
|
||||
target_columns: list[str] = Field(description="List of column names to validate")
|
||||
validator_type: ValidatorType = Field(description="Validation method: 'code', 'local_callable', or 'remote'")
|
||||
validator_params: Annotated[ValidatorParamsT, Discriminator("validator_type")] = Field(
|
||||
description="Validator-specific parameters (e.g., CodeValidatorParams)"
|
||||
)
|
||||
batch_size: int = Field(default=10, ge=1, description="Number of records to process in each batch")
|
||||
column_type: Literal["validation"] = "validation"
|
||||
|
||||
|
|
@ -479,8 +510,8 @@ class EmbeddingColumnConfig(SingleColumnConfig):
|
|||
column_type: Discriminator field, always "embedding" for this configuration type.
|
||||
"""
|
||||
|
||||
target_column: str
|
||||
model_alias: str
|
||||
target_column: str = Field(description="Name of the text column to generate embeddings for")
|
||||
model_alias: str = Field(description="Alias of the model to use for embedding generation")
|
||||
column_type: Literal["embedding"] = "embedding"
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -513,9 +544,13 @@ class ImageColumnConfig(SingleColumnConfig):
|
|||
column_type: Discriminator field, always "image" for this configuration type.
|
||||
"""
|
||||
|
||||
prompt: str
|
||||
model_alias: str
|
||||
multi_modal_context: list[ImageContext] | None = None
|
||||
prompt: str = Field(
|
||||
description="Jinja2 template for the image generation prompt; can reference other columns via {{ column_name }}"
|
||||
)
|
||||
model_alias: str = Field(description="Alias of the model to use for image generation")
|
||||
multi_modal_context: list[ImageContext] | None = Field(
|
||||
default=None, description="Optional list of ImageContext for multi-modal image-to-image generation"
|
||||
)
|
||||
column_type: Literal["image"] = "image"
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
204
packages/data-designer-config/tests/config/test_schema_text.py
Normal file
204
packages/data-designer-config/tests/config/test_schema_text.py
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from data_designer.config.base import ConfigBase
|
||||
|
||||
|
||||
class Color(Enum):
|
||||
RED = "red"
|
||||
GREEN = "green"
|
||||
|
||||
|
||||
class RequiredOnlyModel(ConfigBase):
|
||||
name: str
|
||||
count: int
|
||||
|
||||
|
||||
class DefaultFactoryModel(ConfigBase):
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
metadata: dict[str, int] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class DefaultFactoryWithDescriptionModel(ConfigBase):
|
||||
items: list[str] = Field(default_factory=list, description="Collection of items")
|
||||
|
||||
|
||||
class EnumDefaultModel(ConfigBase):
|
||||
color: Color = Color.RED
|
||||
|
||||
|
||||
class NoneDefaultModel(ConfigBase):
|
||||
label: str | None = None
|
||||
|
||||
|
||||
class ScalarDefaultModel(ConfigBase):
|
||||
threshold: float = 0.5
|
||||
enabled: bool = True
|
||||
tag: str = "default"
|
||||
|
||||
|
||||
class DescribedFieldsModel(ConfigBase):
|
||||
name: str = Field(description="The name of the thing")
|
||||
count: int = Field(default=0, description="How many things")
|
||||
|
||||
|
||||
class DocstringModel(ConfigBase):
|
||||
"""A model with a docstring summary."""
|
||||
|
||||
value: int
|
||||
|
||||
|
||||
class NoDocstringModel(ConfigBase):
|
||||
value: int
|
||||
|
||||
|
||||
class LiteralModel(ConfigBase):
|
||||
tag: Literal["fixed"] = "fixed"
|
||||
name: str
|
||||
|
||||
|
||||
class MixedModel(ConfigBase):
|
||||
"""Model exercising every field variant."""
|
||||
|
||||
required_field: str
|
||||
optional_none: int | None = None
|
||||
with_default: float = 3.14
|
||||
enum_field: Color = Color.GREEN
|
||||
factory_list: list[str] = Field(default_factory=list)
|
||||
factory_dict: dict[str, int] = Field(default_factory=dict, description="A mapping")
|
||||
|
||||
|
||||
# --- Required fields ---
|
||||
|
||||
|
||||
def test_required_fields_marked() -> None:
|
||||
text = RequiredOnlyModel.schema_text()
|
||||
assert "name: str [required]" in text
|
||||
assert "count: int [required]" in text
|
||||
|
||||
|
||||
# --- default_factory fields ---
|
||||
|
||||
|
||||
def test_default_factory_list_shows_factory_call() -> None:
|
||||
text = DefaultFactoryModel.schema_text()
|
||||
assert "tags:" in text
|
||||
assert "= list()" in text
|
||||
|
||||
|
||||
def test_default_factory_dict_shows_factory_call() -> None:
|
||||
text = DefaultFactoryModel.schema_text()
|
||||
assert "metadata:" in text
|
||||
assert "= dict()" in text
|
||||
|
||||
|
||||
def test_default_factory_does_not_show_pydantic_undefined() -> None:
|
||||
text = DefaultFactoryModel.schema_text()
|
||||
assert "PydanticUndefined" not in text
|
||||
|
||||
|
||||
def test_default_factory_with_description() -> None:
|
||||
text = DefaultFactoryWithDescriptionModel.schema_text()
|
||||
assert "items:" in text
|
||||
assert "= list()" in text
|
||||
assert "Collection of items" in text
|
||||
|
||||
|
||||
# --- Enum defaults ---
|
||||
|
||||
|
||||
def test_enum_default_shows_value() -> None:
|
||||
text = EnumDefaultModel.schema_text()
|
||||
assert "color: Color = 'red'" in text
|
||||
assert "Color.RED" not in text
|
||||
|
||||
|
||||
# --- None defaults ---
|
||||
|
||||
|
||||
def test_none_default() -> None:
|
||||
text = NoneDefaultModel.schema_text()
|
||||
assert "label: str | None = None" in text
|
||||
|
||||
|
||||
# --- Scalar defaults ---
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field_name", "expected"),
|
||||
[
|
||||
("threshold", "threshold: float = 0.5"),
|
||||
("enabled", "enabled: bool = True"),
|
||||
("tag", "tag: str = 'default'"),
|
||||
],
|
||||
ids=["float", "bool", "str"],
|
||||
)
|
||||
def test_scalar_defaults(field_name: str, expected: str) -> None:
|
||||
text = ScalarDefaultModel.schema_text()
|
||||
assert expected in text
|
||||
|
||||
|
||||
# --- Field descriptions ---
|
||||
|
||||
|
||||
def test_description_appears_below_field() -> None:
|
||||
text = DescribedFieldsModel.schema_text()
|
||||
lines = text.splitlines()
|
||||
name_idx = next(i for i, line in enumerate(lines) if "name: str" in line)
|
||||
assert "The name of the thing" in lines[name_idx + 1]
|
||||
|
||||
|
||||
# --- Docstrings ---
|
||||
|
||||
|
||||
def test_docstring_included() -> None:
|
||||
text = DocstringModel.schema_text()
|
||||
assert "A model with a docstring summary." in text
|
||||
|
||||
|
||||
def test_no_docstring_still_works() -> None:
|
||||
text = NoDocstringModel.schema_text()
|
||||
assert text.startswith("NoDocstringModel:")
|
||||
assert "value: int [required]" in text
|
||||
|
||||
|
||||
# --- Header format ---
|
||||
|
||||
|
||||
def test_header_is_class_name() -> None:
|
||||
text = RequiredOnlyModel.schema_text()
|
||||
assert text.startswith("RequiredOnlyModel:")
|
||||
|
||||
|
||||
# --- Single-value Literal fields ---
|
||||
|
||||
|
||||
def test_literal_field_includes_value_and_default() -> None:
|
||||
text = LiteralModel.schema_text()
|
||||
assert "tag: Literal['fixed'] = 'fixed'" in text
|
||||
|
||||
|
||||
# --- Mixed model exercises all variants together ---
|
||||
|
||||
|
||||
def test_mixed_model_all_variants() -> None:
|
||||
text = MixedModel.schema_text()
|
||||
assert "Model exercising every field variant." in text
|
||||
assert "required_field: str [required]" in text
|
||||
assert "optional_none: int | None = None" in text
|
||||
assert "with_default: float = 3.14" in text
|
||||
assert "enum_field: Color = 'green'" in text
|
||||
assert "factory_list:" in text
|
||||
assert "= list()" in text
|
||||
assert "factory_dict:" in text
|
||||
assert "= dict()" in text
|
||||
assert "A mapping" in text
|
||||
assert "PydanticUndefined" not in text
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AgentCommandDef:
|
||||
name: str
|
||||
attr: str
|
||||
help: str
|
||||
command_pattern: str
|
||||
returns: str
|
||||
|
||||
|
||||
AGENT_COMMANDS: tuple[AgentCommandDef, ...] = (
|
||||
AgentCommandDef(
|
||||
name="context",
|
||||
attr="context_command",
|
||||
help="Bootstrap payload with types, state, and builder.",
|
||||
command_pattern="data-designer agent context",
|
||||
returns="agent_context",
|
||||
),
|
||||
AgentCommandDef(
|
||||
name="types",
|
||||
attr="types_command",
|
||||
help="Type names and import paths for one or all families.",
|
||||
command_pattern="data-designer agent types [family]",
|
||||
returns="agent_types",
|
||||
),
|
||||
AgentCommandDef(
|
||||
name="schema",
|
||||
attr="schema_command",
|
||||
help="Schema for a type or entire family.",
|
||||
command_pattern="data-designer agent schema <family> <type> | --all",
|
||||
returns="agent_schema",
|
||||
),
|
||||
AgentCommandDef(
|
||||
name="builder",
|
||||
attr="builder_command",
|
||||
help="ConfigBuilder method surface with signatures.",
|
||||
command_pattern="data-designer agent builder",
|
||||
returns="agent_builder",
|
||||
),
|
||||
AgentCommandDef(
|
||||
name="state.model-aliases",
|
||||
attr="state_model_aliases_command",
|
||||
help="Model aliases and usability status.",
|
||||
command_pattern="data-designer agent state model-aliases",
|
||||
returns="agent_state_model_aliases",
|
||||
),
|
||||
AgentCommandDef(
|
||||
name="state.persona-datasets",
|
||||
attr="state_persona_datasets_command",
|
||||
help="Persona locales and install status.",
|
||||
command_pattern="data-designer agent state persona-datasets",
|
||||
returns="agent_state_persona_datasets",
|
||||
),
|
||||
)
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import typer
|
||||
|
||||
from data_designer.cli.utils.agent_introspection import (
|
||||
AgentIntrospectionError,
|
||||
get_builder_api,
|
||||
get_context,
|
||||
get_model_aliases_state,
|
||||
get_persona_datasets_state,
|
||||
get_schema,
|
||||
get_types,
|
||||
)
|
||||
from data_designer.cli.utils.agent_text_formatter import (
|
||||
format_builder_text,
|
||||
format_context_text,
|
||||
format_model_aliases_text,
|
||||
format_persona_datasets_text,
|
||||
format_schema_text,
|
||||
format_types_text,
|
||||
)
|
||||
from data_designer.config.utils.constants import DATA_DESIGNER_HOME
|
||||
|
||||
|
||||
def context_command() -> None:
|
||||
"""Return a bootstrap payload with types, local state, and builder summary."""
|
||||
_run(lambda: get_context(DATA_DESIGNER_HOME), format_context_text)
|
||||
|
||||
|
||||
def types_command(
|
||||
family: str | None = typer.Argument(None, help="Optional schema family name."),
|
||||
) -> None:
|
||||
"""Return available type names and import paths for one family or all families."""
|
||||
_run(lambda: get_types(family), format_types_text)
|
||||
|
||||
|
||||
def schema_command(
|
||||
family: str = typer.Argument(..., help="Schema family name."),
|
||||
type_name: str | None = typer.Argument(None, help="Type name within the selected family."),
|
||||
all_types: bool = typer.Option(False, "--all", help="Return every schema in the selected family."),
|
||||
) -> None:
|
||||
"""Return schema for a specific type or every type in a family."""
|
||||
_run(lambda: get_schema(family, type_name, all_types=all_types), format_schema_text)
|
||||
|
||||
|
||||
def builder_command() -> None:
|
||||
"""Return the DataDesignerConfigBuilder method surface with signatures and docstrings."""
|
||||
_run(get_builder_api, format_builder_text)
|
||||
|
||||
|
||||
def state_model_aliases_command() -> None:
|
||||
"""Return configured model aliases and whether each one is currently usable."""
|
||||
_run(lambda: get_model_aliases_state(DATA_DESIGNER_HOME), format_model_aliases_text)
|
||||
|
||||
|
||||
def state_persona_datasets_command() -> None:
|
||||
"""Return built-in persona locales and whether each dataset is installed locally."""
|
||||
_run(lambda: get_persona_datasets_state(DATA_DESIGNER_HOME), format_persona_datasets_text)
|
||||
|
||||
|
||||
def _run(get_data: Callable[[], Any], format_text: Callable[[Any], str]) -> None:
|
||||
try:
|
||||
data = get_data()
|
||||
typer.echo(format_text(data))
|
||||
except AgentIntrospectionError as exc:
|
||||
typer.echo(f"Error [{exc.code}]: {exc.message}", err=True)
|
||||
raise typer.Exit(code=1)
|
||||
except Exception as exc:
|
||||
typer.echo(f"Error [internal_error]: {exc}", err=True)
|
||||
raise typer.Exit(code=1)
|
||||
|
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||
|
||||
import typer
|
||||
|
||||
from data_designer.cli.agent_command_defs import AGENT_COMMANDS
|
||||
from data_designer.cli.lazy_group import create_lazy_typer_group
|
||||
from data_designer.cli.runtime import ensure_cli_default_model_settings
|
||||
|
||||
|
|
@ -99,9 +100,37 @@ download_app = typer.Typer(
|
|||
no_args_is_help=True,
|
||||
)
|
||||
|
||||
_AGENT_CMD = f"{_CMD}.agent"
|
||||
|
||||
|
||||
def _build_agent_lazy_group(prefix: str) -> dict[str, dict[str, str]]:
|
||||
return {
|
||||
cmd.name.removeprefix(f"{prefix}."): {"module": _AGENT_CMD, "attr": cmd.attr, "help": cmd.help}
|
||||
for cmd in AGENT_COMMANDS
|
||||
if (prefix == "" and "." not in cmd.name) or cmd.name.startswith(f"{prefix}.")
|
||||
}
|
||||
|
||||
|
||||
agent_app = typer.Typer(
|
||||
name="agent",
|
||||
help="Agent-only interface for dynamic Data Designer introspection",
|
||||
cls=create_lazy_typer_group(_build_agent_lazy_group("")),
|
||||
no_args_is_help=True,
|
||||
)
|
||||
|
||||
agent_state_app = typer.Typer(
|
||||
name="state",
|
||||
help="Return current local state relevant to agents",
|
||||
cls=create_lazy_typer_group(_build_agent_lazy_group("state")),
|
||||
no_args_is_help=True,
|
||||
)
|
||||
|
||||
agent_app.add_typer(agent_state_app, name="state")
|
||||
|
||||
# Add setup command groups
|
||||
app.add_typer(config_app, name="config", rich_help_panel="Setup")
|
||||
app.add_typer(download_app, name="download", rich_help_panel="Setup")
|
||||
app.add_typer(agent_app, name="agent", rich_help_panel="Agent")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,331 @@
|
|||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.metadata
|
||||
import inspect
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, get_args, get_origin
|
||||
|
||||
import data_designer.config as dd
|
||||
from data_designer.cli.agent_command_defs import AGENT_COMMANDS
|
||||
from data_designer.cli.repositories.model_repository import ModelRepository
|
||||
from data_designer.cli.repositories.persona_repository import PersonaRepository
|
||||
from data_designer.cli.repositories.provider_repository import ProviderRepository
|
||||
from data_designer.cli.services.download_service import DownloadService
|
||||
from data_designer.config.column_types import ColumnConfigT
|
||||
from data_designer.config.config_builder import DataDesignerConfigBuilder
|
||||
from data_designer.config.default_model_settings import get_providers_with_missing_api_keys
|
||||
from data_designer.config.processor_types import ProcessorConfigT
|
||||
from data_designer.config.sampler_constraints import ColumnConstraintT
|
||||
from data_designer.config.sampler_params import SamplerParamsT
|
||||
from data_designer.config.validator_params import ValidatorParamsT
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FamilySpec:
|
||||
name: str
|
||||
type_union: Any
|
||||
discriminator_field: str
|
||||
|
||||
|
||||
class AgentIntrospectionError(Exception):
|
||||
def __init__(self, code: str, message: str, details: dict[str, Any] | None = None) -> None:
|
||||
super().__init__(message)
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
|
||||
|
||||
_FAMILY_SPECS: dict[str, FamilySpec] = {
|
||||
"columns": FamilySpec(name="columns", type_union=ColumnConfigT, discriminator_field="column_type"),
|
||||
"samplers": FamilySpec(name="samplers", type_union=SamplerParamsT, discriminator_field="sampler_type"),
|
||||
"validators": FamilySpec(name="validators", type_union=ValidatorParamsT, discriminator_field="validator_type"),
|
||||
"processors": FamilySpec(name="processors", type_union=ProcessorConfigT, discriminator_field="processor_type"),
|
||||
"constraints": FamilySpec(name="constraints", type_union=ColumnConstraintT, discriminator_field="constraint_type"),
|
||||
}
|
||||
|
||||
|
||||
def get_family_names() -> list[str]:
|
||||
return sorted(_FAMILY_SPECS)
|
||||
|
||||
|
||||
def get_library_version() -> str:
|
||||
try:
|
||||
return importlib.metadata.version("data-designer")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def get_family_spec(family: str) -> FamilySpec:
|
||||
spec = _FAMILY_SPECS.get(_normalize_family_name(family))
|
||||
if spec is None:
|
||||
raise AgentIntrospectionError(
|
||||
code="unknown_family",
|
||||
message=f"Unknown family {family!r}.",
|
||||
details={"available_families": get_family_names()},
|
||||
)
|
||||
return spec
|
||||
|
||||
|
||||
def discover_family_types(family: str) -> dict[str, type]:
|
||||
spec = get_family_spec(family)
|
||||
discovered: dict[str, type] = {}
|
||||
for model in get_args(spec.type_union):
|
||||
type_name = _extract_literal_value(model.model_fields[spec.discriminator_field].annotation)
|
||||
if type_name in discovered and discovered[type_name] is not model:
|
||||
raise AgentIntrospectionError(
|
||||
code="duplicate_discriminator_value",
|
||||
message=f"Duplicate discriminator {type_name!r} in family {family!r}.",
|
||||
details={"family": family, "type_name": type_name},
|
||||
)
|
||||
discovered[type_name] = model
|
||||
return dict(sorted(discovered.items()))
|
||||
|
||||
|
||||
def get_import_path(cls: type) -> str:
|
||||
exported = getattr(dd, cls.__name__, None)
|
||||
if exported is cls:
|
||||
return f"data_designer.config.{cls.__name__}"
|
||||
return f"{cls.__module__}.{cls.__qualname__}"
|
||||
|
||||
|
||||
def get_family_catalog(family: str) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"type_name": type_name, "class_name": cls.__name__, "import_path": get_import_path(cls)}
|
||||
for type_name, cls in discover_family_types(family).items()
|
||||
]
|
||||
|
||||
|
||||
def get_family_schema(family: str, type_name: str) -> dict[str, Any]:
|
||||
types_map = discover_family_types(family)
|
||||
cls = types_map.get(type_name)
|
||||
if cls is None:
|
||||
raise AgentIntrospectionError(
|
||||
code="unknown_type",
|
||||
message=f"Unknown type {type_name!r} for family {family!r}.",
|
||||
details={"family": family, "available_types": list(types_map)},
|
||||
)
|
||||
return _build_schema_dict(get_family_spec(family).name, type_name, cls)
|
||||
|
||||
|
||||
def get_family_schemas(family: str) -> dict[str, Any]:
|
||||
spec = get_family_spec(family)
|
||||
types_map = discover_family_types(family)
|
||||
items = [_build_schema_dict(spec.name, tn, cls) for tn, cls in types_map.items()]
|
||||
return {"family": spec.name, "items": items}
|
||||
|
||||
|
||||
def get_builder_api() -> dict[str, Any]:
|
||||
return {
|
||||
"class_name": DataDesignerConfigBuilder.__name__,
|
||||
"import_path": get_import_path(DataDesignerConfigBuilder),
|
||||
"methods": _get_builder_methods(),
|
||||
}
|
||||
|
||||
|
||||
def get_operations() -> list[dict[str, str]]:
|
||||
return [
|
||||
{"name": c.name, "command_pattern": c.command_pattern, "description": c.help, "returns": c.returns}
|
||||
for c in AGENT_COMMANDS
|
||||
]
|
||||
|
||||
|
||||
def get_context(config_dir: Path) -> dict[str, Any]:
|
||||
catalogs = {f: get_family_catalog(f) for f in get_family_names()}
|
||||
return {
|
||||
"operations": get_operations(),
|
||||
"families": [{"family": f, "count": len(items)} for f, items in catalogs.items()],
|
||||
"types": catalogs,
|
||||
"state": {
|
||||
"model_aliases": get_model_aliases_state(config_dir),
|
||||
"persona_datasets": get_persona_datasets_state(config_dir),
|
||||
},
|
||||
"builder": get_builder_api(),
|
||||
}
|
||||
|
||||
|
||||
def get_types(family: str | None) -> dict[str, Any]:
|
||||
if family is None:
|
||||
catalogs = {f: get_family_catalog(f) for f in get_family_names()}
|
||||
return {
|
||||
"families": [{"family": f, "count": len(items)} for f, items in catalogs.items()],
|
||||
"items": catalogs,
|
||||
}
|
||||
return {"family": get_family_spec(family).name, "items": get_family_catalog(family)}
|
||||
|
||||
|
||||
def get_schema(family: str, type_name: str | None, *, all_types: bool) -> dict[str, Any]:
|
||||
if all_types and type_name is not None:
|
||||
raise AgentIntrospectionError(
|
||||
code="invalid_schema_request",
|
||||
message="Provide either a type name or --all, but not both.",
|
||||
details={"family": family, "type_name": type_name, "all": all_types},
|
||||
)
|
||||
if all_types:
|
||||
return get_family_schemas(family)
|
||||
if type_name is None:
|
||||
raise AgentIntrospectionError(
|
||||
code="missing_type_name",
|
||||
message="A type name is required unless --all is provided.",
|
||||
details={"family": family},
|
||||
)
|
||||
return get_family_schema(family, type_name)
|
||||
|
||||
|
||||
def get_model_aliases_state(config_dir: Path) -> dict[str, Any]:
|
||||
model_registry = _load_registry(ModelRepository(config_dir))
|
||||
provider_registry = _load_registry(ProviderRepository(config_dir))
|
||||
|
||||
items: list[dict[str, Any]] = []
|
||||
if model_registry is None:
|
||||
return {
|
||||
"model_config_present": False,
|
||||
"provider_config_present": provider_registry is not None,
|
||||
"default_provider": None if provider_registry is None else provider_registry.default,
|
||||
"items": items,
|
||||
}
|
||||
|
||||
providers_by_name: dict[str, Any] = {}
|
||||
missing_key_names: set[str] = set()
|
||||
default_provider: str | None = None
|
||||
if provider_registry is not None:
|
||||
providers_by_name = {p.name: p for p in provider_registry.providers}
|
||||
default_provider = provider_registry.default or (
|
||||
provider_registry.providers[0].name if provider_registry.providers else None
|
||||
)
|
||||
missing_key_names = {p.name for p in get_providers_with_missing_api_keys(provider_registry.providers)}
|
||||
|
||||
for mc in sorted(model_registry.model_configs, key=lambda m: m.alias):
|
||||
effective = mc.provider or default_provider
|
||||
usable = True
|
||||
reason: str | None = None
|
||||
if effective is None:
|
||||
usable, reason = False, "No model provider is configured."
|
||||
elif effective not in providers_by_name:
|
||||
usable, reason = False, f"Provider {effective!r} is not configured."
|
||||
elif effective in missing_key_names:
|
||||
usable, reason = False, f"Provider {effective!r} is missing an API key."
|
||||
|
||||
items.append(
|
||||
{
|
||||
"model_alias": mc.alias,
|
||||
"model": mc.model,
|
||||
"generation_type": getattr(mc.generation_type, "value", str(mc.generation_type)),
|
||||
"configured_provider": mc.provider,
|
||||
"effective_provider": effective,
|
||||
"usable": usable,
|
||||
"reason": reason,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"model_config_present": True,
|
||||
"provider_config_present": provider_registry is not None,
|
||||
"default_provider": default_provider,
|
||||
"items": items,
|
||||
}
|
||||
|
||||
|
||||
def get_persona_datasets_state(config_dir: Path) -> dict[str, Any]:
|
||||
persona_repo = PersonaRepository()
|
||||
download_svc = DownloadService(config_dir, persona_repo)
|
||||
return {
|
||||
"managed_assets_directory": str(download_svc.get_managed_assets_directory()),
|
||||
"items": [
|
||||
{
|
||||
"locale": loc.code,
|
||||
"dataset_name": loc.dataset_name,
|
||||
"size": loc.size,
|
||||
"installed": download_svc.is_locale_downloaded(loc.code),
|
||||
}
|
||||
for loc in sorted(persona_repo.list_all(), key=lambda loc: loc.code)
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _build_schema_dict(family_name: str, type_name: str, cls: type) -> dict[str, Any]:
|
||||
return {
|
||||
"family": family_name,
|
||||
"type_name": type_name,
|
||||
"class_name": cls.__name__,
|
||||
"import_path": get_import_path(cls),
|
||||
"schema": cls.model_json_schema(),
|
||||
"schema_text": cls.schema_text(),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_family_name(family: str) -> str:
|
||||
normalized = family.strip().lower()
|
||||
if normalized in _FAMILY_SPECS:
|
||||
return normalized
|
||||
plural = f"{normalized}s"
|
||||
if plural in _FAMILY_SPECS:
|
||||
return plural
|
||||
return normalized
|
||||
|
||||
|
||||
def _extract_literal_value(annotation: Any) -> str:
|
||||
if get_origin(annotation) is not Literal or not get_args(annotation):
|
||||
raise AgentIntrospectionError(
|
||||
code="invalid_discriminator_annotation",
|
||||
message=f"Expected non-empty Literal annotation, got {annotation!r}.",
|
||||
)
|
||||
value = get_args(annotation)[0]
|
||||
return str(value.value) if isinstance(value, Enum) else str(value)
|
||||
|
||||
|
||||
def _get_builder_methods() -> list[dict[str, Any]]:
|
||||
methods: list[dict[str, Any]] = []
|
||||
for name, attr in inspect.getmembers(DataDesignerConfigBuilder):
|
||||
if name.startswith("_") and name != "__init__":
|
||||
continue
|
||||
if not callable(attr):
|
||||
continue
|
||||
try:
|
||||
sig = inspect.signature(attr)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
|
||||
docstring = inspect.getdoc(attr)
|
||||
methods.append(
|
||||
{
|
||||
"name": name,
|
||||
"signature": _format_signature(name, sig),
|
||||
"summary": _get_first_line(docstring),
|
||||
"docstring": docstring,
|
||||
}
|
||||
)
|
||||
|
||||
return methods
|
||||
|
||||
|
||||
def _format_signature(method_name: str, sig: inspect.Signature) -> str:
|
||||
params = [p for p in sig.parameters.values() if p.name not in {"self", "cls"}]
|
||||
sig_str = str(sig.replace(parameters=params))
|
||||
sig_str = re.sub(r"\b[a-zA-Z_]\w*(?:\.[a-zA-Z_]\w*)+", lambda m: m.group().rsplit(".", 1)[-1], sig_str)
|
||||
# Strip quotes left by `from __future__ import annotations` around type annotations.
|
||||
sig_str = re.sub(r"(?<=: )'([^']+)'", r"\1", sig_str)
|
||||
sig_str = re.sub(r"(?<=-> )'([^']+)'", r"\1", sig_str)
|
||||
return f"{method_name}{sig_str}"
|
||||
|
||||
|
||||
def _get_first_line(text: str | None) -> str | None:
|
||||
return next((line.strip() for line in text.strip().splitlines() if line.strip()), None) if text else None
|
||||
|
||||
|
||||
def _load_registry(repo: Any) -> Any:
|
||||
if not repo.exists():
|
||||
return None
|
||||
registry = repo.load()
|
||||
if registry is None:
|
||||
raise AgentIntrospectionError(
|
||||
code="invalid_registry",
|
||||
message=f"Failed to load registry from {str(repo.config_file)!r}.",
|
||||
details={"config_file": str(repo.config_file)},
|
||||
)
|
||||
return registry
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from data_designer.cli.utils.agent_introspection import get_library_version
|
||||
|
||||
|
||||
def format_context_text(data: dict[str, Any]) -> str:
|
||||
"""Format the full context payload as sectioned text with tables."""
|
||||
sections = [
|
||||
f"Data Designer v{get_library_version()}",
|
||||
"",
|
||||
"import data_designer.config as dd",
|
||||
"",
|
||||
'A "family" is a group of related config types that share a discriminator field.',
|
||||
"Use dd.<ClassName> to reference any type below.",
|
||||
"",
|
||||
"## Families",
|
||||
"",
|
||||
format_types_text({"families": data["families"], "items": data["types"]}),
|
||||
"",
|
||||
"## Model Aliases",
|
||||
"",
|
||||
format_model_aliases_text(data["state"]["model_aliases"]),
|
||||
"",
|
||||
"## Persona Datasets",
|
||||
"",
|
||||
format_persona_datasets_text(data["state"]["persona_datasets"]),
|
||||
"",
|
||||
"## Builder",
|
||||
"",
|
||||
format_builder_text(data["builder"]),
|
||||
"",
|
||||
"## Commands",
|
||||
"",
|
||||
_format_table(data["operations"], ["command_pattern", "description"]),
|
||||
]
|
||||
return "\n".join(sections)
|
||||
|
||||
|
||||
def format_types_text(data: dict[str, Any]) -> str:
|
||||
"""Format type listings for one family or all families."""
|
||||
if "families" in data:
|
||||
lines: list[str] = [f"{f['family']}: {f['count']} types" for f in data["families"]]
|
||||
lines.append("")
|
||||
for family_name, items in data["items"].items():
|
||||
lines.append(_format_table(items, ["type_name", "class_name"], title=family_name))
|
||||
lines.append("")
|
||||
return "\n".join(lines).rstrip()
|
||||
return _format_table(
|
||||
data["items"],
|
||||
["type_name", "class_name"],
|
||||
title=data.get("family"),
|
||||
)
|
||||
|
||||
|
||||
def format_schema_text(data: dict[str, Any]) -> str:
|
||||
"""Format schema data as human-readable field summaries."""
|
||||
if "items" in data:
|
||||
header = f"# {data['family']} schemas ({len(data['items'])} types)"
|
||||
schemas = "\n\n".join(item["schema_text"] for item in data["items"])
|
||||
return f"{header}\n\n{schemas}"
|
||||
return data["schema_text"]
|
||||
|
||||
|
||||
def format_builder_text(data: dict[str, Any]) -> str:
|
||||
"""Format builder methods with signatures."""
|
||||
path = data["import_path"]
|
||||
hint = f"dd.{path.removeprefix('data_designer.config.')}" if path.startswith("data_designer.config.") else path
|
||||
lines: list[str] = [
|
||||
f"{data['class_name']}:",
|
||||
f" usage: {hint}",
|
||||
" methods:",
|
||||
"",
|
||||
]
|
||||
for method in data["methods"]:
|
||||
lines.append(f" {method['signature']}")
|
||||
if method.get("summary"):
|
||||
lines.append(f" {method['summary']}")
|
||||
lines.append("")
|
||||
return "\n".join(lines).rstrip()
|
||||
|
||||
|
||||
def format_model_aliases_text(state: dict[str, Any]) -> str:
|
||||
"""Format model aliases as a text table with provider summary."""
|
||||
lines: list[str] = [f"default_provider: {state.get('default_provider') or '(none)'}", ""]
|
||||
lines.append(
|
||||
_format_table(
|
||||
state.get("items", []),
|
||||
["model_alias", "model", "generation_type", "effective_provider", "usable", "reason"],
|
||||
column_labels={"effective_provider": "provider"},
|
||||
title="model aliases",
|
||||
)
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_persona_datasets_text(state: dict[str, Any]) -> str:
|
||||
"""Format persona datasets as a text table."""
|
||||
return _format_table(state.get("items", []), ["locale", "size", "installed"], title="persona datasets")
|
||||
|
||||
|
||||
def _format_table(
|
||||
items: list[dict[str, Any]],
|
||||
columns: list[str],
|
||||
*,
|
||||
title: str | None = None,
|
||||
column_labels: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
labels = {col: (column_labels or {}).get(col, col) for col in columns}
|
||||
|
||||
if not items:
|
||||
header = f"# {title}" if title else "# table"
|
||||
return f"{header}\n(no items)"
|
||||
|
||||
col_widths = {col: max(len(labels[col]), max(len(_cell(row.get(col))) for row in items)) for col in columns}
|
||||
|
||||
lines: list[str] = []
|
||||
if title:
|
||||
lines.append(f"# {title}")
|
||||
lines.append("")
|
||||
lines.append(" ".join(f"{labels[col]:<{col_widths[col]}}" for col in columns))
|
||||
lines.append(" ".join("-" * col_widths[col] for col in columns))
|
||||
for row in items:
|
||||
lines.append(" ".join(f"{_cell(row.get(col)):<{col_widths[col]}}" for col in columns))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _cell(value: Any) -> str:
|
||||
"""Convert a cell value to a string, rendering None as empty."""
|
||||
if value is None:
|
||||
return ""
|
||||
return str(value)
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from data_designer.cli.main import app
|
||||
|
||||
_PATCH = "data_designer.cli.commands.agent"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"args,data_fn,format_fn,expected_text",
|
||||
[
|
||||
(["agent", "context"], "get_context", "format_context_text", "Data Designer"),
|
||||
(["agent", "types", "columns"], "get_types", "format_types_text", "columns"),
|
||||
(["agent", "builder"], "get_builder_api", "format_builder_text", "Builder:"),
|
||||
(["agent", "state", "model-aliases"], "get_model_aliases_state", "format_model_aliases_text", "model aliases"),
|
||||
(
|
||||
["agent", "state", "persona-datasets"],
|
||||
"get_persona_datasets_state",
|
||||
"format_persona_datasets_text",
|
||||
"persona",
|
||||
),
|
||||
],
|
||||
ids=["context", "types", "builder", "model-aliases", "persona-datasets"],
|
||||
)
|
||||
def test_commands_default_text_mode(args: list[str], data_fn: str, format_fn: str, expected_text: str) -> None:
|
||||
runner = CliRunner()
|
||||
with (
|
||||
patch(f"{_PATCH}.{data_fn}", return_value={"stub": True}) as mock_get,
|
||||
patch(f"{_PATCH}.{format_fn}", return_value=expected_text),
|
||||
):
|
||||
result = runner.invoke(app, args)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert expected_text in result.output
|
||||
mock_get.assert_called_once()
|
||||
|
||||
|
||||
def test_schema_command_default_outputs_text() -> None:
|
||||
runner = CliRunner()
|
||||
with (
|
||||
patch(f"{_PATCH}.get_schema", return_value={"type_name": "llm-text", "schema": {}}),
|
||||
patch(f"{_PATCH}.format_schema_text", return_value="# llm-text\n{}"),
|
||||
):
|
||||
result = runner.invoke(app, ["agent", "schema", "columns", "llm-text"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "llm-text" in result.output
|
||||
|
||||
|
||||
def test_error_outputs_message_to_stderr() -> None:
|
||||
runner = CliRunner()
|
||||
with patch(f"{_PATCH}.get_schema", side_effect=ValueError("boom")):
|
||||
result = runner.invoke(app, ["agent", "schema", "columns", "missing"])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert result.stdout == ""
|
||||
assert "internal_error" in result.stderr
|
||||
assert "boom" in result.stderr
|
||||
|
|
@ -0,0 +1,104 @@
|
|||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from data_designer.cli.utils.agent_introspection import (
|
||||
AgentIntrospectionError,
|
||||
discover_family_types,
|
||||
get_builder_api,
|
||||
get_family_catalog,
|
||||
get_family_schema,
|
||||
get_family_spec,
|
||||
get_operations,
|
||||
get_types,
|
||||
)
|
||||
|
||||
|
||||
def test_get_family_catalog_accepts_singular_family_names() -> None:
|
||||
assert get_family_catalog("validator") == get_family_catalog("validators")
|
||||
|
||||
|
||||
def test_get_family_catalog_returns_sorted_type_names() -> None:
|
||||
catalog = get_family_catalog("columns")
|
||||
assert catalog
|
||||
assert [item["type_name"] for item in catalog] == sorted(item["type_name"] for item in catalog)
|
||||
|
||||
|
||||
def test_get_family_schema_returns_json_schema_payload() -> None:
|
||||
schema_payload = get_family_schema("validator", "code")
|
||||
|
||||
assert schema_payload["family"] == "validators"
|
||||
assert schema_payload["type_name"] == "code"
|
||||
assert schema_payload["class_name"] == "CodeValidatorParams"
|
||||
assert schema_payload["import_path"] == "data_designer.config.CodeValidatorParams"
|
||||
assert schema_payload["schema"]["title"] == "CodeValidatorParams"
|
||||
|
||||
|
||||
def test_get_family_schema_includes_schema_text() -> None:
|
||||
schema_payload = get_family_schema("columns", "llm-text")
|
||||
|
||||
text = schema_payload["schema_text"]
|
||||
assert text.startswith("LLMTextColumnConfig:")
|
||||
assert "column_type:" in text
|
||||
assert "name:" in text
|
||||
assert "Configuration for text generation" in text
|
||||
assert "Jinja2 template" in text
|
||||
|
||||
|
||||
def test_get_family_schema_raises_for_unknown_type() -> None:
|
||||
with pytest.raises(AgentIntrospectionError) as exc_info:
|
||||
get_family_schema("validators", "does-not-exist")
|
||||
|
||||
assert exc_info.value.code == "unknown_type"
|
||||
assert exc_info.value.details["family"] == "validators"
|
||||
assert "code" in exc_info.value.details["available_types"]
|
||||
|
||||
|
||||
def test_discover_family_types_returns_pydantic_classes() -> None:
|
||||
types_map = discover_family_types("columns")
|
||||
|
||||
assert types_map
|
||||
assert all(hasattr(cls, "model_fields") for cls in types_map.values())
|
||||
|
||||
|
||||
def test_get_family_spec_returns_discriminator_field() -> None:
|
||||
spec = get_family_spec("columns")
|
||||
|
||||
assert spec.name == "columns"
|
||||
assert spec.discriminator_field == "column_type"
|
||||
|
||||
|
||||
def test_get_builder_api_includes_docstrings() -> None:
|
||||
builder_api = get_builder_api()
|
||||
|
||||
assert builder_api["class_name"] == "DataDesignerConfigBuilder"
|
||||
assert builder_api["import_path"] == "data_designer.config.DataDesignerConfigBuilder"
|
||||
assert builder_api["methods"]
|
||||
assert all("docstring" in method for method in builder_api["methods"])
|
||||
|
||||
|
||||
def test_get_types_returns_all_families_when_no_family_given() -> None:
|
||||
data = get_types(None)
|
||||
|
||||
assert "families" in data
|
||||
assert "items" in data
|
||||
assert len(data["families"]) > 0
|
||||
assert all(f["family"] in data["items"] for f in data["families"])
|
||||
|
||||
|
||||
def test_get_types_returns_single_family() -> None:
|
||||
data = get_types("columns")
|
||||
|
||||
assert data["family"] == "columns"
|
||||
assert isinstance(data["items"], list)
|
||||
assert len(data["items"]) > 0
|
||||
|
||||
|
||||
def test_get_operations_returns_all_commands() -> None:
|
||||
ops = get_operations()
|
||||
|
||||
assert len(ops) == 6
|
||||
assert all("name" in op and "command_pattern" in op and "description" in op for op in ops)
|
||||
|
|
@ -0,0 +1,140 @@
|
|||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from data_designer.cli.repositories.model_repository import ModelConfigRegistry, ModelRepository
|
||||
from data_designer.cli.repositories.provider_repository import ModelProviderRegistry, ProviderRepository
|
||||
from data_designer.cli.utils.agent_introspection import (
|
||||
get_context,
|
||||
get_model_aliases_state,
|
||||
get_persona_datasets_state,
|
||||
)
|
||||
from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig, ModelProvider
|
||||
|
||||
|
||||
def test_get_model_aliases_state_reports_provider_status(tmp_path: Path) -> None:
|
||||
provider_repository = ProviderRepository(tmp_path)
|
||||
provider_repository.save(
|
||||
ModelProviderRegistry(
|
||||
providers=[
|
||||
ModelProvider(
|
||||
name="provider-a",
|
||||
endpoint="https://api.example.com/a",
|
||||
provider_type="openai",
|
||||
api_key="test-api-key",
|
||||
),
|
||||
ModelProvider(
|
||||
name="provider-b",
|
||||
endpoint="https://api.example.com/b",
|
||||
provider_type="openai",
|
||||
api_key="MISSING_PROVIDER_KEY",
|
||||
),
|
||||
],
|
||||
default="provider-a",
|
||||
)
|
||||
)
|
||||
|
||||
model_repository = ModelRepository(tmp_path)
|
||||
model_repository.save(
|
||||
ModelConfigRegistry(
|
||||
model_configs=[
|
||||
ModelConfig(
|
||||
alias="alpha",
|
||||
model="model-alpha",
|
||||
provider=None,
|
||||
inference_parameters=ChatCompletionInferenceParams(),
|
||||
),
|
||||
ModelConfig(
|
||||
alias="beta",
|
||||
model="model-beta",
|
||||
provider="provider-b",
|
||||
inference_parameters=ChatCompletionInferenceParams(),
|
||||
),
|
||||
ModelConfig(
|
||||
alias="gamma",
|
||||
model="model-gamma",
|
||||
provider="provider-missing",
|
||||
inference_parameters=ChatCompletionInferenceParams(),
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
payload = get_model_aliases_state(tmp_path)
|
||||
|
||||
assert payload["model_config_present"] is True
|
||||
assert payload["provider_config_present"] is True
|
||||
assert payload["default_provider"] == "provider-a"
|
||||
assert payload["items"] == [
|
||||
{
|
||||
"model_alias": "alpha",
|
||||
"model": "model-alpha",
|
||||
"generation_type": "chat-completion",
|
||||
"configured_provider": None,
|
||||
"effective_provider": "provider-a",
|
||||
"usable": True,
|
||||
"reason": None,
|
||||
},
|
||||
{
|
||||
"model_alias": "beta",
|
||||
"model": "model-beta",
|
||||
"generation_type": "chat-completion",
|
||||
"configured_provider": "provider-b",
|
||||
"effective_provider": "provider-b",
|
||||
"usable": False,
|
||||
"reason": "Provider 'provider-b' is missing an API key.",
|
||||
},
|
||||
{
|
||||
"model_alias": "gamma",
|
||||
"model": "model-gamma",
|
||||
"generation_type": "chat-completion",
|
||||
"configured_provider": "provider-missing",
|
||||
"effective_provider": "provider-missing",
|
||||
"usable": False,
|
||||
"reason": "Provider 'provider-missing' is not configured.",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_get_model_aliases_state_handles_missing_local_files(tmp_path: Path) -> None:
|
||||
payload = get_model_aliases_state(tmp_path)
|
||||
|
||||
assert payload == {
|
||||
"model_config_present": False,
|
||||
"provider_config_present": False,
|
||||
"default_provider": None,
|
||||
"items": [],
|
||||
}
|
||||
|
||||
|
||||
def test_get_persona_datasets_state_reports_installed_locales(tmp_path: Path) -> None:
|
||||
managed_assets_dir = tmp_path / "managed-assets" / "datasets"
|
||||
managed_assets_dir.mkdir(parents=True)
|
||||
(managed_assets_dir / "en_US.parquet").write_text("stub")
|
||||
|
||||
payload = get_persona_datasets_state(tmp_path)
|
||||
|
||||
assert payload["managed_assets_directory"] == str(managed_assets_dir)
|
||||
installed_by_locale = {item["locale"]: item["installed"] for item in payload["items"]}
|
||||
assert installed_by_locale["en_US"] is True
|
||||
assert any(not item["installed"] for item in payload["items"] if item["locale"] != "en_US")
|
||||
|
||||
|
||||
def test_get_context_returns_self_describing_payload(tmp_path: Path) -> None:
|
||||
payload = get_context(tmp_path)
|
||||
|
||||
operation_names = [operation["name"] for operation in payload["operations"]]
|
||||
assert operation_names == [
|
||||
"context",
|
||||
"types",
|
||||
"schema",
|
||||
"builder",
|
||||
"state.model-aliases",
|
||||
"state.persona-datasets",
|
||||
]
|
||||
assert payload["families"]
|
||||
assert "columns" in payload["types"]
|
||||
assert payload["builder"]["methods"]
|
||||
|
|
@ -0,0 +1,223 @@
|
|||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from data_designer.cli.utils.agent_introspection import get_family_schema
|
||||
from data_designer.cli.utils.agent_text_formatter import (
|
||||
format_builder_text,
|
||||
format_context_text,
|
||||
format_model_aliases_text,
|
||||
format_persona_datasets_text,
|
||||
format_schema_text,
|
||||
format_types_text,
|
||||
)
|
||||
|
||||
# --- format_context_text ---
|
||||
|
||||
|
||||
def test_format_context_text_includes_builder_section() -> None:
|
||||
data: dict[str, Any] = {
|
||||
"families": [{"family": "columns", "count": 1}],
|
||||
"types": {
|
||||
"columns": [{"type_name": "a", "class_name": "A", "import_path": "m.A"}],
|
||||
},
|
||||
"state": {
|
||||
"model_aliases": {"default_provider": None, "items": []},
|
||||
"persona_datasets": {"items": []},
|
||||
},
|
||||
"builder": {
|
||||
"class_name": "DataDesignerConfigBuilder",
|
||||
"import_path": "data_designer.config.DataDesignerConfigBuilder",
|
||||
"methods": [{"name": "add_column", "signature": "add_column(col)", "summary": "Add a column."}],
|
||||
},
|
||||
"operations": [{"command_pattern": "agent context", "description": "Bootstrap payload."}],
|
||||
}
|
||||
result = format_context_text(data)
|
||||
|
||||
assert "## Builder" in result
|
||||
assert "DataDesignerConfigBuilder:" in result
|
||||
assert "add_column(col)" in result
|
||||
|
||||
|
||||
# --- format_types_text ---
|
||||
|
||||
|
||||
def test_format_types_text_single_family() -> None:
|
||||
data: dict[str, Any] = {
|
||||
"family": "columns",
|
||||
"items": [
|
||||
{"type_name": "alpha", "class_name": "AlphaConfig", "import_path": "mod.AlphaConfig"},
|
||||
{"type_name": "beta", "class_name": "BetaConfig", "import_path": "mod.BetaConfig"},
|
||||
],
|
||||
}
|
||||
result = format_types_text(data)
|
||||
|
||||
assert "# columns" in result
|
||||
assert "alpha" in result
|
||||
assert "AlphaConfig" in result
|
||||
|
||||
|
||||
def test_format_types_text_all_families() -> None:
|
||||
data: dict[str, Any] = {
|
||||
"families": [{"family": "columns", "count": 2}],
|
||||
"items": {
|
||||
"columns": [
|
||||
{"type_name": "a", "class_name": "A", "import_path": "m.A"},
|
||||
{"type_name": "b", "class_name": "B", "import_path": "m.B"},
|
||||
],
|
||||
},
|
||||
}
|
||||
result = format_types_text(data)
|
||||
|
||||
assert "columns: 2 types" in result
|
||||
assert "a" in result
|
||||
assert "b" in result
|
||||
|
||||
|
||||
def test_format_types_text_empty_items() -> None:
|
||||
data: dict[str, Any] = {"family": "columns", "items": []}
|
||||
result = format_types_text(data)
|
||||
|
||||
assert "(no items)" in result
|
||||
|
||||
|
||||
# --- format_schema_text ---
|
||||
|
||||
|
||||
def test_format_schema_text_single_type() -> None:
|
||||
data: dict[str, Any] = {
|
||||
"type_name": "llm-text",
|
||||
"class_name": "LLMTextColumnConfig",
|
||||
"schema_text": "LLMTextColumnConfig:\n name: str [required]",
|
||||
}
|
||||
result = format_schema_text(data)
|
||||
|
||||
assert "LLMTextColumnConfig:" in result
|
||||
assert "name: str [required]" in result
|
||||
|
||||
|
||||
def test_format_schema_text_all_types() -> None:
|
||||
data: dict[str, Any] = {
|
||||
"family": "columns",
|
||||
"items": [
|
||||
{"type_name": "a", "class_name": "A", "schema_text": "A:\n x: int [required]"},
|
||||
{"type_name": "b", "class_name": "B", "schema_text": "B:\n y: str = 'hi'"},
|
||||
],
|
||||
}
|
||||
result = format_schema_text(data)
|
||||
|
||||
assert "# columns schemas (2 types)" in result
|
||||
assert "A:\n x: int [required]" in result
|
||||
assert "B:\n y: str = 'hi'" in result
|
||||
|
||||
|
||||
def test_format_schema_text_passes_through_schema_text() -> None:
|
||||
schema_text = "TestModel:\n name: str [required]\n count: int = 0"
|
||||
data: dict[str, Any] = {"type_name": "test", "class_name": "TestModel", "schema_text": schema_text}
|
||||
result = format_schema_text(data)
|
||||
|
||||
assert result == schema_text
|
||||
|
||||
|
||||
# --- format_builder_text ---
|
||||
|
||||
|
||||
def test_format_builder_text_renders_methods() -> None:
|
||||
data: dict[str, Any] = {
|
||||
"class_name": "MyBuilder",
|
||||
"import_path": "data_designer.config.MyBuilder",
|
||||
"methods": [
|
||||
{"name": "add_column", "signature": "add_column(column: ColumnConfig)", "summary": "Add a column."},
|
||||
{"name": "build", "signature": "build()", "summary": "Build the config."},
|
||||
],
|
||||
}
|
||||
result = format_builder_text(data)
|
||||
|
||||
assert "MyBuilder:" in result
|
||||
assert "usage: dd.MyBuilder" in result
|
||||
assert "add_column(column: ColumnConfig)" in result
|
||||
assert "Add a column." in result
|
||||
|
||||
|
||||
def test_format_builder_text_handles_method_without_summary() -> None:
|
||||
data: dict[str, Any] = {
|
||||
"class_name": "Builder",
|
||||
"import_path": "mod.Builder",
|
||||
"methods": [{"name": "reset", "signature": "reset()", "summary": None}],
|
||||
}
|
||||
result = format_builder_text(data)
|
||||
|
||||
assert "reset()" in result
|
||||
|
||||
|
||||
# --- format_model_aliases_text ---
|
||||
|
||||
|
||||
def test_format_model_aliases_text_with_items() -> None:
|
||||
state: dict[str, Any] = {
|
||||
"default_provider": "nvidia",
|
||||
"items": [
|
||||
{
|
||||
"model_alias": "test",
|
||||
"model": "meta/llama-3",
|
||||
"generation_type": "chat",
|
||||
"effective_provider": "nvidia",
|
||||
"usable": True,
|
||||
"reason": None,
|
||||
},
|
||||
],
|
||||
}
|
||||
result = format_model_aliases_text(state)
|
||||
|
||||
assert "default_provider: nvidia" in result
|
||||
assert "test" in result
|
||||
assert "meta/llama-3" in result
|
||||
|
||||
|
||||
def test_format_model_aliases_text_empty() -> None:
|
||||
state: dict[str, Any] = {"default_provider": None, "items": []}
|
||||
result = format_model_aliases_text(state)
|
||||
|
||||
assert "default_provider: (none)" in result
|
||||
assert "(no items)" in result
|
||||
|
||||
|
||||
# --- format_persona_datasets_text ---
|
||||
|
||||
|
||||
def test_format_persona_datasets_text() -> None:
|
||||
state: dict[str, Any] = {
|
||||
"items": [{"locale": "en_US", "size": "10MB", "installed": True}],
|
||||
}
|
||||
result = format_persona_datasets_text(state)
|
||||
|
||||
assert "# persona datasets" in result
|
||||
assert "en_US" in result
|
||||
assert "True" in result
|
||||
|
||||
|
||||
# --- Real config models ---
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"family,type_name",
|
||||
[
|
||||
("columns", "llm-text"),
|
||||
("columns", "sampler"),
|
||||
("samplers", "category"),
|
||||
("validators", "code"),
|
||||
("constraints", "scalar_inequality"),
|
||||
],
|
||||
ids=["columns-llm-text", "columns-sampler", "samplers-category", "validators-code", "constraints-scalar"],
|
||||
)
|
||||
def test_format_schema_text_on_real_config_models(family: str, type_name: str) -> None:
|
||||
schema_data = get_family_schema(family, type_name)
|
||||
result = format_schema_text(schema_data)
|
||||
|
||||
assert schema_data["class_name"] in result
|
||||
assert result == schema_data["schema_text"]
|
||||
Loading…
Reference in a new issue