feat: add seed dataset support with configuration, preview, and builder utilities

This commit is contained in:
Shine1i 2026-02-14 18:44:38 +01:00
parent 2bd20d7d15
commit f2a00d6e44
25 changed files with 1035 additions and 34 deletions

View file

@ -4,6 +4,8 @@
"ignore": [
"dist",
"node_modules",
"test",
"test/**",
"**/._*",
"._*",
"**/.DS_Store",

View file

@ -28,6 +28,19 @@ export type ToolsResponse = {
tools: string[];
};
export type SeedInspectResponse = {
// biome-ignore lint/style/useNamingConvention: api schema
repo_id: string;
splits: string[];
// biome-ignore lint/style/useNamingConvention: api schema
globs_by_split: Record<string, string>;
columns: string[];
};
export type SeedPreviewResponse = {
rows: Record<string, unknown>[];
};
async function parseErrorResponse(response: Response): Promise<string> {
const text = (await response.text()).trim();
if (!text) {
@ -80,3 +93,11 @@ export async function validateRecipe(
export async function listRecipeTools(payload: unknown): Promise<ToolsResponse> {
return postJson<ToolsResponse>("/tools", payload);
}
export async function inspectSeedDataset(payload: unknown): Promise<SeedInspectResponse> {
return postJson<SeedInspectResponse>("/seed/inspect", payload);
}
export async function previewSeedDataset(payload: unknown): Promise<SeedPreviewResponse> {
return postJson<SeedPreviewResponse>("/seed/preview", payload);
}

View file

@ -23,11 +23,13 @@ import {
makeModelConfig,
makeModelProviderConfig,
makeSamplerConfig,
makeSeedConfig,
} from "../utils";
import { ExpressionDialog } from "../dialogs/expression/expression-dialog";
import { LlmDialog } from "../dialogs/llm/llm-dialog";
import { ModelConfigDialog } from "../dialogs/models/model-config-dialog";
import { ModelProviderDialog } from "../dialogs/models/model-provider-dialog";
import { SeedDialog } from "../dialogs/seed/seed-dialog";
import { CategoryDialog } from "../dialogs/samplers/category-dialog";
import { DatetimeDialog } from "../dialogs/samplers/datetime-dialog";
import { BernoulliDialog } from "../dialogs/samplers/bernoulli-dialog";
@ -38,11 +40,12 @@ import { TimedeltaDialog } from "../dialogs/samplers/timedelta-dialog";
import { UniformDialog } from "../dialogs/samplers/uniform-dialog";
import { UuidDialog } from "../dialogs/samplers/uuid-dialog";
export type BlockKind = "sampler" | "llm" | "expression";
export type BlockKind = "sampler" | "llm" | "expression" | "seed";
export type BlockType =
| SamplerType
| LlmType
| "expression"
| "seed"
| "model_provider"
| "model_config";
@ -81,6 +84,12 @@ export const BLOCK_GROUPS: BlockGroup[] = [
description: "Numeric + categorical blocks.",
icon: DiceFaces03Icon,
},
{
kind: "seed",
title: "Seed",
description: "Columns from a seed dataset.",
icon: Plant01Icon,
},
{
kind: "llm",
title: "LLM",
@ -96,6 +105,21 @@ export const BLOCK_GROUPS: BlockGroup[] = [
];
const BLOCK_DEFINITIONS: BlockDefinition[] = [
{
kind: "seed",
type: "seed",
title: "Seed (Hugging Face)",
description: "Configure a HF seed dataset.",
icon: Plant01Icon,
createConfig: (id, existing) => makeSeedConfig(id, existing),
renderDialog: ({ config, onUpdate }) =>
config.kind === "seed" ? (
<SeedDialog
config={config}
onUpdate={(patch) => onUpdate(config.id, patch)}
/>
) : null,
},
{
kind: "sampler",
type: "category",
@ -372,6 +396,9 @@ export function getBlockDefinitionForConfig(
if (!config) {
return null;
}
if (config.kind === "seed") {
return getBlockDefinition("seed", "seed");
}
if (config.kind === "sampler") {
const samplerType =
config.sampler_type === "person_from_faker"

View file

@ -22,8 +22,8 @@ import { RECIPE_FLOATING_ICON_BUTTON_CLASS } from "./recipe-floating-icon-button
import type { LlmType, SamplerType } from "../types";
import { BLOCK_GROUPS, getBlocksForKind } from "../blocks/registry";
type SheetView = "root" | "sampler" | "llm" | "expression" | "processor";
type SheetKind = "sampler" | "llm" | "expression";
type SheetView = "root" | "sampler" | "seed" | "llm" | "expression" | "processor";
type SheetKind = "sampler" | "seed" | "llm" | "expression";
type RootSheetView = Exclude<SheetView, "root">;
type RootGroup = {
kind: RootSheetView;
@ -37,6 +37,7 @@ type BlockSheetProps = {
sheetView: SheetView;
onViewChange: (sheetView: SheetView) => void;
onAddSampler: (type: SamplerType) => void;
onAddSeed: () => void;
onAddLlm: (type: LlmType) => void;
onAddModelProvider: () => void;
onAddModelConfig: () => void;
@ -54,6 +55,9 @@ function getSheetTitle(sheetView: SheetView): string {
if (sheetView === "sampler") {
return "Sampler blocks";
}
if (sheetView === "seed") {
return "Seed blocks";
}
if (sheetView === "expression") {
return "Expression blocks";
}
@ -66,6 +70,7 @@ function getSheetTitle(sheetView: SheetView): string {
const VIEW_KIND: Record<SheetView, SheetKind | null> = {
root: null,
sampler: "sampler",
seed: "seed",
llm: "llm",
expression: "expression",
processor: null,
@ -124,6 +129,7 @@ export function BlockSheet({
sheetView,
onViewChange,
onAddSampler,
onAddSeed,
onAddLlm,
onAddModelProvider,
onAddModelConfig,
@ -136,6 +142,7 @@ export function BlockSheet({
const sheetTitle = getSheetTitle(sheetView);
const [open, setOpen] = useState(false);
const expressionBlocks = useMemo(() => getBlocksForKind("expression"), []);
const seedBlocks = useMemo(() => getBlocksForKind("seed"), []);
return (
<div className="flex flex-col items-end gap-2">
<Sheet
@ -198,6 +205,11 @@ export function BlockSheet({
onOpenProcessors();
return;
}
if (item.kind === "seed" && seedBlocks.length === 1) {
setOpen(false);
onAddSeed();
return;
}
if (item.kind === "expression" && expressionBlocks.length === 1) {
setOpen(false);
onAddExpression();
@ -229,6 +241,8 @@ export function BlockSheet({
onClick={() => {
if (item.kind === "sampler") {
onAddSampler(item.type as SamplerType);
} else if (item.kind === "seed") {
onAddSeed();
} else if (item.kind === "llm") {
if (item.type === "model_provider") {
onAddModelProvider();

View file

@ -15,6 +15,7 @@ import {
FunctionIcon,
Parabola02Icon,
PencilEdit02Icon,
Plant01Icon,
Tag01Icon,
TagsIcon,
UserAccountIcon,
@ -76,6 +77,7 @@ function getJinjaContext(
function getItemIcon(item: AvailableRefItem) {
if (item.kind === "expression") return FunctionIcon;
if (item.kind === "seed") return Plant01Icon;
if (item.kind === "llm") {
if (item.subtype === "structured") return CodeIcon;
if (item.subtype === "code") return CodeSimpleIcon;

View file

@ -61,6 +61,9 @@ const NODE_META = {
expression: {
tone: "bg-indigo-50 text-indigo-600 border-indigo-100",
},
seed: {
tone: "bg-lime-50 text-lime-700 border-lime-100",
},
model_provider: {
tone: "bg-amber-50 text-amber-600 border-amber-100",
},
@ -108,6 +111,9 @@ function resolveNodeIcon(
if (kind === "model_config") {
return Plant01Icon;
}
if (kind === "seed") {
return Plant01Icon;
}
return DiceFaces03Icon;
}
@ -163,6 +169,13 @@ function getConfigSummary(config: NodeConfig | undefined): string {
return "Prompt/system via linked input nodes";
}
if (config.kind === "seed") {
if (config.hf_path.trim()) {
return config.hf_path.trim();
}
return "Set HF dataset path";
}
return "Open details for config";
}
@ -292,7 +305,9 @@ function RecipeGraphNodeBase({
}, [id, layoutDirection, config, updateNodeInternals]);
const showDataHandles =
data.kind === "llm" || data.kind === "expression" || data.kind === "sampler";
data.kind === "llm" ||
data.kind === "expression" ||
data.kind === "sampler";
const showSemanticIn = data.kind === "llm" || data.kind === "model_config";
const showSemanticOut = data.kind === "model_config" || data.kind === "model_provider";
const isTopBottom = layoutDirection === "TB";

View file

@ -2,8 +2,8 @@ import { Button } from "@/components/ui/button";
import { Dialog, DialogContent, DialogFooter } from "@/components/ui/dialog";
import { Switch } from "@/components/ui/switch";
import type { ReactElement } from "react";
import type { NodeConfig, SamplerConfig } from "../types";
import { renderBlockDialog } from "../blocks/registry";
import type { NodeConfig, SamplerConfig } from "../types";
import { DialogShell } from "./shared/dialog-shell";
import { ValidationBanner } from "./shared/validation-banner";
@ -50,7 +50,8 @@ export function ConfigDialog({
<ValidationBanner config={config} />
{(config.kind === "sampler" ||
config.kind === "llm" ||
config.kind === "expression") && (
config.kind === "expression" ||
config.kind === "seed") && (
<div className="flex items-center corner-squircle justify-between gap-3 rounded-2xl border border-border/60 px-3 py-2">
<div>
<p className="text-sm font-semibold">Drop from final dataset</p>

View file

@ -0,0 +1,486 @@
import { Button } from "@/components/ui/button";
import {
Empty,
EmptyContent,
EmptyDescription,
EmptyHeader,
EmptyTitle,
} from "@/components/ui/empty";
import { Input } from "@/components/ui/input";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Spinner } from "@/components/ui/spinner";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import {
Tabs,
TabsContent,
TabsList,
TabsTrigger,
} from "@/components/ui/tabs";
import { type ReactElement, useMemo, useState } from "react";
import { inspectSeedDataset, previewSeedDataset } from "../../api";
import type {
SeedConfig,
SeedSamplingStrategy,
SeedSelectionType,
} from "../../types";
const SAMPLING_OPTIONS: Array<{ value: SeedSamplingStrategy; label: string }> = [
{ value: "ordered", label: "Ordered" },
{ value: "shuffle", label: "Shuffle" },
];
const SELECTION_OPTIONS: Array<{ value: SeedSelectionType; label: string }> = [
{ value: "none", label: "None" },
{ value: "index_range", label: "Index range" },
{ value: "partition_block", label: "Partition block" },
];
type SeedDialogProps = {
config: SeedConfig;
onUpdate: (patch: Partial<SeedConfig>) => void;
};
function parseHfDatasetRepoId(input: string): string | null {
const raw = input.trim();
if (!raw) return null;
if (!raw.includes("://") && raw.split("/").length === 2) {
return raw;
}
try {
const url = new URL(raw);
const parts = url.pathname.split("/").filter(Boolean);
const datasetsIdx = parts.indexOf("datasets");
if (datasetsIdx === -1) return null;
const org = parts[datasetsIdx + 1];
const repo = parts[datasetsIdx + 2];
if (!org || !repo) return null;
return `${org}/${repo}`;
} catch {
return null;
}
}
function stringifyCell(value: unknown): string {
if (value === null || value === undefined) return "";
if (typeof value === "string") return value;
if (typeof value === "number" || typeof value === "boolean") return String(value);
try {
return JSON.stringify(value);
} catch {
return String(value);
}
}
function parseOptionalInt(value: string | undefined): number | null {
const trimmed = value?.trim();
if (!trimmed) return null;
const num = Number(trimmed);
return Number.isFinite(num) ? num : null;
}
export function SeedDialog({ config, onUpdate }: SeedDialogProps): ReactElement {
const [inspectLoading, setInspectLoading] = useState(false);
const [inspectError, setInspectError] = useState<string | null>(null);
const [previewLoading, setPreviewLoading] = useState(false);
const [previewError, setPreviewError] = useState<string | null>(null);
const [previewRows, setPreviewRows] = useState<Record<string, unknown>[]>([]);
const repoId = useMemo(
() => parseHfDatasetRepoId(config.hf_url ?? ""),
[config.hf_url],
);
const pathId = `${config.id}-hf-path`;
const urlId = `${config.id}-hf-url`;
const tokenId = `${config.id}-hf-token`;
const splitId = `${config.id}-hf-split`;
const samplingId = `${config.id}-sampling`;
const selectionId = `${config.id}-selection`;
async function onInspect(): Promise<void> {
setInspectError(null);
const repo_id = repoId;
if (!repo_id) {
setInspectError("Invalid HF dataset URL (need /datasets/org/repo).");
return;
}
setInspectLoading(true);
try {
const res = await inspectSeedDataset({
// biome-ignore lint/style/useNamingConvention: api schema
repo_id,
// biome-ignore lint/style/useNamingConvention: api schema
hf_token: config.hf_token?.trim() || null,
split: config.hf_split?.trim() || null,
});
const splits = res.splits ?? [];
const globs = res.globs_by_split ?? {};
let nextSplit = "";
if (config.hf_split && splits.includes(config.hf_split)) {
nextSplit = config.hf_split;
} else if (splits[0]) {
nextSplit = splits[0];
}
const nextPath = nextSplit ? (globs[nextSplit] ?? "") : "";
onUpdate({
hf_repo_id: res.repo_id,
seed_splits: splits,
seed_globs_by_split: globs,
seed_columns: res.columns ?? [],
hf_split: nextSplit,
hf_path: nextPath || config.hf_path,
});
} catch (err) {
setInspectError(err instanceof Error ? err.message : "Inspect failed.");
} finally {
setInspectLoading(false);
}
}
async function onPreview(): Promise<void> {
setPreviewError(null);
const hf_path = config.hf_path.trim();
if (!hf_path) {
setPreviewError("HF path missing (Load first).");
return;
}
setPreviewLoading(true);
try {
const res = await previewSeedDataset({
// biome-ignore lint/style/useNamingConvention: api schema
hf_path,
// biome-ignore lint/style/useNamingConvention: api schema
hf_token: config.hf_token?.trim() || null,
// biome-ignore lint/style/useNamingConvention: api schema
sampling_strategy: config.sampling_strategy,
// biome-ignore lint/style/useNamingConvention: api schema
selection_type: config.selection_type,
// biome-ignore lint/style/useNamingConvention: api schema
selection_start: parseOptionalInt(config.selection_start),
// biome-ignore lint/style/useNamingConvention: api schema
selection_end: parseOptionalInt(config.selection_end),
// biome-ignore lint/style/useNamingConvention: api schema
selection_index: parseOptionalInt(config.selection_index),
// biome-ignore lint/style/useNamingConvention: api schema
selection_num_partitions: parseOptionalInt(config.selection_num_partitions),
limit: 10,
});
const rows = res.rows ?? [];
setPreviewRows(rows);
if ((config.seed_columns?.length ?? 0) === 0 && rows[0]) {
onUpdate({ seed_columns: Object.keys(rows[0]) });
}
} catch (err) {
setPreviewError(err instanceof Error ? err.message : "Preview failed.");
} finally {
setPreviewLoading(false);
}
}
const previewColumns = useMemo(() => {
const cols = config.seed_columns ?? [];
if (cols.length > 0) return cols;
if (previewRows[0]) return Object.keys(previewRows[0]);
return [];
}, [config.seed_columns, previewRows]);
return (
<Tabs defaultValue="config" className="w-full">
<TabsList className="w-full">
<TabsTrigger value="config">Config</TabsTrigger>
<TabsTrigger value="preview">Preview</TabsTrigger>
</TabsList>
<TabsContent value="config" className="pt-3">
<div className="space-y-4">
<div className="grid gap-2">
<label
className="text-xs font-semibold uppercase text-muted-foreground"
htmlFor={urlId}
>
HF dataset URL
</label>
<div className="flex items-center gap-2">
<Input
id={urlId}
className="nodrag flex-1"
placeholder="https://huggingface.co/datasets/org/repo"
value={config.hf_url ?? ""}
onChange={(e) => onUpdate({ hf_url: e.target.value })}
/>
<Button
type="button"
variant="outline"
className="nodrag"
onClick={() => void onInspect()}
disabled={inspectLoading}
>
{inspectLoading ? "Loading..." : "Load"}
</Button>
{inspectLoading && <Spinner />}
</div>
<p className="text-xs text-muted-foreground">
Repo: {repoId ?? "-"}
</p>
</div>
{inspectError && (
<p className="text-xs text-red-600">{inspectError}</p>
)}
{(config.seed_splits?.length ?? 0) > 0 && (
<div className="grid gap-2">
<label
className="text-xs font-semibold uppercase text-muted-foreground"
htmlFor={splitId}
>
Split
</label>
<Select
value={config.hf_split ?? ""}
onValueChange={(value) => {
const nextPath = config.seed_globs_by_split?.[value] ?? "";
onUpdate({ hf_split: value, hf_path: nextPath || config.hf_path });
}}
>
<SelectTrigger className="nodrag w-full" id={splitId}>
<SelectValue placeholder="Select split" />
</SelectTrigger>
<SelectContent>
{(config.seed_splits ?? []).map((value) => (
<SelectItem key={value} value={value}>
{value}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
)}
<div className="grid gap-2">
<label
className="text-xs font-semibold uppercase text-muted-foreground"
htmlFor={pathId}
>
HF path (auto)
</label>
<Input
id={pathId}
className="nodrag"
placeholder="datasets/org/repo/data/train-*.parquet"
value={config.hf_path}
onChange={(e) => onUpdate({ hf_path: e.target.value })}
/>
</div>
<div className="grid gap-2">
<label
className="text-xs font-semibold uppercase text-muted-foreground"
htmlFor={tokenId}
>
HF token (optional)
</label>
<Input
id={tokenId}
className="nodrag"
placeholder="hf_..."
value={config.hf_token ?? ""}
onChange={(e) => onUpdate({ hf_token: e.target.value })}
/>
</div>
<div className="grid gap-2">
<label
className="text-xs font-semibold uppercase text-muted-foreground"
htmlFor={samplingId}
>
Sampling strategy
</label>
<Select
value={config.sampling_strategy}
onValueChange={(value) =>
onUpdate({ sampling_strategy: value as SeedSamplingStrategy })
}
>
<SelectTrigger className="nodrag w-full" id={samplingId}>
<SelectValue placeholder="Select sampling" />
</SelectTrigger>
<SelectContent>
{SAMPLING_OPTIONS.map((opt) => (
<SelectItem key={opt.value} value={opt.value}>
{opt.label}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
<div className="grid gap-2">
<label
className="text-xs font-semibold uppercase text-muted-foreground"
htmlFor={selectionId}
>
Selection strategy
</label>
<Select
value={config.selection_type}
onValueChange={(value) =>
onUpdate({ selection_type: value as SeedSelectionType })
}
>
<SelectTrigger className="nodrag w-full" id={selectionId}>
<SelectValue placeholder="Select selection" />
</SelectTrigger>
<SelectContent>
{SELECTION_OPTIONS.map((opt) => (
<SelectItem key={opt.value} value={opt.value}>
{opt.label}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
{config.selection_type === "index_range" && (
<div className="grid grid-cols-2 gap-3">
<div className="grid gap-2">
<label className="text-xs font-semibold uppercase text-muted-foreground">
Start
</label>
<Input
className="nodrag"
inputMode="numeric"
value={config.selection_start ?? ""}
onChange={(e) => onUpdate({ selection_start: e.target.value })}
/>
</div>
<div className="grid gap-2">
<label className="text-xs font-semibold uppercase text-muted-foreground">
End
</label>
<Input
className="nodrag"
inputMode="numeric"
value={config.selection_end ?? ""}
onChange={(e) => onUpdate({ selection_end: e.target.value })}
/>
</div>
</div>
)}
{config.selection_type === "partition_block" && (
<div className="grid grid-cols-2 gap-3">
<div className="grid gap-2">
<label className="text-xs font-semibold uppercase text-muted-foreground">
Index
</label>
<Input
className="nodrag"
inputMode="numeric"
value={config.selection_index ?? ""}
onChange={(e) => onUpdate({ selection_index: e.target.value })}
/>
</div>
<div className="grid gap-2">
<label className="text-xs font-semibold uppercase text-muted-foreground">
Partitions
</label>
<Input
className="nodrag"
inputMode="numeric"
value={config.selection_num_partitions ?? ""}
onChange={(e) =>
onUpdate({ selection_num_partitions: e.target.value })
}
/>
</div>
</div>
)}
<p className="text-xs text-muted-foreground">
Seed columns auto-add. Reference by name (ex: {"{{ rubrics }}"}).
</p>
</div>
</TabsContent>
<TabsContent value="preview" className="pt-3">
<div className="space-y-4">
{previewRows.length === 0 ? (
<div className="flex w-full items-center justify-center">
<Empty className="max-w-lg">
<EmptyHeader>
<EmptyTitle>Preview samples</EmptyTitle>
<EmptyDescription>
Load 10 rows to see columns and sample values.
</EmptyDescription>
</EmptyHeader>
<EmptyContent>
<Button
type="button"
variant="outline"
className="nodrag"
onClick={() => void onPreview()}
disabled={previewLoading}
>
{previewLoading ? "Loading..." : "Load 10 rows"}
</Button>
</EmptyContent>
</Empty>
</div>
) : (
<div className="space-y-3">
<div className="flex items-center gap-3">
<Button
type="button"
variant="outline"
className="nodrag"
onClick={() => void onPreview()}
disabled={previewLoading}
>
{previewLoading ? "Loading..." : "Reload 10 rows"}
</Button>
{previewLoading && <Spinner />}
</div>
<Table className="border border-border/60 rounded-xl">
<TableHeader>
<TableRow>
{previewColumns.map((col) => (
<TableHead key={col} className="max-w-[260px]">
{col}
</TableHead>
))}
</TableRow>
</TableHeader>
<TableBody>
{previewRows.map((row, idx) => (
<TableRow key={idx}>
{previewColumns.map((col) => (
<TableCell key={col} className="max-w-[260px]">
<div className="truncate">{stringifyCell(row[col])}</div>
</TableCell>
))}
</TableRow>
))}
</TableBody>
</Table>
</div>
)}
{previewError && <p className="text-xs text-red-600">{previewError}</p>}
</div>
</TabsContent>
</Tabs>
);
}

View file

@ -76,6 +76,33 @@ function toErrorMessage(error: unknown, fallback: string): string {
return fallback;
}
async function copyTextToClipboard(text: string): Promise<boolean> {
try {
if (navigator.clipboard?.writeText) {
await navigator.clipboard.writeText(text);
return true;
}
} catch {
// fallthrough to legacy path
}
try {
const textarea = document.createElement("textarea");
textarea.value = text;
textarea.setAttribute("readonly", "");
textarea.style.position = "fixed";
textarea.style.top = "0";
textarea.style.left = "-9999px";
document.body.appendChild(textarea);
textarea.select();
const ok = document.execCommand("copy");
document.body.removeChild(textarea);
return ok;
} catch {
return false;
}
}
export function useRecipeStudioActions({
recipeId,
initialRecipeName,
@ -237,13 +264,11 @@ export function useRecipeStudioActions({
toastError("Copy failed", payloadErrorMessage);
return;
}
if (!navigator.clipboard) {
console.error("Clipboard not available.");
toastError("Copy failed", "Clipboard not available.");
return;
}
try {
await navigator.clipboard.writeText(JSON.stringify(payload, null, 2));
const ok = await copyTextToClipboard(JSON.stringify(payload, null, 2));
if (!ok) {
throw new Error("Clipboard not available.");
}
setCopied(true);
window.setTimeout(() => setCopied(false), 1500);
toastSuccess("Payload copied");

View file

@ -91,6 +91,7 @@ export function RecipeStudioPage({
onEdgesChange,
onConnect,
addSamplerNode,
addSeedNode,
addLlmNode,
addModelProviderNode,
addModelConfigNode,
@ -126,6 +127,7 @@ export function RecipeStudioPage({
onEdgesChange: state.onEdgesChange,
onConnect: state.onConnect,
addSamplerNode: state.addSamplerNode,
addSeedNode: state.addSeedNode,
addLlmNode: state.addLlmNode,
addModelProviderNode: state.addModelProviderNode,
addModelConfigNode: state.addModelConfigNode,
@ -377,6 +379,7 @@ export function RecipeStudioPage({
sheetView={sheetView}
onViewChange={setSheetView}
onAddSampler={addSamplerNode}
onAddSeed={addSeedNode}
onAddLlm={addLlmNode}
onAddModelProvider={addModelProviderNode}
onAddModelConfig={addModelConfigNode}

View file

@ -37,7 +37,7 @@ import {
updateNodeData,
} from "./recipe-studio-helpers";
type SheetView = "root" | "sampler" | "llm" | "expression" | "processor";
type SheetView = "root" | "sampler" | "seed" | "llm" | "expression" | "processor";
type RecipeStudioState = {
nodes: RecipeNode[];
@ -63,6 +63,7 @@ type RecipeStudioState = {
setLayoutDirection: (direction: LayoutDirection) => void;
applyLayout: () => void;
addSamplerNode: (type: SamplerType) => void;
addSeedNode: () => void;
addLlmNode: (type: LlmType) => void;
addModelProviderNode: () => void;
addModelConfigNode: () => void;
@ -165,6 +166,23 @@ export const useRecipeStudioStore = create<RecipeStudioState>((set, get) => ({
}),
addSamplerNode: (type) =>
set((state) => buildAddedNodeState(state, "sampler", type)),
addSeedNode: () =>
set((state) => {
const existing = Object.values(state.configs).find(
(config) => config.kind === "seed",
);
if (!existing) {
return buildAddedNodeState(state, "seed", "seed");
}
return {
activeConfigId: existing.id,
dialogOpen: true,
nodes: state.nodes.map((node) => ({
...node,
selected: node.id === existing.id,
})),
};
}),
addLlmNode: (type) => set((state) => buildAddedNodeState(state, "llm", type)),
addModelProviderNode: () =>
set((state) => buildAddedNodeState(state, "llm", "model_provider")),

View file

@ -18,15 +18,25 @@ export type ExpressionDtype = "str" | "int" | "float" | "bool";
export type LayoutDirection = "LR" | "TB";
export type SeedSamplingStrategy = "ordered" | "shuffle";
export type SeedSelectionType = "none" | "index_range" | "partition_block";
export type RecipeNodeData = {
title: string;
name: string;
kind: "sampler" | "llm" | "expression" | "model_provider" | "model_config";
kind:
| "sampler"
| "llm"
| "expression"
| "seed"
| "model_provider"
| "model_config";
subtype: string;
blockType:
| SamplerType
| LlmType
| "expression"
| "seed"
| "model_provider"
| "model_config";
layoutDirection?: LayoutDirection;
@ -204,6 +214,31 @@ export type ExpressionConfig = {
dtype: ExpressionDtype;
};
export type SeedConfig = {
id: string;
kind: "seed";
name: string;
drop?: boolean;
// ui-only (serialized in seed_config)
hf_url?: string;
hf_repo_id?: string;
hf_split?: string;
hf_path: string;
hf_token?: string;
hf_endpoint?: string;
seed_splits?: string[];
// ui-only
// biome-ignore lint/style/useNamingConvention: ui schema
seed_globs_by_split?: Record<string, string>;
seed_columns?: string[];
sampling_strategy: SeedSamplingStrategy;
selection_type: SeedSelectionType;
selection_start?: string;
selection_end?: string;
selection_index?: string;
selection_num_partitions?: string;
};
export type SchemaTransformProcessorConfig = {
id: string;
// biome-ignore lint/style/useNamingConvention: api schema
@ -218,5 +253,6 @@ export type NodeConfig =
| SamplerConfig
| LlmConfig
| ExpressionConfig
| SeedConfig
| ModelProviderConfig
| ModelConfig;

View file

@ -5,6 +5,7 @@ import type {
ModelConfig,
ModelProviderConfig,
NodeConfig,
SeedConfig,
SamplerConfig,
SamplerType,
} from "../types";
@ -281,3 +282,30 @@ export function makeExpressionConfig(
dtype: "str",
};
}
export function makeSeedConfig(
id: string,
existing: NodeConfig[],
): SeedConfig {
return {
id,
kind: "seed",
name: nextName(existing, "seed"),
drop: false,
hf_url: "",
hf_repo_id: "",
hf_split: "",
hf_path: "",
hf_token: "",
hf_endpoint: "https://huggingface.co",
seed_splits: [],
seed_globs_by_split: {},
seed_columns: [],
sampling_strategy: "ordered",
selection_type: "none",
selection_start: "0",
selection_end: "10",
selection_index: "0",
selection_num_partitions: "1",
};
}

View file

@ -12,6 +12,7 @@ import {
parseModelConfig,
parseModelProvider,
} from "./parsers";
import { parseSeedConfig } from "./parsers/seed-config-parser";
import { buildNodes, parseUi } from "./ui";
import type { ImportResult } from "./types";
@ -22,6 +23,7 @@ type RecipeInput = {
mcp_providers?: unknown;
tool_configs?: unknown;
processors?: unknown;
seed_config?: unknown;
};
type UiInput = {
@ -218,6 +220,15 @@ export function importRecipePayload(input: string): ImportResult {
let nextId = 1;
if (recipe.seed_config) {
const id = `n${nextId}`;
nextId += 1;
const seedConfig = parseSeedConfig(recipe.seed_config, id);
if (seedConfig) {
configs.push(seedConfig);
}
}
if (Array.isArray(recipe.model_providers)) {
recipe.model_providers.forEach((provider, index) => {
if (!isRecord(provider)) {

View file

@ -13,7 +13,8 @@ type ColumnParser = (
) => NodeConfig | null;
const COLUMN_PARSERS: Record<string, ColumnParser> = {
sampler: (column, name, id, errors) => parseSampler(column, name, id, errors),
sampler: (column, name, id, errors) =>
parseSampler(column, name, id, errors),
expression: (column, name, id) => parseExpression(column, name, id),
"llm-text": (column, name, id) => parseLlm(column, name, id),
"llm-structured": (column, name, id) => parseLlm(column, name, id),

View file

@ -0,0 +1,103 @@
import type {
SeedConfig,
SeedSamplingStrategy,
SeedSelectionType,
} from "../../../types";
import { isRecord, readString } from "../helpers";
function normalizeSampling(value: unknown): SeedSamplingStrategy {
const raw = readString(value);
if (raw === "shuffle") return "shuffle";
return "ordered";
}
function makeDefaultSeedConfig(id: string): SeedConfig {
return {
id,
kind: "seed",
name: "seed",
drop: false,
hf_url: "",
hf_repo_id: "",
hf_split: "",
hf_path: "",
hf_token: "",
hf_endpoint: "https://huggingface.co",
seed_splits: [],
seed_globs_by_split: {},
seed_columns: [],
sampling_strategy: "ordered",
selection_type: "none",
selection_start: "0",
selection_end: "10",
selection_index: "0",
selection_num_partitions: "1",
};
}
function parseSeedSettings(seedConfigRaw: unknown): Partial<SeedConfig> {
if (!isRecord(seedConfigRaw)) {
return {};
}
const sampling_strategy = normalizeSampling(seedConfigRaw.sampling_strategy);
let hf_path = "";
let hf_token = "";
let hf_endpoint = "https://huggingface.co";
const sourceRaw = seedConfigRaw.source;
if (isRecord(sourceRaw) && readString(sourceRaw.seed_type) === "hf") {
hf_path = readString(sourceRaw.path) ?? "";
hf_token = readString(sourceRaw.token) ?? "";
hf_endpoint = readString(sourceRaw.endpoint) ?? hf_endpoint;
}
let selection_type: SeedSelectionType = "none";
let selection_start = "0";
let selection_end = "10";
let selection_index = "0";
let selection_num_partitions = "1";
const selectionRaw = seedConfigRaw.selection_strategy;
if (isRecord(selectionRaw)) {
if (
typeof selectionRaw.start === "number" &&
typeof selectionRaw.end === "number"
) {
selection_type = "index_range";
selection_start = String(selectionRaw.start);
selection_end = String(selectionRaw.end);
} else if (
typeof selectionRaw.index === "number" &&
typeof selectionRaw.num_partitions === "number"
) {
selection_type = "partition_block";
selection_index = String(selectionRaw.index);
selection_num_partitions = String(selectionRaw.num_partitions);
}
}
return {
hf_path,
hf_token,
hf_endpoint,
sampling_strategy,
selection_type,
selection_start,
selection_end,
selection_index,
selection_num_partitions,
};
}
export function parseSeedConfig(
seedConfigRaw: unknown,
id: string,
): SeedConfig | null {
if (!seedConfigRaw) {
return null;
}
return {
...makeDefaultSeedConfig(id),
...parseSeedSettings(seedConfigRaw), // payload-only fields override ui defaults
};
}

View file

@ -4,6 +4,7 @@ export {
makeModelConfig,
makeModelProviderConfig,
makeSamplerConfig,
makeSeedConfig,
} from "./config-factories";
export {
labelForExpression,

View file

@ -29,6 +29,16 @@ export function nodeDataFromConfig(
layoutDirection,
};
}
if (config.kind === "seed") {
return {
title: "Seed",
kind: "seed",
subtype: "Hugging Face",
blockType: "seed",
name: config.name,
layoutDirection,
};
}
if (config.kind === "model_provider") {
return {
title: "Model Provider",

View file

@ -16,6 +16,9 @@ import {
buildModelProvider,
buildProcessors,
buildSamplerColumn,
buildSeedConfig,
buildSeedDropProcessor,
pickFirstSeedConfig,
} from "./builders";
import type { RecipePayloadResult } from "./types";
import {
@ -65,6 +68,7 @@ export function buildRecipePayload(
const toolConfigJsonByAlias = new Map<string, string>();
const nameSet = new Set<string>();
const nameToConfig = new Map<string, NodeConfig>();
const firstSeed = pickFirstSeedConfig(configs);
for (const node of nodes) {
const config = configs[node.id];
@ -74,10 +78,12 @@ export function buildRecipePayload(
for (const error of getConfigErrors(config)) {
errors.push(`${config.name}: ${error}`);
}
if (nameSet.has(config.name)) {
errors.push(`Duplicate node name: ${config.name}.`);
if (config.kind !== "seed") {
if (nameSet.has(config.name)) {
errors.push(`Duplicate node name: ${config.name}.`);
}
nameSet.add(config.name);
}
nameSet.add(config.name);
if (config.kind === "sampler") {
nameToConfig.set(config.name, config);
@ -135,6 +141,10 @@ export function buildRecipePayload(
nameToConfig.set(config.name, config);
continue;
}
if (config.kind === "seed") {
// SeedConfig is global config (seed_config); seed-dataset columns are added by DataDesigner.
continue;
}
if (config.kind === "model_provider") {
modelProviderNames.add(config.name);
modelProviders.push(buildModelProvider(config, errors));
@ -166,6 +176,9 @@ export function buildRecipePayload(
if (!config) {
return [];
}
if (config.kind === "seed") {
return [];
}
const width = getNodeWidth(node);
return [
{
@ -195,6 +208,13 @@ export function buildRecipePayload(
];
});
const recipeProcessors = buildProcessors(processors, errors);
const seedConfig = firstSeed ? buildSeedConfig(firstSeed, errors) : undefined;
const seedDropProcessor = firstSeed
? buildSeedDropProcessor(firstSeed, errors)
: null;
if (seedDropProcessor) {
recipeProcessors.push(seedDropProcessor);
}
return {
errors,
@ -207,6 +227,8 @@ export function buildRecipePayload(
// biome-ignore lint/style/useNamingConvention: api schema
model_configs: modelConfigs,
// biome-ignore lint/style/useNamingConvention: api schema
seed_config: seedConfig,
// biome-ignore lint/style/useNamingConvention: api schema
tool_configs: toolConfigs,
columns,
processors: recipeProcessors,

View file

@ -0,0 +1,90 @@
import type { NodeConfig, SeedConfig } from "../../types";
function parseIntStrict(value: string | undefined): number | null {
const trimmed = value?.trim();
if (!trimmed) return null;
const num = Number(value);
if (!Number.isFinite(num) || !Number.isInteger(num)) return null;
return num;
}
export function buildSeedConfig(
config: SeedConfig,
errors: string[],
): Record<string, unknown> | undefined {
const path = config.hf_path.trim();
if (!path) {
return undefined;
}
const endpoint = config.hf_endpoint?.trim() || "https://huggingface.co";
const token = config.hf_token?.trim() || null;
let selectionStrategy: Record<string, unknown> | null = null;
if (config.selection_type === "index_range") {
const start = parseIntStrict(config.selection_start);
const end = parseIntStrict(config.selection_end);
if (start === null || end === null) {
errors.push(`Seed ${config.name}: selection index range invalid.`);
return undefined;
}
selectionStrategy = { start, end };
} else if (config.selection_type === "partition_block") {
const index = parseIntStrict(config.selection_index);
const numPartitions = parseIntStrict(config.selection_num_partitions);
if (index === null || numPartitions === null) {
errors.push(`Seed ${config.name}: selection partition invalid.`);
return undefined;
}
// biome-ignore lint/style/useNamingConvention: api schema
selectionStrategy = { index, num_partitions: numPartitions };
}
return {
source: {
// biome-ignore lint/style/useNamingConvention: api schema
seed_type: "hf",
path,
token,
endpoint,
},
// biome-ignore lint/style/useNamingConvention: api schema
sampling_strategy: config.sampling_strategy,
// biome-ignore lint/style/useNamingConvention: api schema
selection_strategy: selectionStrategy,
};
}
export function pickFirstSeedConfig(
configs: Record<string, NodeConfig>,
): SeedConfig | null {
for (const config of Object.values(configs)) {
if (config.kind === "seed") {
return config;
}
}
return null;
}
export function buildSeedDropProcessor(
config: SeedConfig,
errors: string[],
): Record<string, unknown> | null {
if (!config.drop) {
return null;
}
const cols = (config.seed_columns ?? []).map((c) => c.trim()).filter(Boolean);
if (cols.length === 0) {
errors.push(`Seed ${config.name}: drop enabled but no seed columns loaded.`);
return null;
}
return {
// biome-ignore lint/style/useNamingConvention: api schema
processor_type: "drop_columns",
name: "drop_seed_columns",
// biome-ignore lint/style/useNamingConvention: api schema
build_stage: "post_batch",
// biome-ignore lint/style/useNamingConvention: api schema
column_names: cols,
};
}

View file

@ -2,3 +2,8 @@ export { buildLlmColumn, buildLlmMcpProvider, buildLlmToolConfig } from "./build
export { buildModelConfig, buildModelProvider } from "./builders-model";
export { buildExpressionColumn, buildProcessors } from "./builders-processors";
export { buildSamplerColumn } from "./builders-sampler";
export {
buildSeedConfig,
buildSeedDropProcessor,
pickFirstSeedConfig,
} from "./builders-seed";

View file

@ -7,6 +7,8 @@ export type RecipePayload = {
// biome-ignore lint/style/useNamingConvention: api schema
model_configs: Record<string, unknown>[];
// biome-ignore lint/style/useNamingConvention: api schema
seed_config?: Record<string, unknown>;
// biome-ignore lint/style/useNamingConvention: api schema
tool_configs: Record<string, unknown>[];
columns: Record<string, unknown>[];
processors: Record<string, unknown>[];

View file

@ -8,6 +8,14 @@ function parseNumber(value?: string): number | null {
return Number.isFinite(num) ? num : null;
}
function parseIntNumber(value?: string): number | null {
const num = parseNumber(value);
if (num === null || !Number.isInteger(num)) {
return null;
}
return num;
}
function parseAgeRange(value?: string): [number, number] | null {
if (!value) {
return null;
@ -207,5 +215,44 @@ export function getConfigErrors(config: NodeConfig | null): string[] {
errors.push("Expression is required.");
}
}
if (config.kind === "seed") {
if (!config.hf_path.trim()) {
errors.push("HF dataset path is required.");
}
if (config.hf_endpoint?.trim() && !config.hf_endpoint.trim().startsWith("http")) {
errors.push("HF endpoint must start with http.");
}
if (config.drop && (config.seed_columns?.length ?? 0) === 0) {
errors.push("Seed drop needs loaded columns (open Seed Preview).");
}
if (config.selection_type === "index_range") {
const start = parseIntNumber(config.selection_start);
const end = parseIntNumber(config.selection_end);
if (start === null || end === null) {
errors.push("Index range start/end must be integers.");
} else {
if (start < 0 || end < 0) {
errors.push("Index range start/end must be >= 0.");
}
if (end < start) {
errors.push("Index range end must be >= start.");
}
}
}
if (config.selection_type === "partition_block") {
const index = parseIntNumber(config.selection_index);
const parts = parseIntNumber(config.selection_num_partitions);
if (index === null || parts === null) {
errors.push("Partition index/num_partitions must be integers.");
} else {
if (index < 0) errors.push("Partition index must be >= 0.");
if (parts < 1) errors.push("Partition num_partitions must be >= 1.");
if (parts >= 1 && index >= parts) {
errors.push("Partition index must be < num_partitions.");
}
}
}
}
return errors;
}

View file

@ -36,31 +36,55 @@ export function getAvailableRefItems(
const items: AvailableRefItem[] = [];
for (const config of Object.values(configs)) {
if (config.id === currentId) continue;
if (config.kind === "model_provider" || config.kind === "model_config") continue;
if (config.id === currentId) {
continue;
}
if (config.kind === "model_provider" || config.kind === "model_config") {
continue;
}
if (config.kind === "sampler") {
items.push({ ref: config.name, kind: "sampler", subtype: config.sampler_type });
items.push({
ref: config.name,
kind: "sampler",
subtype: config.sampler_type,
});
continue;
}
if (config.kind === "expression") {
items.push({ ref: config.name, kind: "expression", subtype: config.dtype });
items.push({
ref: config.name,
kind: "expression",
subtype: config.dtype,
});
continue;
}
if (config.kind === "llm") {
items.push({ ref: config.name, kind: "llm", subtype: config.llm_type });
if (config.llm_type === "structured" && config.output_format) {
for (const ref of getStructuredRefs(config.name, config.output_format)) {
items.push({
ref: ref.ref,
kind: "llm",
subtype: config.llm_type,
valueType: ref.valueType,
});
}
if (config.kind === "seed") {
for (const col of config.seed_columns ?? []) {
const name = col.trim();
if (!name) continue;
items.push({ ref: name, kind: "seed", subtype: "seed" });
}
continue;
}
if (config.kind !== "llm") {
continue;
}
items.push({ ref: config.name, kind: "llm", subtype: config.llm_type });
if (config.llm_type !== "structured" || !config.output_format) {
continue;
}
for (const ref of getStructuredRefs(config.name, config.output_format)) {
items.push({
ref: ref.ref,
kind: "llm",
subtype: config.llm_type,
valueType: ref.valueType,
});
}
}
@ -73,4 +97,3 @@ export function getAvailableVariables(
): string[] {
return getAvailableRefItems(configs, currentId).map((item) => item.ref);
}

View file

@ -17,6 +17,14 @@ export default defineConfig({
target: "http://127.0.0.1:8000",
changeOrigin: true,
},
"/seed/inspect": {
target: "http://127.0.0.1:8004",
changeOrigin: true,
},
"/seed/preview": {
target: "http://127.0.0.1:8004",
changeOrigin: true,
},
"/preview": {
target: "http://127.0.0.1:8004",
changeOrigin: true,