mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
feat: add seed dataset support with configuration, preview, and builder utilities
This commit is contained in:
parent
2bd20d7d15
commit
f2a00d6e44
25 changed files with 1035 additions and 34 deletions
|
|
@ -4,6 +4,8 @@
|
|||
"ignore": [
|
||||
"dist",
|
||||
"node_modules",
|
||||
"test",
|
||||
"test/**",
|
||||
"**/._*",
|
||||
"._*",
|
||||
"**/.DS_Store",
|
||||
|
|
|
|||
|
|
@ -28,6 +28,19 @@ export type ToolsResponse = {
|
|||
tools: string[];
|
||||
};
|
||||
|
||||
export type SeedInspectResponse = {
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
repo_id: string;
|
||||
splits: string[];
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
globs_by_split: Record<string, string>;
|
||||
columns: string[];
|
||||
};
|
||||
|
||||
export type SeedPreviewResponse = {
|
||||
rows: Record<string, unknown>[];
|
||||
};
|
||||
|
||||
async function parseErrorResponse(response: Response): Promise<string> {
|
||||
const text = (await response.text()).trim();
|
||||
if (!text) {
|
||||
|
|
@ -80,3 +93,11 @@ export async function validateRecipe(
|
|||
export async function listRecipeTools(payload: unknown): Promise<ToolsResponse> {
|
||||
return postJson<ToolsResponse>("/tools", payload);
|
||||
}
|
||||
|
||||
export async function inspectSeedDataset(payload: unknown): Promise<SeedInspectResponse> {
|
||||
return postJson<SeedInspectResponse>("/seed/inspect", payload);
|
||||
}
|
||||
|
||||
export async function previewSeedDataset(payload: unknown): Promise<SeedPreviewResponse> {
|
||||
return postJson<SeedPreviewResponse>("/seed/preview", payload);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,11 +23,13 @@ import {
|
|||
makeModelConfig,
|
||||
makeModelProviderConfig,
|
||||
makeSamplerConfig,
|
||||
makeSeedConfig,
|
||||
} from "../utils";
|
||||
import { ExpressionDialog } from "../dialogs/expression/expression-dialog";
|
||||
import { LlmDialog } from "../dialogs/llm/llm-dialog";
|
||||
import { ModelConfigDialog } from "../dialogs/models/model-config-dialog";
|
||||
import { ModelProviderDialog } from "../dialogs/models/model-provider-dialog";
|
||||
import { SeedDialog } from "../dialogs/seed/seed-dialog";
|
||||
import { CategoryDialog } from "../dialogs/samplers/category-dialog";
|
||||
import { DatetimeDialog } from "../dialogs/samplers/datetime-dialog";
|
||||
import { BernoulliDialog } from "../dialogs/samplers/bernoulli-dialog";
|
||||
|
|
@ -38,11 +40,12 @@ import { TimedeltaDialog } from "../dialogs/samplers/timedelta-dialog";
|
|||
import { UniformDialog } from "../dialogs/samplers/uniform-dialog";
|
||||
import { UuidDialog } from "../dialogs/samplers/uuid-dialog";
|
||||
|
||||
export type BlockKind = "sampler" | "llm" | "expression";
|
||||
export type BlockKind = "sampler" | "llm" | "expression" | "seed";
|
||||
export type BlockType =
|
||||
| SamplerType
|
||||
| LlmType
|
||||
| "expression"
|
||||
| "seed"
|
||||
| "model_provider"
|
||||
| "model_config";
|
||||
|
||||
|
|
@ -81,6 +84,12 @@ export const BLOCK_GROUPS: BlockGroup[] = [
|
|||
description: "Numeric + categorical blocks.",
|
||||
icon: DiceFaces03Icon,
|
||||
},
|
||||
{
|
||||
kind: "seed",
|
||||
title: "Seed",
|
||||
description: "Columns from a seed dataset.",
|
||||
icon: Plant01Icon,
|
||||
},
|
||||
{
|
||||
kind: "llm",
|
||||
title: "LLM",
|
||||
|
|
@ -96,6 +105,21 @@ export const BLOCK_GROUPS: BlockGroup[] = [
|
|||
];
|
||||
|
||||
const BLOCK_DEFINITIONS: BlockDefinition[] = [
|
||||
{
|
||||
kind: "seed",
|
||||
type: "seed",
|
||||
title: "Seed (Hugging Face)",
|
||||
description: "Configure a HF seed dataset.",
|
||||
icon: Plant01Icon,
|
||||
createConfig: (id, existing) => makeSeedConfig(id, existing),
|
||||
renderDialog: ({ config, onUpdate }) =>
|
||||
config.kind === "seed" ? (
|
||||
<SeedDialog
|
||||
config={config}
|
||||
onUpdate={(patch) => onUpdate(config.id, patch)}
|
||||
/>
|
||||
) : null,
|
||||
},
|
||||
{
|
||||
kind: "sampler",
|
||||
type: "category",
|
||||
|
|
@ -372,6 +396,9 @@ export function getBlockDefinitionForConfig(
|
|||
if (!config) {
|
||||
return null;
|
||||
}
|
||||
if (config.kind === "seed") {
|
||||
return getBlockDefinition("seed", "seed");
|
||||
}
|
||||
if (config.kind === "sampler") {
|
||||
const samplerType =
|
||||
config.sampler_type === "person_from_faker"
|
||||
|
|
|
|||
|
|
@ -22,8 +22,8 @@ import { RECIPE_FLOATING_ICON_BUTTON_CLASS } from "./recipe-floating-icon-button
|
|||
import type { LlmType, SamplerType } from "../types";
|
||||
import { BLOCK_GROUPS, getBlocksForKind } from "../blocks/registry";
|
||||
|
||||
type SheetView = "root" | "sampler" | "llm" | "expression" | "processor";
|
||||
type SheetKind = "sampler" | "llm" | "expression";
|
||||
type SheetView = "root" | "sampler" | "seed" | "llm" | "expression" | "processor";
|
||||
type SheetKind = "sampler" | "seed" | "llm" | "expression";
|
||||
type RootSheetView = Exclude<SheetView, "root">;
|
||||
type RootGroup = {
|
||||
kind: RootSheetView;
|
||||
|
|
@ -37,6 +37,7 @@ type BlockSheetProps = {
|
|||
sheetView: SheetView;
|
||||
onViewChange: (sheetView: SheetView) => void;
|
||||
onAddSampler: (type: SamplerType) => void;
|
||||
onAddSeed: () => void;
|
||||
onAddLlm: (type: LlmType) => void;
|
||||
onAddModelProvider: () => void;
|
||||
onAddModelConfig: () => void;
|
||||
|
|
@ -54,6 +55,9 @@ function getSheetTitle(sheetView: SheetView): string {
|
|||
if (sheetView === "sampler") {
|
||||
return "Sampler blocks";
|
||||
}
|
||||
if (sheetView === "seed") {
|
||||
return "Seed blocks";
|
||||
}
|
||||
if (sheetView === "expression") {
|
||||
return "Expression blocks";
|
||||
}
|
||||
|
|
@ -66,6 +70,7 @@ function getSheetTitle(sheetView: SheetView): string {
|
|||
const VIEW_KIND: Record<SheetView, SheetKind | null> = {
|
||||
root: null,
|
||||
sampler: "sampler",
|
||||
seed: "seed",
|
||||
llm: "llm",
|
||||
expression: "expression",
|
||||
processor: null,
|
||||
|
|
@ -124,6 +129,7 @@ export function BlockSheet({
|
|||
sheetView,
|
||||
onViewChange,
|
||||
onAddSampler,
|
||||
onAddSeed,
|
||||
onAddLlm,
|
||||
onAddModelProvider,
|
||||
onAddModelConfig,
|
||||
|
|
@ -136,6 +142,7 @@ export function BlockSheet({
|
|||
const sheetTitle = getSheetTitle(sheetView);
|
||||
const [open, setOpen] = useState(false);
|
||||
const expressionBlocks = useMemo(() => getBlocksForKind("expression"), []);
|
||||
const seedBlocks = useMemo(() => getBlocksForKind("seed"), []);
|
||||
return (
|
||||
<div className="flex flex-col items-end gap-2">
|
||||
<Sheet
|
||||
|
|
@ -198,6 +205,11 @@ export function BlockSheet({
|
|||
onOpenProcessors();
|
||||
return;
|
||||
}
|
||||
if (item.kind === "seed" && seedBlocks.length === 1) {
|
||||
setOpen(false);
|
||||
onAddSeed();
|
||||
return;
|
||||
}
|
||||
if (item.kind === "expression" && expressionBlocks.length === 1) {
|
||||
setOpen(false);
|
||||
onAddExpression();
|
||||
|
|
@ -229,6 +241,8 @@ export function BlockSheet({
|
|||
onClick={() => {
|
||||
if (item.kind === "sampler") {
|
||||
onAddSampler(item.type as SamplerType);
|
||||
} else if (item.kind === "seed") {
|
||||
onAddSeed();
|
||||
} else if (item.kind === "llm") {
|
||||
if (item.type === "model_provider") {
|
||||
onAddModelProvider();
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ import {
|
|||
FunctionIcon,
|
||||
Parabola02Icon,
|
||||
PencilEdit02Icon,
|
||||
Plant01Icon,
|
||||
Tag01Icon,
|
||||
TagsIcon,
|
||||
UserAccountIcon,
|
||||
|
|
@ -76,6 +77,7 @@ function getJinjaContext(
|
|||
|
||||
function getItemIcon(item: AvailableRefItem) {
|
||||
if (item.kind === "expression") return FunctionIcon;
|
||||
if (item.kind === "seed") return Plant01Icon;
|
||||
if (item.kind === "llm") {
|
||||
if (item.subtype === "structured") return CodeIcon;
|
||||
if (item.subtype === "code") return CodeSimpleIcon;
|
||||
|
|
|
|||
|
|
@ -61,6 +61,9 @@ const NODE_META = {
|
|||
expression: {
|
||||
tone: "bg-indigo-50 text-indigo-600 border-indigo-100",
|
||||
},
|
||||
seed: {
|
||||
tone: "bg-lime-50 text-lime-700 border-lime-100",
|
||||
},
|
||||
model_provider: {
|
||||
tone: "bg-amber-50 text-amber-600 border-amber-100",
|
||||
},
|
||||
|
|
@ -108,6 +111,9 @@ function resolveNodeIcon(
|
|||
if (kind === "model_config") {
|
||||
return Plant01Icon;
|
||||
}
|
||||
if (kind === "seed") {
|
||||
return Plant01Icon;
|
||||
}
|
||||
return DiceFaces03Icon;
|
||||
}
|
||||
|
||||
|
|
@ -163,6 +169,13 @@ function getConfigSummary(config: NodeConfig | undefined): string {
|
|||
return "Prompt/system via linked input nodes";
|
||||
}
|
||||
|
||||
if (config.kind === "seed") {
|
||||
if (config.hf_path.trim()) {
|
||||
return config.hf_path.trim();
|
||||
}
|
||||
return "Set HF dataset path";
|
||||
}
|
||||
|
||||
return "Open details for config";
|
||||
}
|
||||
|
||||
|
|
@ -292,7 +305,9 @@ function RecipeGraphNodeBase({
|
|||
}, [id, layoutDirection, config, updateNodeInternals]);
|
||||
|
||||
const showDataHandles =
|
||||
data.kind === "llm" || data.kind === "expression" || data.kind === "sampler";
|
||||
data.kind === "llm" ||
|
||||
data.kind === "expression" ||
|
||||
data.kind === "sampler";
|
||||
const showSemanticIn = data.kind === "llm" || data.kind === "model_config";
|
||||
const showSemanticOut = data.kind === "model_config" || data.kind === "model_provider";
|
||||
const isTopBottom = layoutDirection === "TB";
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ import { Button } from "@/components/ui/button";
|
|||
import { Dialog, DialogContent, DialogFooter } from "@/components/ui/dialog";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import type { ReactElement } from "react";
|
||||
import type { NodeConfig, SamplerConfig } from "../types";
|
||||
import { renderBlockDialog } from "../blocks/registry";
|
||||
import type { NodeConfig, SamplerConfig } from "../types";
|
||||
import { DialogShell } from "./shared/dialog-shell";
|
||||
import { ValidationBanner } from "./shared/validation-banner";
|
||||
|
||||
|
|
@ -50,7 +50,8 @@ export function ConfigDialog({
|
|||
<ValidationBanner config={config} />
|
||||
{(config.kind === "sampler" ||
|
||||
config.kind === "llm" ||
|
||||
config.kind === "expression") && (
|
||||
config.kind === "expression" ||
|
||||
config.kind === "seed") && (
|
||||
<div className="flex items-center corner-squircle justify-between gap-3 rounded-2xl border border-border/60 px-3 py-2">
|
||||
<div>
|
||||
<p className="text-sm font-semibold">Drop from final dataset</p>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,486 @@
|
|||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Empty,
|
||||
EmptyContent,
|
||||
EmptyDescription,
|
||||
EmptyHeader,
|
||||
EmptyTitle,
|
||||
} from "@/components/ui/empty";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import {
|
||||
Tabs,
|
||||
TabsContent,
|
||||
TabsList,
|
||||
TabsTrigger,
|
||||
} from "@/components/ui/tabs";
|
||||
import { type ReactElement, useMemo, useState } from "react";
|
||||
import { inspectSeedDataset, previewSeedDataset } from "../../api";
|
||||
import type {
|
||||
SeedConfig,
|
||||
SeedSamplingStrategy,
|
||||
SeedSelectionType,
|
||||
} from "../../types";
|
||||
|
||||
const SAMPLING_OPTIONS: Array<{ value: SeedSamplingStrategy; label: string }> = [
|
||||
{ value: "ordered", label: "Ordered" },
|
||||
{ value: "shuffle", label: "Shuffle" },
|
||||
];
|
||||
|
||||
const SELECTION_OPTIONS: Array<{ value: SeedSelectionType; label: string }> = [
|
||||
{ value: "none", label: "None" },
|
||||
{ value: "index_range", label: "Index range" },
|
||||
{ value: "partition_block", label: "Partition block" },
|
||||
];
|
||||
|
||||
type SeedDialogProps = {
|
||||
config: SeedConfig;
|
||||
onUpdate: (patch: Partial<SeedConfig>) => void;
|
||||
};
|
||||
|
||||
function parseHfDatasetRepoId(input: string): string | null {
|
||||
const raw = input.trim();
|
||||
if (!raw) return null;
|
||||
if (!raw.includes("://") && raw.split("/").length === 2) {
|
||||
return raw;
|
||||
}
|
||||
try {
|
||||
const url = new URL(raw);
|
||||
const parts = url.pathname.split("/").filter(Boolean);
|
||||
const datasetsIdx = parts.indexOf("datasets");
|
||||
if (datasetsIdx === -1) return null;
|
||||
const org = parts[datasetsIdx + 1];
|
||||
const repo = parts[datasetsIdx + 2];
|
||||
if (!org || !repo) return null;
|
||||
return `${org}/${repo}`;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function stringifyCell(value: unknown): string {
|
||||
if (value === null || value === undefined) return "";
|
||||
if (typeof value === "string") return value;
|
||||
if (typeof value === "number" || typeof value === "boolean") return String(value);
|
||||
try {
|
||||
return JSON.stringify(value);
|
||||
} catch {
|
||||
return String(value);
|
||||
}
|
||||
}
|
||||
|
||||
function parseOptionalInt(value: string | undefined): number | null {
|
||||
const trimmed = value?.trim();
|
||||
if (!trimmed) return null;
|
||||
const num = Number(trimmed);
|
||||
return Number.isFinite(num) ? num : null;
|
||||
}
|
||||
|
||||
export function SeedDialog({ config, onUpdate }: SeedDialogProps): ReactElement {
|
||||
const [inspectLoading, setInspectLoading] = useState(false);
|
||||
const [inspectError, setInspectError] = useState<string | null>(null);
|
||||
const [previewLoading, setPreviewLoading] = useState(false);
|
||||
const [previewError, setPreviewError] = useState<string | null>(null);
|
||||
const [previewRows, setPreviewRows] = useState<Record<string, unknown>[]>([]);
|
||||
|
||||
const repoId = useMemo(
|
||||
() => parseHfDatasetRepoId(config.hf_url ?? ""),
|
||||
[config.hf_url],
|
||||
);
|
||||
|
||||
const pathId = `${config.id}-hf-path`;
|
||||
const urlId = `${config.id}-hf-url`;
|
||||
const tokenId = `${config.id}-hf-token`;
|
||||
const splitId = `${config.id}-hf-split`;
|
||||
const samplingId = `${config.id}-sampling`;
|
||||
const selectionId = `${config.id}-selection`;
|
||||
|
||||
async function onInspect(): Promise<void> {
|
||||
setInspectError(null);
|
||||
const repo_id = repoId;
|
||||
if (!repo_id) {
|
||||
setInspectError("Invalid HF dataset URL (need /datasets/org/repo).");
|
||||
return;
|
||||
}
|
||||
setInspectLoading(true);
|
||||
try {
|
||||
const res = await inspectSeedDataset({
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
repo_id,
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
hf_token: config.hf_token?.trim() || null,
|
||||
split: config.hf_split?.trim() || null,
|
||||
});
|
||||
const splits = res.splits ?? [];
|
||||
const globs = res.globs_by_split ?? {};
|
||||
let nextSplit = "";
|
||||
if (config.hf_split && splits.includes(config.hf_split)) {
|
||||
nextSplit = config.hf_split;
|
||||
} else if (splits[0]) {
|
||||
nextSplit = splits[0];
|
||||
}
|
||||
|
||||
const nextPath = nextSplit ? (globs[nextSplit] ?? "") : "";
|
||||
onUpdate({
|
||||
hf_repo_id: res.repo_id,
|
||||
seed_splits: splits,
|
||||
seed_globs_by_split: globs,
|
||||
seed_columns: res.columns ?? [],
|
||||
hf_split: nextSplit,
|
||||
hf_path: nextPath || config.hf_path,
|
||||
});
|
||||
} catch (err) {
|
||||
setInspectError(err instanceof Error ? err.message : "Inspect failed.");
|
||||
} finally {
|
||||
setInspectLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
async function onPreview(): Promise<void> {
|
||||
setPreviewError(null);
|
||||
const hf_path = config.hf_path.trim();
|
||||
if (!hf_path) {
|
||||
setPreviewError("HF path missing (Load first).");
|
||||
return;
|
||||
}
|
||||
setPreviewLoading(true);
|
||||
try {
|
||||
const res = await previewSeedDataset({
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
hf_path,
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
hf_token: config.hf_token?.trim() || null,
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
sampling_strategy: config.sampling_strategy,
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
selection_type: config.selection_type,
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
selection_start: parseOptionalInt(config.selection_start),
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
selection_end: parseOptionalInt(config.selection_end),
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
selection_index: parseOptionalInt(config.selection_index),
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
selection_num_partitions: parseOptionalInt(config.selection_num_partitions),
|
||||
limit: 10,
|
||||
});
|
||||
const rows = res.rows ?? [];
|
||||
setPreviewRows(rows);
|
||||
if ((config.seed_columns?.length ?? 0) === 0 && rows[0]) {
|
||||
onUpdate({ seed_columns: Object.keys(rows[0]) });
|
||||
}
|
||||
} catch (err) {
|
||||
setPreviewError(err instanceof Error ? err.message : "Preview failed.");
|
||||
} finally {
|
||||
setPreviewLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
const previewColumns = useMemo(() => {
|
||||
const cols = config.seed_columns ?? [];
|
||||
if (cols.length > 0) return cols;
|
||||
if (previewRows[0]) return Object.keys(previewRows[0]);
|
||||
return [];
|
||||
}, [config.seed_columns, previewRows]);
|
||||
|
||||
return (
|
||||
<Tabs defaultValue="config" className="w-full">
|
||||
<TabsList className="w-full">
|
||||
<TabsTrigger value="config">Config</TabsTrigger>
|
||||
<TabsTrigger value="preview">Preview</TabsTrigger>
|
||||
</TabsList>
|
||||
|
||||
<TabsContent value="config" className="pt-3">
|
||||
<div className="space-y-4">
|
||||
<div className="grid gap-2">
|
||||
<label
|
||||
className="text-xs font-semibold uppercase text-muted-foreground"
|
||||
htmlFor={urlId}
|
||||
>
|
||||
HF dataset URL
|
||||
</label>
|
||||
<div className="flex items-center gap-2">
|
||||
<Input
|
||||
id={urlId}
|
||||
className="nodrag flex-1"
|
||||
placeholder="https://huggingface.co/datasets/org/repo"
|
||||
value={config.hf_url ?? ""}
|
||||
onChange={(e) => onUpdate({ hf_url: e.target.value })}
|
||||
/>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
className="nodrag"
|
||||
onClick={() => void onInspect()}
|
||||
disabled={inspectLoading}
|
||||
>
|
||||
{inspectLoading ? "Loading..." : "Load"}
|
||||
</Button>
|
||||
{inspectLoading && <Spinner />}
|
||||
</div>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Repo: {repoId ?? "-"}
|
||||
</p>
|
||||
</div>
|
||||
{inspectError && (
|
||||
<p className="text-xs text-red-600">{inspectError}</p>
|
||||
)}
|
||||
|
||||
{(config.seed_splits?.length ?? 0) > 0 && (
|
||||
<div className="grid gap-2">
|
||||
<label
|
||||
className="text-xs font-semibold uppercase text-muted-foreground"
|
||||
htmlFor={splitId}
|
||||
>
|
||||
Split
|
||||
</label>
|
||||
<Select
|
||||
value={config.hf_split ?? ""}
|
||||
onValueChange={(value) => {
|
||||
const nextPath = config.seed_globs_by_split?.[value] ?? "";
|
||||
onUpdate({ hf_split: value, hf_path: nextPath || config.hf_path });
|
||||
}}
|
||||
>
|
||||
<SelectTrigger className="nodrag w-full" id={splitId}>
|
||||
<SelectValue placeholder="Select split" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{(config.seed_splits ?? []).map((value) => (
|
||||
<SelectItem key={value} value={value}>
|
||||
{value}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="grid gap-2">
|
||||
<label
|
||||
className="text-xs font-semibold uppercase text-muted-foreground"
|
||||
htmlFor={pathId}
|
||||
>
|
||||
HF path (auto)
|
||||
</label>
|
||||
<Input
|
||||
id={pathId}
|
||||
className="nodrag"
|
||||
placeholder="datasets/org/repo/data/train-*.parquet"
|
||||
value={config.hf_path}
|
||||
onChange={(e) => onUpdate({ hf_path: e.target.value })}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-2">
|
||||
<label
|
||||
className="text-xs font-semibold uppercase text-muted-foreground"
|
||||
htmlFor={tokenId}
|
||||
>
|
||||
HF token (optional)
|
||||
</label>
|
||||
<Input
|
||||
id={tokenId}
|
||||
className="nodrag"
|
||||
placeholder="hf_..."
|
||||
value={config.hf_token ?? ""}
|
||||
onChange={(e) => onUpdate({ hf_token: e.target.value })}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-2">
|
||||
<label
|
||||
className="text-xs font-semibold uppercase text-muted-foreground"
|
||||
htmlFor={samplingId}
|
||||
>
|
||||
Sampling strategy
|
||||
</label>
|
||||
<Select
|
||||
value={config.sampling_strategy}
|
||||
onValueChange={(value) =>
|
||||
onUpdate({ sampling_strategy: value as SeedSamplingStrategy })
|
||||
}
|
||||
>
|
||||
<SelectTrigger className="nodrag w-full" id={samplingId}>
|
||||
<SelectValue placeholder="Select sampling" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{SAMPLING_OPTIONS.map((opt) => (
|
||||
<SelectItem key={opt.value} value={opt.value}>
|
||||
{opt.label}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-2">
|
||||
<label
|
||||
className="text-xs font-semibold uppercase text-muted-foreground"
|
||||
htmlFor={selectionId}
|
||||
>
|
||||
Selection strategy
|
||||
</label>
|
||||
<Select
|
||||
value={config.selection_type}
|
||||
onValueChange={(value) =>
|
||||
onUpdate({ selection_type: value as SeedSelectionType })
|
||||
}
|
||||
>
|
||||
<SelectTrigger className="nodrag w-full" id={selectionId}>
|
||||
<SelectValue placeholder="Select selection" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{SELECTION_OPTIONS.map((opt) => (
|
||||
<SelectItem key={opt.value} value={opt.value}>
|
||||
{opt.label}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
{config.selection_type === "index_range" && (
|
||||
<div className="grid grid-cols-2 gap-3">
|
||||
<div className="grid gap-2">
|
||||
<label className="text-xs font-semibold uppercase text-muted-foreground">
|
||||
Start
|
||||
</label>
|
||||
<Input
|
||||
className="nodrag"
|
||||
inputMode="numeric"
|
||||
value={config.selection_start ?? ""}
|
||||
onChange={(e) => onUpdate({ selection_start: e.target.value })}
|
||||
/>
|
||||
</div>
|
||||
<div className="grid gap-2">
|
||||
<label className="text-xs font-semibold uppercase text-muted-foreground">
|
||||
End
|
||||
</label>
|
||||
<Input
|
||||
className="nodrag"
|
||||
inputMode="numeric"
|
||||
value={config.selection_end ?? ""}
|
||||
onChange={(e) => onUpdate({ selection_end: e.target.value })}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{config.selection_type === "partition_block" && (
|
||||
<div className="grid grid-cols-2 gap-3">
|
||||
<div className="grid gap-2">
|
||||
<label className="text-xs font-semibold uppercase text-muted-foreground">
|
||||
Index
|
||||
</label>
|
||||
<Input
|
||||
className="nodrag"
|
||||
inputMode="numeric"
|
||||
value={config.selection_index ?? ""}
|
||||
onChange={(e) => onUpdate({ selection_index: e.target.value })}
|
||||
/>
|
||||
</div>
|
||||
<div className="grid gap-2">
|
||||
<label className="text-xs font-semibold uppercase text-muted-foreground">
|
||||
Partitions
|
||||
</label>
|
||||
<Input
|
||||
className="nodrag"
|
||||
inputMode="numeric"
|
||||
value={config.selection_num_partitions ?? ""}
|
||||
onChange={(e) =>
|
||||
onUpdate({ selection_num_partitions: e.target.value })
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Seed columns auto-add. Reference by name (ex: {"{{ rubrics }}"}).
|
||||
</p>
|
||||
</div>
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="preview" className="pt-3">
|
||||
<div className="space-y-4">
|
||||
{previewRows.length === 0 ? (
|
||||
<div className="flex w-full items-center justify-center">
|
||||
<Empty className="max-w-lg">
|
||||
<EmptyHeader>
|
||||
<EmptyTitle>Preview samples</EmptyTitle>
|
||||
<EmptyDescription>
|
||||
Load 10 rows to see columns and sample values.
|
||||
</EmptyDescription>
|
||||
</EmptyHeader>
|
||||
<EmptyContent>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
className="nodrag"
|
||||
onClick={() => void onPreview()}
|
||||
disabled={previewLoading}
|
||||
>
|
||||
{previewLoading ? "Loading..." : "Load 10 rows"}
|
||||
</Button>
|
||||
</EmptyContent>
|
||||
</Empty>
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-3">
|
||||
<div className="flex items-center gap-3">
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
className="nodrag"
|
||||
onClick={() => void onPreview()}
|
||||
disabled={previewLoading}
|
||||
>
|
||||
{previewLoading ? "Loading..." : "Reload 10 rows"}
|
||||
</Button>
|
||||
{previewLoading && <Spinner />}
|
||||
</div>
|
||||
<Table className="border border-border/60 rounded-xl">
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
{previewColumns.map((col) => (
|
||||
<TableHead key={col} className="max-w-[260px]">
|
||||
{col}
|
||||
</TableHead>
|
||||
))}
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{previewRows.map((row, idx) => (
|
||||
<TableRow key={idx}>
|
||||
{previewColumns.map((col) => (
|
||||
<TableCell key={col} className="max-w-[260px]">
|
||||
<div className="truncate">{stringifyCell(row[col])}</div>
|
||||
</TableCell>
|
||||
))}
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
)}
|
||||
{previewError && <p className="text-xs text-red-600">{previewError}</p>}
|
||||
</div>
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
);
|
||||
}
|
||||
|
|
@ -76,6 +76,33 @@ function toErrorMessage(error: unknown, fallback: string): string {
|
|||
return fallback;
|
||||
}
|
||||
|
||||
async function copyTextToClipboard(text: string): Promise<boolean> {
|
||||
try {
|
||||
if (navigator.clipboard?.writeText) {
|
||||
await navigator.clipboard.writeText(text);
|
||||
return true;
|
||||
}
|
||||
} catch {
|
||||
// fallthrough to legacy path
|
||||
}
|
||||
|
||||
try {
|
||||
const textarea = document.createElement("textarea");
|
||||
textarea.value = text;
|
||||
textarea.setAttribute("readonly", "");
|
||||
textarea.style.position = "fixed";
|
||||
textarea.style.top = "0";
|
||||
textarea.style.left = "-9999px";
|
||||
document.body.appendChild(textarea);
|
||||
textarea.select();
|
||||
const ok = document.execCommand("copy");
|
||||
document.body.removeChild(textarea);
|
||||
return ok;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export function useRecipeStudioActions({
|
||||
recipeId,
|
||||
initialRecipeName,
|
||||
|
|
@ -237,13 +264,11 @@ export function useRecipeStudioActions({
|
|||
toastError("Copy failed", payloadErrorMessage);
|
||||
return;
|
||||
}
|
||||
if (!navigator.clipboard) {
|
||||
console.error("Clipboard not available.");
|
||||
toastError("Copy failed", "Clipboard not available.");
|
||||
return;
|
||||
}
|
||||
try {
|
||||
await navigator.clipboard.writeText(JSON.stringify(payload, null, 2));
|
||||
const ok = await copyTextToClipboard(JSON.stringify(payload, null, 2));
|
||||
if (!ok) {
|
||||
throw new Error("Clipboard not available.");
|
||||
}
|
||||
setCopied(true);
|
||||
window.setTimeout(() => setCopied(false), 1500);
|
||||
toastSuccess("Payload copied");
|
||||
|
|
|
|||
|
|
@ -91,6 +91,7 @@ export function RecipeStudioPage({
|
|||
onEdgesChange,
|
||||
onConnect,
|
||||
addSamplerNode,
|
||||
addSeedNode,
|
||||
addLlmNode,
|
||||
addModelProviderNode,
|
||||
addModelConfigNode,
|
||||
|
|
@ -126,6 +127,7 @@ export function RecipeStudioPage({
|
|||
onEdgesChange: state.onEdgesChange,
|
||||
onConnect: state.onConnect,
|
||||
addSamplerNode: state.addSamplerNode,
|
||||
addSeedNode: state.addSeedNode,
|
||||
addLlmNode: state.addLlmNode,
|
||||
addModelProviderNode: state.addModelProviderNode,
|
||||
addModelConfigNode: state.addModelConfigNode,
|
||||
|
|
@ -377,6 +379,7 @@ export function RecipeStudioPage({
|
|||
sheetView={sheetView}
|
||||
onViewChange={setSheetView}
|
||||
onAddSampler={addSamplerNode}
|
||||
onAddSeed={addSeedNode}
|
||||
onAddLlm={addLlmNode}
|
||||
onAddModelProvider={addModelProviderNode}
|
||||
onAddModelConfig={addModelConfigNode}
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ import {
|
|||
updateNodeData,
|
||||
} from "./recipe-studio-helpers";
|
||||
|
||||
type SheetView = "root" | "sampler" | "llm" | "expression" | "processor";
|
||||
type SheetView = "root" | "sampler" | "seed" | "llm" | "expression" | "processor";
|
||||
|
||||
type RecipeStudioState = {
|
||||
nodes: RecipeNode[];
|
||||
|
|
@ -63,6 +63,7 @@ type RecipeStudioState = {
|
|||
setLayoutDirection: (direction: LayoutDirection) => void;
|
||||
applyLayout: () => void;
|
||||
addSamplerNode: (type: SamplerType) => void;
|
||||
addSeedNode: () => void;
|
||||
addLlmNode: (type: LlmType) => void;
|
||||
addModelProviderNode: () => void;
|
||||
addModelConfigNode: () => void;
|
||||
|
|
@ -165,6 +166,23 @@ export const useRecipeStudioStore = create<RecipeStudioState>((set, get) => ({
|
|||
}),
|
||||
addSamplerNode: (type) =>
|
||||
set((state) => buildAddedNodeState(state, "sampler", type)),
|
||||
addSeedNode: () =>
|
||||
set((state) => {
|
||||
const existing = Object.values(state.configs).find(
|
||||
(config) => config.kind === "seed",
|
||||
);
|
||||
if (!existing) {
|
||||
return buildAddedNodeState(state, "seed", "seed");
|
||||
}
|
||||
return {
|
||||
activeConfigId: existing.id,
|
||||
dialogOpen: true,
|
||||
nodes: state.nodes.map((node) => ({
|
||||
...node,
|
||||
selected: node.id === existing.id,
|
||||
})),
|
||||
};
|
||||
}),
|
||||
addLlmNode: (type) => set((state) => buildAddedNodeState(state, "llm", type)),
|
||||
addModelProviderNode: () =>
|
||||
set((state) => buildAddedNodeState(state, "llm", "model_provider")),
|
||||
|
|
|
|||
|
|
@ -18,15 +18,25 @@ export type ExpressionDtype = "str" | "int" | "float" | "bool";
|
|||
|
||||
export type LayoutDirection = "LR" | "TB";
|
||||
|
||||
export type SeedSamplingStrategy = "ordered" | "shuffle";
|
||||
export type SeedSelectionType = "none" | "index_range" | "partition_block";
|
||||
|
||||
export type RecipeNodeData = {
|
||||
title: string;
|
||||
name: string;
|
||||
kind: "sampler" | "llm" | "expression" | "model_provider" | "model_config";
|
||||
kind:
|
||||
| "sampler"
|
||||
| "llm"
|
||||
| "expression"
|
||||
| "seed"
|
||||
| "model_provider"
|
||||
| "model_config";
|
||||
subtype: string;
|
||||
blockType:
|
||||
| SamplerType
|
||||
| LlmType
|
||||
| "expression"
|
||||
| "seed"
|
||||
| "model_provider"
|
||||
| "model_config";
|
||||
layoutDirection?: LayoutDirection;
|
||||
|
|
@ -204,6 +214,31 @@ export type ExpressionConfig = {
|
|||
dtype: ExpressionDtype;
|
||||
};
|
||||
|
||||
export type SeedConfig = {
|
||||
id: string;
|
||||
kind: "seed";
|
||||
name: string;
|
||||
drop?: boolean;
|
||||
// ui-only (serialized in seed_config)
|
||||
hf_url?: string;
|
||||
hf_repo_id?: string;
|
||||
hf_split?: string;
|
||||
hf_path: string;
|
||||
hf_token?: string;
|
||||
hf_endpoint?: string;
|
||||
seed_splits?: string[];
|
||||
// ui-only
|
||||
// biome-ignore lint/style/useNamingConvention: ui schema
|
||||
seed_globs_by_split?: Record<string, string>;
|
||||
seed_columns?: string[];
|
||||
sampling_strategy: SeedSamplingStrategy;
|
||||
selection_type: SeedSelectionType;
|
||||
selection_start?: string;
|
||||
selection_end?: string;
|
||||
selection_index?: string;
|
||||
selection_num_partitions?: string;
|
||||
};
|
||||
|
||||
export type SchemaTransformProcessorConfig = {
|
||||
id: string;
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
|
|
@ -218,5 +253,6 @@ export type NodeConfig =
|
|||
| SamplerConfig
|
||||
| LlmConfig
|
||||
| ExpressionConfig
|
||||
| SeedConfig
|
||||
| ModelProviderConfig
|
||||
| ModelConfig;
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import type {
|
|||
ModelConfig,
|
||||
ModelProviderConfig,
|
||||
NodeConfig,
|
||||
SeedConfig,
|
||||
SamplerConfig,
|
||||
SamplerType,
|
||||
} from "../types";
|
||||
|
|
@ -281,3 +282,30 @@ export function makeExpressionConfig(
|
|||
dtype: "str",
|
||||
};
|
||||
}
|
||||
|
||||
export function makeSeedConfig(
|
||||
id: string,
|
||||
existing: NodeConfig[],
|
||||
): SeedConfig {
|
||||
return {
|
||||
id,
|
||||
kind: "seed",
|
||||
name: nextName(existing, "seed"),
|
||||
drop: false,
|
||||
hf_url: "",
|
||||
hf_repo_id: "",
|
||||
hf_split: "",
|
||||
hf_path: "",
|
||||
hf_token: "",
|
||||
hf_endpoint: "https://huggingface.co",
|
||||
seed_splits: [],
|
||||
seed_globs_by_split: {},
|
||||
seed_columns: [],
|
||||
sampling_strategy: "ordered",
|
||||
selection_type: "none",
|
||||
selection_start: "0",
|
||||
selection_end: "10",
|
||||
selection_index: "0",
|
||||
selection_num_partitions: "1",
|
||||
};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import {
|
|||
parseModelConfig,
|
||||
parseModelProvider,
|
||||
} from "./parsers";
|
||||
import { parseSeedConfig } from "./parsers/seed-config-parser";
|
||||
import { buildNodes, parseUi } from "./ui";
|
||||
import type { ImportResult } from "./types";
|
||||
|
||||
|
|
@ -22,6 +23,7 @@ type RecipeInput = {
|
|||
mcp_providers?: unknown;
|
||||
tool_configs?: unknown;
|
||||
processors?: unknown;
|
||||
seed_config?: unknown;
|
||||
};
|
||||
|
||||
type UiInput = {
|
||||
|
|
@ -218,6 +220,15 @@ export function importRecipePayload(input: string): ImportResult {
|
|||
|
||||
let nextId = 1;
|
||||
|
||||
if (recipe.seed_config) {
|
||||
const id = `n${nextId}`;
|
||||
nextId += 1;
|
||||
const seedConfig = parseSeedConfig(recipe.seed_config, id);
|
||||
if (seedConfig) {
|
||||
configs.push(seedConfig);
|
||||
}
|
||||
}
|
||||
|
||||
if (Array.isArray(recipe.model_providers)) {
|
||||
recipe.model_providers.forEach((provider, index) => {
|
||||
if (!isRecord(provider)) {
|
||||
|
|
|
|||
|
|
@ -13,7 +13,8 @@ type ColumnParser = (
|
|||
) => NodeConfig | null;
|
||||
|
||||
const COLUMN_PARSERS: Record<string, ColumnParser> = {
|
||||
sampler: (column, name, id, errors) => parseSampler(column, name, id, errors),
|
||||
sampler: (column, name, id, errors) =>
|
||||
parseSampler(column, name, id, errors),
|
||||
expression: (column, name, id) => parseExpression(column, name, id),
|
||||
"llm-text": (column, name, id) => parseLlm(column, name, id),
|
||||
"llm-structured": (column, name, id) => parseLlm(column, name, id),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,103 @@
|
|||
import type {
|
||||
SeedConfig,
|
||||
SeedSamplingStrategy,
|
||||
SeedSelectionType,
|
||||
} from "../../../types";
|
||||
import { isRecord, readString } from "../helpers";
|
||||
|
||||
function normalizeSampling(value: unknown): SeedSamplingStrategy {
|
||||
const raw = readString(value);
|
||||
if (raw === "shuffle") return "shuffle";
|
||||
return "ordered";
|
||||
}
|
||||
|
||||
function makeDefaultSeedConfig(id: string): SeedConfig {
|
||||
return {
|
||||
id,
|
||||
kind: "seed",
|
||||
name: "seed",
|
||||
drop: false,
|
||||
hf_url: "",
|
||||
hf_repo_id: "",
|
||||
hf_split: "",
|
||||
hf_path: "",
|
||||
hf_token: "",
|
||||
hf_endpoint: "https://huggingface.co",
|
||||
seed_splits: [],
|
||||
seed_globs_by_split: {},
|
||||
seed_columns: [],
|
||||
sampling_strategy: "ordered",
|
||||
selection_type: "none",
|
||||
selection_start: "0",
|
||||
selection_end: "10",
|
||||
selection_index: "0",
|
||||
selection_num_partitions: "1",
|
||||
};
|
||||
}
|
||||
|
||||
function parseSeedSettings(seedConfigRaw: unknown): Partial<SeedConfig> {
|
||||
if (!isRecord(seedConfigRaw)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const sampling_strategy = normalizeSampling(seedConfigRaw.sampling_strategy);
|
||||
|
||||
let hf_path = "";
|
||||
let hf_token = "";
|
||||
let hf_endpoint = "https://huggingface.co";
|
||||
const sourceRaw = seedConfigRaw.source;
|
||||
if (isRecord(sourceRaw) && readString(sourceRaw.seed_type) === "hf") {
|
||||
hf_path = readString(sourceRaw.path) ?? "";
|
||||
hf_token = readString(sourceRaw.token) ?? "";
|
||||
hf_endpoint = readString(sourceRaw.endpoint) ?? hf_endpoint;
|
||||
}
|
||||
|
||||
let selection_type: SeedSelectionType = "none";
|
||||
let selection_start = "0";
|
||||
let selection_end = "10";
|
||||
let selection_index = "0";
|
||||
let selection_num_partitions = "1";
|
||||
const selectionRaw = seedConfigRaw.selection_strategy;
|
||||
if (isRecord(selectionRaw)) {
|
||||
if (
|
||||
typeof selectionRaw.start === "number" &&
|
||||
typeof selectionRaw.end === "number"
|
||||
) {
|
||||
selection_type = "index_range";
|
||||
selection_start = String(selectionRaw.start);
|
||||
selection_end = String(selectionRaw.end);
|
||||
} else if (
|
||||
typeof selectionRaw.index === "number" &&
|
||||
typeof selectionRaw.num_partitions === "number"
|
||||
) {
|
||||
selection_type = "partition_block";
|
||||
selection_index = String(selectionRaw.index);
|
||||
selection_num_partitions = String(selectionRaw.num_partitions);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
hf_path,
|
||||
hf_token,
|
||||
hf_endpoint,
|
||||
sampling_strategy,
|
||||
selection_type,
|
||||
selection_start,
|
||||
selection_end,
|
||||
selection_index,
|
||||
selection_num_partitions,
|
||||
};
|
||||
}
|
||||
|
||||
export function parseSeedConfig(
|
||||
seedConfigRaw: unknown,
|
||||
id: string,
|
||||
): SeedConfig | null {
|
||||
if (!seedConfigRaw) {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
...makeDefaultSeedConfig(id),
|
||||
...parseSeedSettings(seedConfigRaw), // payload-only fields override ui defaults
|
||||
};
|
||||
}
|
||||
|
|
@ -4,6 +4,7 @@ export {
|
|||
makeModelConfig,
|
||||
makeModelProviderConfig,
|
||||
makeSamplerConfig,
|
||||
makeSeedConfig,
|
||||
} from "./config-factories";
|
||||
export {
|
||||
labelForExpression,
|
||||
|
|
|
|||
|
|
@ -29,6 +29,16 @@ export function nodeDataFromConfig(
|
|||
layoutDirection,
|
||||
};
|
||||
}
|
||||
if (config.kind === "seed") {
|
||||
return {
|
||||
title: "Seed",
|
||||
kind: "seed",
|
||||
subtype: "Hugging Face",
|
||||
blockType: "seed",
|
||||
name: config.name,
|
||||
layoutDirection,
|
||||
};
|
||||
}
|
||||
if (config.kind === "model_provider") {
|
||||
return {
|
||||
title: "Model Provider",
|
||||
|
|
|
|||
|
|
@ -16,6 +16,9 @@ import {
|
|||
buildModelProvider,
|
||||
buildProcessors,
|
||||
buildSamplerColumn,
|
||||
buildSeedConfig,
|
||||
buildSeedDropProcessor,
|
||||
pickFirstSeedConfig,
|
||||
} from "./builders";
|
||||
import type { RecipePayloadResult } from "./types";
|
||||
import {
|
||||
|
|
@ -65,6 +68,7 @@ export function buildRecipePayload(
|
|||
const toolConfigJsonByAlias = new Map<string, string>();
|
||||
const nameSet = new Set<string>();
|
||||
const nameToConfig = new Map<string, NodeConfig>();
|
||||
const firstSeed = pickFirstSeedConfig(configs);
|
||||
|
||||
for (const node of nodes) {
|
||||
const config = configs[node.id];
|
||||
|
|
@ -74,10 +78,12 @@ export function buildRecipePayload(
|
|||
for (const error of getConfigErrors(config)) {
|
||||
errors.push(`${config.name}: ${error}`);
|
||||
}
|
||||
if (nameSet.has(config.name)) {
|
||||
errors.push(`Duplicate node name: ${config.name}.`);
|
||||
if (config.kind !== "seed") {
|
||||
if (nameSet.has(config.name)) {
|
||||
errors.push(`Duplicate node name: ${config.name}.`);
|
||||
}
|
||||
nameSet.add(config.name);
|
||||
}
|
||||
nameSet.add(config.name);
|
||||
|
||||
if (config.kind === "sampler") {
|
||||
nameToConfig.set(config.name, config);
|
||||
|
|
@ -135,6 +141,10 @@ export function buildRecipePayload(
|
|||
nameToConfig.set(config.name, config);
|
||||
continue;
|
||||
}
|
||||
if (config.kind === "seed") {
|
||||
// SeedConfig is global config (seed_config); seed-dataset columns are added by DataDesigner.
|
||||
continue;
|
||||
}
|
||||
if (config.kind === "model_provider") {
|
||||
modelProviderNames.add(config.name);
|
||||
modelProviders.push(buildModelProvider(config, errors));
|
||||
|
|
@ -166,6 +176,9 @@ export function buildRecipePayload(
|
|||
if (!config) {
|
||||
return [];
|
||||
}
|
||||
if (config.kind === "seed") {
|
||||
return [];
|
||||
}
|
||||
const width = getNodeWidth(node);
|
||||
return [
|
||||
{
|
||||
|
|
@ -195,6 +208,13 @@ export function buildRecipePayload(
|
|||
];
|
||||
});
|
||||
const recipeProcessors = buildProcessors(processors, errors);
|
||||
const seedConfig = firstSeed ? buildSeedConfig(firstSeed, errors) : undefined;
|
||||
const seedDropProcessor = firstSeed
|
||||
? buildSeedDropProcessor(firstSeed, errors)
|
||||
: null;
|
||||
if (seedDropProcessor) {
|
||||
recipeProcessors.push(seedDropProcessor);
|
||||
}
|
||||
|
||||
return {
|
||||
errors,
|
||||
|
|
@ -207,6 +227,8 @@ export function buildRecipePayload(
|
|||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
model_configs: modelConfigs,
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
seed_config: seedConfig,
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
tool_configs: toolConfigs,
|
||||
columns,
|
||||
processors: recipeProcessors,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,90 @@
|
|||
import type { NodeConfig, SeedConfig } from "../../types";
|
||||
|
||||
function parseIntStrict(value: string | undefined): number | null {
|
||||
const trimmed = value?.trim();
|
||||
if (!trimmed) return null;
|
||||
const num = Number(value);
|
||||
if (!Number.isFinite(num) || !Number.isInteger(num)) return null;
|
||||
return num;
|
||||
}
|
||||
|
||||
export function buildSeedConfig(
|
||||
config: SeedConfig,
|
||||
errors: string[],
|
||||
): Record<string, unknown> | undefined {
|
||||
const path = config.hf_path.trim();
|
||||
if (!path) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const endpoint = config.hf_endpoint?.trim() || "https://huggingface.co";
|
||||
const token = config.hf_token?.trim() || null;
|
||||
|
||||
let selectionStrategy: Record<string, unknown> | null = null;
|
||||
if (config.selection_type === "index_range") {
|
||||
const start = parseIntStrict(config.selection_start);
|
||||
const end = parseIntStrict(config.selection_end);
|
||||
if (start === null || end === null) {
|
||||
errors.push(`Seed ${config.name}: selection index range invalid.`);
|
||||
return undefined;
|
||||
}
|
||||
selectionStrategy = { start, end };
|
||||
} else if (config.selection_type === "partition_block") {
|
||||
const index = parseIntStrict(config.selection_index);
|
||||
const numPartitions = parseIntStrict(config.selection_num_partitions);
|
||||
if (index === null || numPartitions === null) {
|
||||
errors.push(`Seed ${config.name}: selection partition invalid.`);
|
||||
return undefined;
|
||||
}
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
selectionStrategy = { index, num_partitions: numPartitions };
|
||||
}
|
||||
|
||||
return {
|
||||
source: {
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
seed_type: "hf",
|
||||
path,
|
||||
token,
|
||||
endpoint,
|
||||
},
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
sampling_strategy: config.sampling_strategy,
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
selection_strategy: selectionStrategy,
|
||||
};
|
||||
}
|
||||
|
||||
export function pickFirstSeedConfig(
|
||||
configs: Record<string, NodeConfig>,
|
||||
): SeedConfig | null {
|
||||
for (const config of Object.values(configs)) {
|
||||
if (config.kind === "seed") {
|
||||
return config;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function buildSeedDropProcessor(
|
||||
config: SeedConfig,
|
||||
errors: string[],
|
||||
): Record<string, unknown> | null {
|
||||
if (!config.drop) {
|
||||
return null;
|
||||
}
|
||||
const cols = (config.seed_columns ?? []).map((c) => c.trim()).filter(Boolean);
|
||||
if (cols.length === 0) {
|
||||
errors.push(`Seed ${config.name}: drop enabled but no seed columns loaded.`);
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
processor_type: "drop_columns",
|
||||
name: "drop_seed_columns",
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
build_stage: "post_batch",
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
column_names: cols,
|
||||
};
|
||||
}
|
||||
|
|
@ -2,3 +2,8 @@ export { buildLlmColumn, buildLlmMcpProvider, buildLlmToolConfig } from "./build
|
|||
export { buildModelConfig, buildModelProvider } from "./builders-model";
|
||||
export { buildExpressionColumn, buildProcessors } from "./builders-processors";
|
||||
export { buildSamplerColumn } from "./builders-sampler";
|
||||
export {
|
||||
buildSeedConfig,
|
||||
buildSeedDropProcessor,
|
||||
pickFirstSeedConfig,
|
||||
} from "./builders-seed";
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ export type RecipePayload = {
|
|||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
model_configs: Record<string, unknown>[];
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
seed_config?: Record<string, unknown>;
|
||||
// biome-ignore lint/style/useNamingConvention: api schema
|
||||
tool_configs: Record<string, unknown>[];
|
||||
columns: Record<string, unknown>[];
|
||||
processors: Record<string, unknown>[];
|
||||
|
|
|
|||
|
|
@ -8,6 +8,14 @@ function parseNumber(value?: string): number | null {
|
|||
return Number.isFinite(num) ? num : null;
|
||||
}
|
||||
|
||||
function parseIntNumber(value?: string): number | null {
|
||||
const num = parseNumber(value);
|
||||
if (num === null || !Number.isInteger(num)) {
|
||||
return null;
|
||||
}
|
||||
return num;
|
||||
}
|
||||
|
||||
function parseAgeRange(value?: string): [number, number] | null {
|
||||
if (!value) {
|
||||
return null;
|
||||
|
|
@ -207,5 +215,44 @@ export function getConfigErrors(config: NodeConfig | null): string[] {
|
|||
errors.push("Expression is required.");
|
||||
}
|
||||
}
|
||||
if (config.kind === "seed") {
|
||||
if (!config.hf_path.trim()) {
|
||||
errors.push("HF dataset path is required.");
|
||||
}
|
||||
if (config.hf_endpoint?.trim() && !config.hf_endpoint.trim().startsWith("http")) {
|
||||
errors.push("HF endpoint must start with http.");
|
||||
}
|
||||
if (config.drop && (config.seed_columns?.length ?? 0) === 0) {
|
||||
errors.push("Seed drop needs loaded columns (open Seed Preview).");
|
||||
}
|
||||
|
||||
if (config.selection_type === "index_range") {
|
||||
const start = parseIntNumber(config.selection_start);
|
||||
const end = parseIntNumber(config.selection_end);
|
||||
if (start === null || end === null) {
|
||||
errors.push("Index range start/end must be integers.");
|
||||
} else {
|
||||
if (start < 0 || end < 0) {
|
||||
errors.push("Index range start/end must be >= 0.");
|
||||
}
|
||||
if (end < start) {
|
||||
errors.push("Index range end must be >= start.");
|
||||
}
|
||||
}
|
||||
}
|
||||
if (config.selection_type === "partition_block") {
|
||||
const index = parseIntNumber(config.selection_index);
|
||||
const parts = parseIntNumber(config.selection_num_partitions);
|
||||
if (index === null || parts === null) {
|
||||
errors.push("Partition index/num_partitions must be integers.");
|
||||
} else {
|
||||
if (index < 0) errors.push("Partition index must be >= 0.");
|
||||
if (parts < 1) errors.push("Partition num_partitions must be >= 1.");
|
||||
if (parts >= 1 && index >= parts) {
|
||||
errors.push("Partition index must be < num_partitions.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return errors;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,31 +36,55 @@ export function getAvailableRefItems(
|
|||
const items: AvailableRefItem[] = [];
|
||||
|
||||
for (const config of Object.values(configs)) {
|
||||
if (config.id === currentId) continue;
|
||||
if (config.kind === "model_provider" || config.kind === "model_config") continue;
|
||||
if (config.id === currentId) {
|
||||
continue;
|
||||
}
|
||||
if (config.kind === "model_provider" || config.kind === "model_config") {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (config.kind === "sampler") {
|
||||
items.push({ ref: config.name, kind: "sampler", subtype: config.sampler_type });
|
||||
items.push({
|
||||
ref: config.name,
|
||||
kind: "sampler",
|
||||
subtype: config.sampler_type,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (config.kind === "expression") {
|
||||
items.push({ ref: config.name, kind: "expression", subtype: config.dtype });
|
||||
items.push({
|
||||
ref: config.name,
|
||||
kind: "expression",
|
||||
subtype: config.dtype,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (config.kind === "llm") {
|
||||
items.push({ ref: config.name, kind: "llm", subtype: config.llm_type });
|
||||
if (config.llm_type === "structured" && config.output_format) {
|
||||
for (const ref of getStructuredRefs(config.name, config.output_format)) {
|
||||
items.push({
|
||||
ref: ref.ref,
|
||||
kind: "llm",
|
||||
subtype: config.llm_type,
|
||||
valueType: ref.valueType,
|
||||
});
|
||||
}
|
||||
if (config.kind === "seed") {
|
||||
for (const col of config.seed_columns ?? []) {
|
||||
const name = col.trim();
|
||||
if (!name) continue;
|
||||
items.push({ ref: name, kind: "seed", subtype: "seed" });
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (config.kind !== "llm") {
|
||||
continue;
|
||||
}
|
||||
|
||||
items.push({ ref: config.name, kind: "llm", subtype: config.llm_type });
|
||||
if (config.llm_type !== "structured" || !config.output_format) {
|
||||
continue;
|
||||
}
|
||||
for (const ref of getStructuredRefs(config.name, config.output_format)) {
|
||||
items.push({
|
||||
ref: ref.ref,
|
||||
kind: "llm",
|
||||
subtype: config.llm_type,
|
||||
valueType: ref.valueType,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -73,4 +97,3 @@ export function getAvailableVariables(
|
|||
): string[] {
|
||||
return getAvailableRefItems(configs, currentId).map((item) => item.ref);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,14 @@ export default defineConfig({
|
|||
target: "http://127.0.0.1:8000",
|
||||
changeOrigin: true,
|
||||
},
|
||||
"/seed/inspect": {
|
||||
target: "http://127.0.0.1:8004",
|
||||
changeOrigin: true,
|
||||
},
|
||||
"/seed/preview": {
|
||||
target: "http://127.0.0.1:8004",
|
||||
changeOrigin: true,
|
||||
},
|
||||
"/preview": {
|
||||
target: "http://127.0.0.1:8004",
|
||||
changeOrigin: true,
|
||||
|
|
|
|||
Loading…
Reference in a new issue