♻️ refactor(store): class-based Zustand actions with flattenActions (#13383)

♻️ refactor(store): migrate slices to class actions with flattenActions

- Video store: generationConfig/Topic/Batch/createVideo as *ActionImpl; aggregate with flattenActions
- Eval store: benchmark/dataset/run/testCase as classes; top-level flattenActions
- Tool agentSkills: AgentSkillsActionImpl + Pick typing
- groupProfile: flattenActions around ActionImpl instead of spreading instance
- agentGroup: wrap chatGroupAction with flattenActions for consistent aggregation

Made-with: Cursor
This commit is contained in:
Innei 2026-03-30 23:46:35 +08:00 committed by GitHub
parent 6402656ec7
commit 491aba4dbd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 568 additions and 584 deletions

View file

@ -4,6 +4,7 @@ import { type StateCreator } from 'zustand/vanilla';
import { createDevtools } from '../middleware/createDevtools';
import { expose } from '../middleware/expose';
import { flattenActions } from '../utils/flattenActions';
import { type ChatGroupAction } from './action';
import { chatGroupAction } from './action';
import { type ChatGroupState } from './initialState';
@ -15,7 +16,7 @@ const createStore: StateCreator<ChatGroupStore, [['zustand/devtools', never]]> =
...params: Parameters<StateCreator<ChatGroupStore, [['zustand/devtools', never]]>>
) => ({
...initialChatGroupState,
...chatGroupAction(...params),
...flattenActions<ChatGroupAction>([chatGroupAction(...params)]),
});
const devtools = createDevtools('agentGroup');

View file

@ -1,161 +1,147 @@
import isEqual from 'fast-deep-equal';
import { type SWRResponse } from 'swr';
import { type StateCreator } from 'zustand/vanilla';
import { type SWRResponse } from 'swr';
import { mutate, useClientDataSWR } from '@/libs/swr';
import { agentEvalService } from '@/services/agentEval';
import { type EvalStore } from '@/store/eval/store';
import { type EvalStore } from '@/store/eval/store';
import { type StoreSetter } from '@/store/types';
import { type BenchmarkDetailDispatch,benchmarkDetailReducer } from './reducer';
import { type BenchmarkDetailDispatch, benchmarkDetailReducer } from './reducer';
const FETCH_BENCHMARKS_KEY = 'FETCH_BENCHMARKS';
const FETCH_BENCHMARK_DETAIL_KEY = 'FETCH_BENCHMARK_DETAIL';
export interface BenchmarkAction {
createBenchmark: (params: {
type Setter = StoreSetter<EvalStore>;
export const createBenchmarkSlice = (set: Setter, get: () => EvalStore, _api?: unknown) =>
new BenchmarkActionImpl(set, get, _api);
export class BenchmarkActionImpl {
readonly #get: () => EvalStore;
readonly #set: Setter;
constructor(set: Setter, get: () => EvalStore, _api?: unknown) {
void _api;
this.#set = set;
this.#get = get;
}
createBenchmark = async (params: {
description?: string;
identifier: string;
metadata?: Record<string, unknown>;
name: string;
rubrics?: any[];
tags?: string[];
}) => Promise<any>;
deleteBenchmark: (id: string) => Promise<void>;
// Internal methods
internal_dispatchBenchmarkDetail: (payload: BenchmarkDetailDispatch) => void;
internal_updateBenchmarkDetailLoading: (id: string, loading: boolean) => void;
refreshBenchmarkDetail: (id: string) => Promise<void>;
refreshBenchmarks: () => Promise<void>;
updateBenchmark: (params: {
}): Promise<any> => {
this.#set({ isCreatingBenchmark: true }, false, 'createBenchmark/start');
try {
const result = await agentEvalService.createBenchmark({
description: params.description,
identifier: params.identifier,
metadata: params.metadata,
name: params.name,
rubrics: params.rubrics ?? [],
tags: params.tags,
});
await this.#get().refreshBenchmarks();
return result;
} finally {
this.#set({ isCreatingBenchmark: false }, false, 'createBenchmark/end');
}
};
deleteBenchmark = async (id: string): Promise<void> => {
this.#set({ isDeletingBenchmark: true }, false, 'deleteBenchmark/start');
try {
await agentEvalService.deleteBenchmark(id);
await this.#get().refreshBenchmarks();
} finally {
this.#set({ isDeletingBenchmark: false }, false, 'deleteBenchmark/end');
}
};
refreshBenchmarkDetail = async (id: string): Promise<void> => {
await mutate([FETCH_BENCHMARK_DETAIL_KEY, id]);
};
refreshBenchmarks = async (): Promise<void> => {
await mutate(FETCH_BENCHMARKS_KEY);
};
updateBenchmark = async (params: {
description?: string;
id: string;
identifier: string;
metadata?: Record<string, unknown>;
name: string;
tags?: string[];
}) => Promise<void>;
useFetchBenchmarkDetail: (id?: string) => SWRResponse;
useFetchBenchmarks: () => SWRResponse;
}
export const createBenchmarkSlice: StateCreator<
EvalStore,
[['zustand/devtools', never]],
[],
BenchmarkAction
> = (set, get) => ({
createBenchmark: async (params) => {
set({ isCreatingBenchmark: true }, false, 'createBenchmark/start');
try {
const result = await agentEvalService.createBenchmark({
identifier: params.identifier,
name: params.name,
description: params.description,
metadata: params.metadata,
rubrics: params.rubrics ?? [],
tags: params.tags,
});
await get().refreshBenchmarks();
return result;
} finally {
set({ isCreatingBenchmark: false }, false, 'createBenchmark/end');
}
},
deleteBenchmark: async (id) => {
set({ isDeletingBenchmark: true }, false, 'deleteBenchmark/start');
try {
await agentEvalService.deleteBenchmark(id);
await get().refreshBenchmarks();
} finally {
set({ isDeletingBenchmark: false }, false, 'deleteBenchmark/end');
}
},
refreshBenchmarkDetail: async (id) => {
await mutate([FETCH_BENCHMARK_DETAIL_KEY, id]);
},
refreshBenchmarks: async () => {
await mutate(FETCH_BENCHMARKS_KEY);
},
updateBenchmark: async (params) => {
}): Promise<void> => {
const { id } = params;
// 1. Optimistic update
get().internal_dispatchBenchmarkDetail({
type: 'updateBenchmarkDetail',
this.#get().internal_dispatchBenchmarkDetail({
id,
type: 'updateBenchmarkDetail',
value: params,
});
// 2. Set loading
get().internal_updateBenchmarkDetailLoading(id, true);
this.#get().internal_updateBenchmarkDetailLoading(id, true);
try {
// 3. Call service
await agentEvalService.updateBenchmark({
description: params.description,
id: params.id,
identifier: params.identifier,
name: params.name,
description: params.description,
metadata: params.metadata,
name: params.name,
tags: params.tags,
});
// 4. Refresh from server
await get().refreshBenchmarks();
await get().refreshBenchmarkDetail(id);
await this.#get().refreshBenchmarks();
await this.#get().refreshBenchmarkDetail(id);
} finally {
get().internal_updateBenchmarkDetailLoading(id, false);
this.#get().internal_updateBenchmarkDetailLoading(id, false);
}
},
};
useFetchBenchmarkDetail: (id) => {
return useClientDataSWR(
useFetchBenchmarkDetail = (id?: string): SWRResponse =>
useClientDataSWR(
id ? [FETCH_BENCHMARK_DETAIL_KEY, id] : null,
() => agentEvalService.getBenchmark(id!),
{
onSuccess: (data: any) => {
get().internal_dispatchBenchmarkDetail({
type: 'setBenchmarkDetail',
this.#get().internal_dispatchBenchmarkDetail({
id: id!,
type: 'setBenchmarkDetail',
value: data,
});
get().internal_updateBenchmarkDetailLoading(id!, false);
this.#get().internal_updateBenchmarkDetailLoading(id!, false);
},
},
);
},
useFetchBenchmarks: () => {
return useClientDataSWR(FETCH_BENCHMARKS_KEY, () => agentEvalService.listBenchmarks(), {
useFetchBenchmarks = (): SWRResponse =>
useClientDataSWR(FETCH_BENCHMARKS_KEY, () => agentEvalService.listBenchmarks(), {
onSuccess: (data: any) => {
set(
this.#set(
{ benchmarkList: data, benchmarkListInit: true, isLoadingBenchmarkList: false },
false,
'useFetchBenchmarks/success',
);
},
});
},
// Internal - Dispatch to reducer
internal_dispatchBenchmarkDetail: (payload) => {
const currentMap = get().benchmarkDetailMap;
internal_dispatchBenchmarkDetail = (payload: BenchmarkDetailDispatch): void => {
const currentMap = this.#get().benchmarkDetailMap;
const nextMap = benchmarkDetailReducer(currentMap, payload);
// No need to update if map is the same
if (isEqual(nextMap, currentMap)) return;
set({ benchmarkDetailMap: nextMap }, false, `dispatchBenchmarkDetail/${payload.type}`);
},
this.#set({ benchmarkDetailMap: nextMap }, false, `dispatchBenchmarkDetail/${payload.type}`);
};
// Internal - Update loading state for specific detail
internal_updateBenchmarkDetailLoading: (id, loading) => {
set(
internal_updateBenchmarkDetailLoading = (id: string, loading: boolean): void => {
this.#set(
(state) => {
if (loading) {
return { loadingBenchmarkDetailIds: [...state.loadingBenchmarkDetailIds, id] };
@ -167,5 +153,7 @@ export const createBenchmarkSlice: StateCreator<
false,
'updateBenchmarkDetailLoading',
);
},
});
};
}
export type BenchmarkAction = Pick<BenchmarkActionImpl, keyof BenchmarkActionImpl>;

View file

@ -1,65 +1,62 @@
import isEqual from 'fast-deep-equal';
import { type SWRResponse } from 'swr';
import { type StateCreator } from 'zustand/vanilla';
import { type SWRResponse } from 'swr';
import { mutate, useClientDataSWR } from '@/libs/swr';
import { agentEvalService } from '@/services/agentEval';
import { type EvalStore } from '@/store/eval/store';
import { type EvalStore } from '@/store/eval/store';
import { type StoreSetter } from '@/store/types';
import { type DatasetDetailDispatch,datasetDetailReducer } from './reducer';
import { type DatasetDetailDispatch, datasetDetailReducer } from './reducer';
const FETCH_DATASETS_KEY = 'FETCH_DATASETS';
const FETCH_DATASET_DETAIL_KEY = 'FETCH_DATASET_DETAIL';
export interface DatasetAction {
// Internal methods
internal_dispatchDatasetDetail: (payload: DatasetDetailDispatch) => void;
internal_updateDatasetDetailLoading: (id: string, loading: boolean) => void;
refreshDatasetDetail: (id: string) => Promise<void>;
refreshDatasets: (benchmarkId: string) => Promise<void>;
type Setter = StoreSetter<EvalStore>;
useFetchDatasetDetail: (id?: string) => SWRResponse;
useFetchDatasets: (benchmarkId?: string) => SWRResponse;
}
export const createDatasetSlice = (set: Setter, get: () => EvalStore, _api?: unknown) =>
new DatasetActionImpl(set, get, _api);
export const createDatasetSlice: StateCreator<
EvalStore,
[['zustand/devtools', never]],
[],
DatasetAction
> = (set, get) => ({
refreshDatasetDetail: async (id) => {
export class DatasetActionImpl {
readonly #get: () => EvalStore;
readonly #set: Setter;
constructor(set: Setter, get: () => EvalStore, _api?: unknown) {
void _api;
this.#set = set;
this.#get = get;
}
refreshDatasetDetail = async (id: string): Promise<void> => {
await mutate([FETCH_DATASET_DETAIL_KEY, id]);
},
};
refreshDatasets: async (benchmarkId) => {
refreshDatasets = async (benchmarkId: string): Promise<void> => {
await mutate([FETCH_DATASETS_KEY, benchmarkId]);
},
};
useFetchDatasetDetail: (id) => {
return useClientDataSWR(
useFetchDatasetDetail = (id?: string): SWRResponse =>
useClientDataSWR(
id ? [FETCH_DATASET_DETAIL_KEY, id] : null,
() => agentEvalService.getDataset(id!),
{
onSuccess: (data: any) => {
get().internal_dispatchDatasetDetail({
type: 'setDatasetDetail',
this.#get().internal_dispatchDatasetDetail({
id: id!,
type: 'setDatasetDetail',
value: data,
});
get().internal_updateDatasetDetailLoading(id!, false);
this.#get().internal_updateDatasetDetailLoading(id!, false);
},
},
);
},
useFetchDatasets: (benchmarkId) => {
return useClientDataSWR(
useFetchDatasets = (benchmarkId?: string): SWRResponse =>
useClientDataSWR(
benchmarkId ? [FETCH_DATASETS_KEY, benchmarkId] : null,
() => agentEvalService.listDatasets(benchmarkId!),
{
onSuccess: (data: any) => {
set(
this.#set(
{
datasetList: data,
isLoadingDatasets: false,
@ -70,22 +67,18 @@ export const createDatasetSlice: StateCreator<
},
},
);
},
// Internal - Dispatch to reducer
internal_dispatchDatasetDetail: (payload) => {
const currentMap = get().datasetDetailMap;
internal_dispatchDatasetDetail = (payload: DatasetDetailDispatch): void => {
const currentMap = this.#get().datasetDetailMap;
const nextMap = datasetDetailReducer(currentMap, payload);
// No need to update if map is the same
if (isEqual(nextMap, currentMap)) return;
set({ datasetDetailMap: nextMap }, false, `dispatchDatasetDetail/${payload.type}`);
},
this.#set({ datasetDetailMap: nextMap }, false, `dispatchDatasetDetail/${payload.type}`);
};
// Internal - Update loading state for specific detail
internal_updateDatasetDetailLoading: (id, loading) => {
set(
internal_updateDatasetDetailLoading = (id: string, loading: boolean): void => {
this.#set(
(state) => {
if (loading) {
return { loadingDatasetDetailIds: [...state.loadingDatasetDetailIds, id] };
@ -97,5 +90,7 @@ export const createDatasetSlice: StateCreator<
false,
'updateDatasetDetailLoading',
);
},
});
};
}
export type DatasetAction = Pick<DatasetActionImpl, keyof DatasetActionImpl>;

View file

@ -1,11 +1,11 @@
import type { EvalRunInputConfig } from '@lobechat/types';
import isEqual from 'fast-deep-equal';
import type { SWRResponse } from 'swr';
import type { StateCreator } from 'zustand/vanilla';
import { mutate, useClientDataSWR } from '@/libs/swr';
import { agentEvalService } from '@/services/agentEval';
import type { EvalStore } from '@/store/eval/store';
import { type StoreSetter } from '@/store/types';
import { type RunDetailDispatch, runDetailReducer } from './reducer';
@ -14,76 +14,59 @@ const FETCH_DATASET_RUNS_KEY = 'FETCH_EVAL_DATASET_RUNS';
const FETCH_RUN_DETAIL_KEY = 'FETCH_EVAL_RUN_DETAIL';
const FETCH_RUN_RESULTS_KEY = 'FETCH_EVAL_RUN_RESULTS';
export interface RunAction {
abortRun: (id: string) => Promise<void>;
createRun: (params: {
type Setter = StoreSetter<EvalStore>;
export const createRunSlice = (set: Setter, get: () => EvalStore, _api?: unknown) =>
new RunActionImpl(set, get, _api);
export class RunActionImpl {
readonly #get: () => EvalStore;
readonly #set: Setter;
constructor(set: Setter, get: () => EvalStore, _api?: unknown) {
void _api;
this.#set = set;
this.#get = get;
}
abortRun = async (id: string): Promise<void> => {
await agentEvalService.abortRun(id);
await this.#get().refreshRunDetail(id);
};
createRun = async (params: {
config?: EvalRunInputConfig;
datasetId: string;
name?: string;
targetAgentId?: string;
}) => Promise<any>;
deleteRun: (id: string) => Promise<void>;
internal_dispatchRunDetail: (payload: RunDetailDispatch) => void;
internal_updateRunDetailLoading: (id: string, loading: boolean) => void;
internal_updateRunResultLoading: (id: string, loading: boolean) => void;
refreshDatasetRuns: (datasetId: string) => Promise<void>;
refreshRunDetail: (id: string) => Promise<void>;
refreshRuns: (benchmarkId?: string) => Promise<void>;
retryRunCase: (runId: string, testCaseId: string) => Promise<void>;
retryRunErrors: (id: string) => Promise<void>;
startRun: (id: string, force?: boolean) => Promise<void>;
updateRun: (params: {
config?: EvalRunInputConfig;
datasetId?: string;
id: string;
name?: string;
targetAgentId?: string | null;
}) => Promise<any>;
useFetchDatasetRuns: (datasetId?: string) => SWRResponse;
useFetchRunDetail: (id: string, config?: { refreshInterval?: number }) => SWRResponse;
useFetchRunResults: (id: string, config?: { refreshInterval?: number }) => SWRResponse;
useFetchRuns: (benchmarkId?: string) => SWRResponse;
}
export const createRunSlice: StateCreator<
EvalStore,
[['zustand/devtools', never]],
[],
RunAction
> = (set, get) => ({
abortRun: async (id) => {
await agentEvalService.abortRun(id);
await get().refreshRunDetail(id);
},
createRun: async (params) => {
set({ isCreatingRun: true }, false, 'createRun/start');
}): Promise<any> => {
this.#set({ isCreatingRun: true }, false, 'createRun/start');
try {
const result = await agentEvalService.createRun(params);
await get().refreshRuns();
await this.#get().refreshRuns();
return result;
} finally {
set({ isCreatingRun: false }, false, 'createRun/end');
this.#set({ isCreatingRun: false }, false, 'createRun/end');
}
},
};
deleteRun: async (id) => {
deleteRun = async (id: string): Promise<void> => {
await agentEvalService.deleteRun(id);
get().internal_dispatchRunDetail({ id, type: 'deleteRunDetail' });
await get().refreshRuns();
},
this.#get().internal_dispatchRunDetail({ id, type: 'deleteRunDetail' });
await this.#get().refreshRuns();
};
internal_dispatchRunDetail: (payload) => {
const currentMap = get().runDetailMap;
internal_dispatchRunDetail = (payload: RunDetailDispatch): void => {
const currentMap = this.#get().runDetailMap;
const nextMap = runDetailReducer(currentMap, payload);
if (isEqual(nextMap, currentMap)) return;
set({ runDetailMap: nextMap }, false, `dispatchRunDetail/${payload.type}`);
},
this.#set({ runDetailMap: nextMap }, false, `dispatchRunDetail/${payload.type}`);
};
internal_updateRunDetailLoading: (id, loading) => {
set(
internal_updateRunDetailLoading = (id: string, loading: boolean): void => {
this.#set(
(state) => {
if (loading) {
return { loadingRunDetailIds: [...state.loadingRunDetailIds, id] };
@ -95,10 +78,10 @@ export const createRunSlice: StateCreator<
false,
'updateRunDetailLoading',
);
},
};
internal_updateRunResultLoading: (id, loading) => {
set(
internal_updateRunResultLoading = (id: string, loading: boolean): void => {
this.#set(
(state) => {
if (loading) {
return { loadingRunResultIds: [...state.loadingRunResultIds, id] };
@ -110,92 +93,95 @@ export const createRunSlice: StateCreator<
false,
'updateRunResultLoading',
);
},
};
refreshDatasetRuns: async (datasetId) => {
refreshDatasetRuns = async (datasetId: string): Promise<void> => {
await mutate([FETCH_DATASET_RUNS_KEY, datasetId]);
},
};
refreshRunDetail: async (id) => {
refreshRunDetail = async (id: string): Promise<void> => {
await mutate([FETCH_RUN_DETAIL_KEY, id]);
},
};
refreshRuns: async (benchmarkId) => {
refreshRuns = async (benchmarkId?: string): Promise<void> => {
if (benchmarkId) {
await mutate([FETCH_RUNS_KEY, benchmarkId]);
} else {
// Revalidate all benchmark-level run list entries
await mutate((key) => Array.isArray(key) && key[0] === FETCH_RUNS_KEY);
}
},
};
retryRunCase: async (runId, testCaseId) => {
retryRunCase = async (runId: string, testCaseId: string): Promise<void> => {
await agentEvalService.retryRunCase(runId, testCaseId);
await get().refreshRunDetail(runId);
},
await this.#get().refreshRunDetail(runId);
};
retryRunErrors: async (id) => {
retryRunErrors = async (id: string): Promise<void> => {
await agentEvalService.retryRunErrors(id);
await get().refreshRunDetail(id);
},
await this.#get().refreshRunDetail(id);
};
startRun: async (id, force) => {
startRun = async (id: string, force?: boolean): Promise<void> => {
await agentEvalService.startRun(id, force);
await get().refreshRunDetail(id);
},
await this.#get().refreshRunDetail(id);
};
updateRun: async (params) => {
updateRun = async (params: {
config?: EvalRunInputConfig;
datasetId?: string;
id: string;
name?: string;
targetAgentId?: string | null;
}): Promise<any> => {
const result = await agentEvalService.updateRun(params);
await get().refreshRunDetail(params.id);
await get().refreshRuns();
await this.#get().refreshRunDetail(params.id);
await this.#get().refreshRuns();
return result;
},
};
useFetchRunDetail: (id, config) => {
return useClientDataSWR(
useFetchRunDetail = (id: string, config?: { refreshInterval?: number }): SWRResponse =>
useClientDataSWR(
id ? [FETCH_RUN_DETAIL_KEY, id] : null,
() => agentEvalService.getRunDetails(id),
{
...config,
onSuccess: (data: any) => {
get().internal_dispatchRunDetail({
this.#get().internal_dispatchRunDetail({
id,
type: 'setRunDetail',
value: data,
});
get().internal_updateRunDetailLoading(id, false);
this.#get().internal_updateRunDetailLoading(id, false);
},
},
);
},
useFetchRunResults: (id, config) => {
return useClientDataSWR(
useFetchRunResults = (id: string, config?: { refreshInterval?: number }): SWRResponse =>
useClientDataSWR(
id ? [FETCH_RUN_RESULTS_KEY, id] : null,
() => agentEvalService.getRunResults(id),
{
...config,
onSuccess: (data: any) => {
set(
this.#set(
(state) => ({
runResultsMap: { ...state.runResultsMap, [id]: data },
}),
false,
'useFetchRunResults/success',
);
get().internal_updateRunResultLoading(id, false);
this.#get().internal_updateRunResultLoading(id, false);
},
},
);
},
useFetchDatasetRuns: (datasetId) => {
return useClientDataSWR(
useFetchDatasetRuns = (datasetId?: string): SWRResponse =>
useClientDataSWR(
datasetId ? [FETCH_DATASET_RUNS_KEY, datasetId] : null,
() => agentEvalService.listRuns({ datasetId: datasetId! }),
{
onSuccess: (data: any) => {
set(
this.#set(
(state) => ({
datasetRunListMap: { ...state.datasetRunListMap, [datasetId!]: data.data },
}),
@ -205,17 +191,17 @@ export const createRunSlice: StateCreator<
},
},
);
},
useFetchRuns: (benchmarkId) => {
return useClientDataSWR(
useFetchRuns = (benchmarkId?: string): SWRResponse =>
useClientDataSWR(
benchmarkId ? [FETCH_RUNS_KEY, benchmarkId] : null,
() => agentEvalService.listRuns({ benchmarkId: benchmarkId! }),
{
onSuccess: (data: any) => {
set({ isLoadingRuns: false, runList: data.data }, false, 'useFetchRuns/success');
this.#set({ isLoadingRuns: false, runList: data.data }, false, 'useFetchRuns/success');
},
},
);
},
});
}
export type RunAction = Pick<RunActionImpl, keyof RunActionImpl>;

View file

@ -1,54 +1,50 @@
import type { SWRResponse } from 'swr';
import type { StateCreator } from 'zustand/vanilla';
import { mutate, useClientDataSWR } from '@/libs/swr';
import { agentEvalService } from '@/services/agentEval';
import type { EvalStore } from '@/store/eval/store';
import { type StoreSetter } from '@/store/types';
const FETCH_TEST_CASES_KEY = 'FETCH_TEST_CASES';
export interface TestCaseAction {
getTestCasesByDatasetId: (datasetId: string) => any[];
getTestCasesTotalByDatasetId: (datasetId: string) => number;
isLoadingTestCases: (datasetId: string) => boolean;
refreshTestCases: (datasetId: string) => Promise<void>;
useFetchTestCases: (params: {
type Setter = StoreSetter<EvalStore>;
export const createTestCaseSlice = (set: Setter, get: () => EvalStore, _api?: unknown) =>
new TestCaseActionImpl(set, get, _api);
export class TestCaseActionImpl {
readonly #get: () => EvalStore;
readonly #set: Setter;
constructor(set: Setter, get: () => EvalStore, _api?: unknown) {
void _api;
this.#set = set;
this.#get = get;
}
getTestCasesByDatasetId = (datasetId: string): any[] => {
return this.#get().testCasesCache[datasetId]?.data || [];
};
getTestCasesTotalByDatasetId = (datasetId: string): number => {
return this.#get().testCasesCache[datasetId]?.total || 0;
};
isLoadingTestCases = (datasetId: string): boolean => {
return this.#get().loadingTestCaseIds.includes(datasetId);
};
refreshTestCases = async (datasetId: string): Promise<void> => {
await mutate(
(key) => Array.isArray(key) && key[0] === FETCH_TEST_CASES_KEY && key[1] === datasetId,
);
};
useFetchTestCases = (params: {
datasetId: string;
limit?: number;
offset?: number;
}) => SWRResponse;
}
export const createTestCaseSlice: StateCreator<
EvalStore,
[['zustand/devtools', never]],
[],
TestCaseAction
> = (set, get) => ({
// Get test cases for a specific dataset from cache
getTestCasesByDatasetId: (datasetId) => {
return get().testCasesCache[datasetId]?.data || [];
},
// Get total count for a specific dataset from cache
getTestCasesTotalByDatasetId: (datasetId) => {
return get().testCasesCache[datasetId]?.total || 0;
},
// Check if test cases are currently loading for a dataset
isLoadingTestCases: (datasetId) => {
return get().loadingTestCaseIds.includes(datasetId);
},
refreshTestCases: async (datasetId) => {
// Mutate all SWR keys that start with [FETCH_TEST_CASES_KEY, datasetId]
await mutate(
(key) =>
Array.isArray(key) && key[0] === FETCH_TEST_CASES_KEY && key[1] === datasetId,
);
},
useFetchTestCases: (params) => {
}): SWRResponse => {
const { datasetId, limit = 10, offset = 0 } = params;
return useClientDataSWR(
@ -56,7 +52,7 @@ export const createTestCaseSlice: StateCreator<
() => agentEvalService.listTestCases({ datasetId, limit, offset }),
{
onSuccess: (data: any) => {
set(
this.#set(
(state) => ({
loadingTestCaseIds: state.loadingTestCaseIds.filter((id) => id !== datasetId),
testCasesCache: {
@ -74,5 +70,7 @@ export const createTestCaseSlice: StateCreator<
},
},
);
},
});
};
}
export type TestCaseAction = Pick<TestCaseActionImpl, keyof TestCaseActionImpl>;

View file

@ -4,24 +4,27 @@ import type { StateCreator } from 'zustand/vanilla';
import { createDevtools } from '../middleware/createDevtools';
import { expose } from '../middleware/expose';
import { flattenActions } from '../utils/flattenActions';
import { type EvalStoreState, initialState } from './initialState';
import { type BenchmarkAction, createBenchmarkSlice } from './slices/benchmark/action';
import { createDatasetSlice, type DatasetAction } from './slices/dataset/action';
import { createRunSlice, type RunAction } from './slices/run/action';
import { createTestCaseSlice, type TestCaseAction } from './slices/testCase/action';
export type EvalStore = EvalStoreState &
BenchmarkAction &
DatasetAction &
RunAction &
TestCaseAction;
type EvalStoreAction = BenchmarkAction & DatasetAction & RunAction & TestCaseAction;
const createStore: StateCreator<EvalStore, [['zustand/devtools', never]]> = (set, get, store) => ({
export type EvalStore = EvalStoreState & EvalStoreAction;
const createStore: StateCreator<EvalStore, [['zustand/devtools', never]]> = (
...parameters: Parameters<StateCreator<EvalStore, [['zustand/devtools', never]]>>
) => ({
...initialState,
...createBenchmarkSlice(set, get, store),
...createDatasetSlice(set, get, store),
...createRunSlice(set, get, store),
...createTestCaseSlice(set, get, store),
...flattenActions<EvalStoreAction>([
createBenchmarkSlice(...parameters),
createDatasetSlice(...parameters),
createRunSlice(...parameters),
createTestCaseSlice(...parameters),
]),
});
const devtools = createDevtools('eval');

View file

@ -3,6 +3,7 @@ import { type StateCreator } from 'zustand';
import { EDITOR_DEBOUNCE_TIME, EDITOR_MAX_WAIT } from '@/const/index';
import { type StoreSetter } from '@/store/types';
import { flattenActions } from '@/store/utils/flattenActions';
import { type SaveState, type SaveStatus, type State } from './initialState';
import { initialState } from './initialState';
@ -179,7 +180,7 @@ export class ActionImpl {
};
}
export const store: StateCreator<Store> = (set, get, _api) => ({
export const store: StateCreator<Store> = (...parameters: Parameters<StateCreator<Store>>) => ({
...initialState,
...new ActionImpl(set, get, _api),
...flattenActions<Action>([new ActionImpl(...parameters)]),
});

View file

@ -11,10 +11,10 @@ import {
} from '@lobechat/types';
import { produce } from 'immer';
import useSWR, { mutate, type SWRResponse } from 'swr';
import { type StateCreator } from 'zustand/vanilla';
import { useClientDataSWR } from '@/libs/swr';
import { agentSkillService } from '@/services/skill';
import { type StoreSetter } from '@/store/types';
import { setNamespace } from '@/utils/storeDebug';
import { type ToolStore } from '../../store';
@ -27,36 +27,31 @@ export interface AgentSkillDetailData {
skillDetail?: SkillItem;
}
export interface AgentSkillsAction {
createAgentSkill: (params: CreateSkillInput) => Promise<SkillItem | undefined>;
deleteAgentSkill: (id: string) => Promise<void>;
fetchAgentSkillDetail: (id: string) => Promise<SkillItem | undefined>;
importAgentSkillFromGitHub: (params: ImportGitHubInput) => Promise<SkillImportResult | undefined>;
importAgentSkillFromUrl: (params: ImportUrlInput) => Promise<SkillImportResult | undefined>;
importAgentSkillFromZip: (params: ImportZipInput) => Promise<SkillImportResult | undefined>;
refreshAgentSkills: () => Promise<void>;
updateAgentSkill: (params: UpdateSkillInput) => Promise<SkillItem | undefined>;
useFetchAgentSkillDetail: (skillId?: string) => SWRResponse<AgentSkillDetailData>;
useFetchAgentSkills: (enabled: boolean) => SWRResponse<SkillListItem[]>;
}
type Setter = StoreSetter<ToolStore>;
export const createAgentSkillsSlice: StateCreator<
ToolStore,
[['zustand/devtools', never]],
[],
AgentSkillsAction
> = (set, get) => ({
createAgentSkill: async (params) => {
export const createAgentSkillsSlice = (set: Setter, get: () => ToolStore, _api?: unknown) =>
new AgentSkillsActionImpl(set, get, _api);
export class AgentSkillsActionImpl {
readonly #get: () => ToolStore;
readonly #set: Setter;
constructor(set: Setter, get: () => ToolStore, _api?: unknown) {
void _api;
this.#set = set;
this.#get = get;
}
createAgentSkill = async (params: CreateSkillInput): Promise<SkillItem | undefined> => {
const result = await agentSkillService.createSkill(params);
await get().refreshAgentSkills();
await this.#get().refreshAgentSkills();
return result;
},
};
deleteAgentSkill: async (id) => {
deleteAgentSkill = async (id: string): Promise<void> => {
await agentSkillService.deleteSkill(id);
// Clean up detail map
set(
this.#set(
produce((draft: AgentSkillsState) => {
delete draft.agentSkillDetailMap[id];
}),
@ -64,19 +59,18 @@ export const createAgentSkillsSlice: StateCreator<
n('deleteAgentSkill'),
);
// Clear SWR cache
await mutate(['fetchAgentSkillDetail', id].join('-'), undefined, { revalidate: false });
await get().refreshAgentSkills();
},
await this.#get().refreshAgentSkills();
};
fetchAgentSkillDetail: async (id) => {
const cached = get().agentSkillDetailMap[id];
fetchAgentSkillDetail = async (id: string): Promise<SkillItem | undefined> => {
const cached = this.#get().agentSkillDetailMap[id];
if (cached) return cached;
const detail = await agentSkillService.getById(id);
if (detail) {
set(
this.#set(
produce((draft: AgentSkillsState) => {
draft.agentSkillDetailMap[id] = detail;
}),
@ -85,37 +79,42 @@ export const createAgentSkillsSlice: StateCreator<
);
}
return detail;
},
};
importAgentSkillFromGitHub: async (params) => {
importAgentSkillFromGitHub = async (
params: ImportGitHubInput,
): Promise<SkillImportResult | undefined> => {
const result = await agentSkillService.importFromGitHub(params);
await get().refreshAgentSkills();
await this.#get().refreshAgentSkills();
return result;
},
};
importAgentSkillFromUrl: async (params) => {
importAgentSkillFromUrl = async (
params: ImportUrlInput,
): Promise<SkillImportResult | undefined> => {
const result = await agentSkillService.importFromUrl(params);
await get().refreshAgentSkills();
await this.#get().refreshAgentSkills();
return result;
},
};
importAgentSkillFromZip: async (params) => {
importAgentSkillFromZip = async (
params: ImportZipInput,
): Promise<SkillImportResult | undefined> => {
const result = await agentSkillService.importFromZip(params);
await get().refreshAgentSkills();
await this.#get().refreshAgentSkills();
return result;
},
};
refreshAgentSkills: async () => {
refreshAgentSkills = async (): Promise<void> => {
const { data } = await agentSkillService.list();
set({ agentSkills: data }, false, n('refreshAgentSkills'));
},
this.#set({ agentSkills: data }, false, n('refreshAgentSkills'));
};
updateAgentSkill: async (params) => {
updateAgentSkill = async (params: UpdateSkillInput): Promise<SkillItem | undefined> => {
const result = await agentSkillService.updateSkill(params);
// Update detail map if cached
if (result) {
set(
this.#set(
produce((draft: AgentSkillsState) => {
draft.agentSkillDetailMap[params.id] = result;
}),
@ -124,14 +123,13 @@ export const createAgentSkillsSlice: StateCreator<
);
}
// Clear SWR cache so next open refetches instead of showing stale data
await mutate(['fetchAgentSkillDetail', params.id].join('-'), undefined, { revalidate: false });
await get().refreshAgentSkills();
await this.#get().refreshAgentSkills();
return result;
},
};
useFetchAgentSkillDetail: (skillId) =>
useFetchAgentSkillDetail = (skillId?: string): SWRResponse<AgentSkillDetailData> =>
useClientDataSWR<AgentSkillDetailData>(
skillId ? ['fetchAgentSkillDetail', skillId].join('-') : null,
async () => {
@ -141,7 +139,7 @@ export const createAgentSkillsSlice: StateCreator<
]);
if (detail) {
set(
this.#set(
produce((draft: AgentSkillsState) => {
draft.agentSkillDetailMap[skillId!] = detail;
}),
@ -153,9 +151,9 @@ export const createAgentSkillsSlice: StateCreator<
return { resourceTree, skillDetail: detail };
},
{ revalidateOnFocus: false },
),
);
useFetchAgentSkills: (enabled) =>
useFetchAgentSkills = (enabled: boolean): SWRResponse<SkillListItem[]> =>
useSWR<SkillListItem[]>(
enabled ? 'fetchAgentSkills' : null,
async () => {
@ -165,9 +163,11 @@ export const createAgentSkillsSlice: StateCreator<
{
fallbackData: [],
onSuccess: (data) => {
set({ agentSkills: data }, false, n('useFetchAgentSkills'));
this.#set({ agentSkills: data }, false, n('useFetchAgentSkills'));
},
revalidateOnFocus: false,
},
),
});
);
}
export type AgentSkillsAction = Pick<AgentSkillsActionImpl, keyof AgentSkillsActionImpl>;

View file

@ -1,35 +1,35 @@
import { ENABLE_BUSINESS_FEATURES } from '@lobechat/business-const';
import { t } from 'i18next';
import { type StateCreator } from 'zustand';
import { markUserValidAction } from '@/business/client/markUserValidAction';
import { message } from '@/components/AntdStaticMethods';
import { videoService } from '@/services/video';
import { type StoreSetter } from '@/store/types';
import { type VideoStore } from '../../store';
import { generationBatchSelectors } from '../generationBatch/selectors';
import { videoGenerationConfigSelectors } from '../generationConfig/selectors';
import { generationTopicSelectors } from '../generationTopic';
// ====== action interface ====== //
type Setter = StoreSetter<VideoStore>;
export interface CreateVideoAction {
createVideo: () => Promise<void>;
recreateVideo: (generationBatchId: string) => Promise<void>;
}
export const createCreateVideoSlice = (set: Setter, get: () => VideoStore, _api?: unknown) =>
new CreateVideoActionImpl(set, get, _api);
// ====== action implementation ====== //
export class CreateVideoActionImpl {
readonly #get: () => VideoStore;
readonly #set: Setter;
export const createCreateVideoSlice: StateCreator<
VideoStore,
[['zustand/devtools', never]],
[],
CreateVideoAction
> = (set, get) => ({
async createVideo() {
set({ isCreating: true }, false, 'createVideo/startCreateVideo');
constructor(set: Setter, get: () => VideoStore, _api?: unknown) {
void _api;
this.#set = set;
this.#get = get;
}
const store = get();
createVideo = async (): Promise<void> => {
this.#set({ isCreating: true }, false, 'createVideo/startCreateVideo');
const store = this.#get();
const parameters = videoGenerationConfigSelectors.parameters(store);
const provider = videoGenerationConfigSelectors.provider(store);
const model = videoGenerationConfigSelectors.model(store);
@ -58,7 +58,7 @@ export const createCreateVideoSlice: StateCreator<
content: t('generation.validation.endFrameRequiresStartFrame', { ns: 'video' }),
duration: 3,
});
set({ isCreating: false }, false, 'createVideo/endCreateVideo');
this.#set({ isCreating: false }, false, 'createVideo/endCreateVideo');
return;
}
@ -84,7 +84,11 @@ export const createCreateVideoSlice: StateCreator<
try {
// 3. If it's a new topic, set the creating state after topic creation
if (isNewTopic) {
set({ isCreatingWithNewTopic: true }, false, 'createVideo/startCreateVideoWithNewTopic');
this.#set(
{ isCreatingWithNewTopic: true },
false,
'createVideo/startCreateVideoWithNewTopic',
);
}
if (ENABLE_BUSINESS_FEATURES) {
@ -101,11 +105,11 @@ export const createCreateVideoSlice: StateCreator<
// 5. Refresh generation batches to show the new batch
if (!isNewTopic) {
await get().refreshGenerationBatches();
await this.#get().refreshGenerationBatches();
}
// 6. Clear the prompt input after successful video creation
set(
this.#set(
(state) => ({
parameters: { ...state.parameters, prompt: '' },
}),
@ -115,21 +119,21 @@ export const createCreateVideoSlice: StateCreator<
} finally {
// 7. Reset all creating states
if (isNewTopic) {
set(
this.#set(
{ isCreating: false, isCreatingWithNewTopic: false },
false,
'createVideo/endCreateVideoWithNewTopic',
);
} else {
set({ isCreating: false }, false, 'createVideo/endCreateVideo');
this.#set({ isCreating: false }, false, 'createVideo/endCreateVideo');
}
}
},
};
async recreateVideo(generationBatchId: string) {
set({ isCreating: true }, false, 'recreateVideo/start');
recreateVideo = async (generationBatchId: string): Promise<void> => {
this.#set({ isCreating: true }, false, 'recreateVideo/start');
const store = get();
const store = this.#get();
const activeGenerationTopicId = generationTopicSelectors.activeGenerationTopicId(store);
if (!activeGenerationTopicId) {
throw new Error('No active generation topic');
@ -150,7 +154,9 @@ export const createCreateVideoSlice: StateCreator<
await store.refreshGenerationBatches();
} finally {
set({ isCreating: false }, false, 'recreateVideo/end');
this.#set({ isCreating: false }, false, 'recreateVideo/end');
}
},
});
};
}
export type CreateVideoAction = Pick<CreateVideoActionImpl, keyof CreateVideoActionImpl>;

View file

@ -1,12 +1,12 @@
import { isEqual } from 'es-toolkit/compat';
import { useRef } from 'react';
import type { SWRResponse } from 'swr';
import { type StateCreator } from 'zustand';
import { mutate, useClientDataSWR } from '@/libs/swr';
import { type GetGenerationStatusResult } from '@/server/routers/lambda/generation';
import { generationService } from '@/services/generation';
import { generationBatchService } from '@/services/generationBatch';
import { type StoreSetter } from '@/store/types';
import { AsyncTaskStatus } from '@/types/asyncTask';
import { type GenerationBatch } from '@/types/generation';
import { setNamespace } from '@/utils/storeDebug';
@ -21,44 +21,28 @@ const n = setNamespace('generationBatch');
const SWR_USE_FETCH_GENERATION_BATCHES = 'SWR_USE_FETCH_VIDEO_GENERATION_BATCHES';
const SWR_USE_CHECK_GENERATION_STATUS = 'SWR_USE_CHECK_VIDEO_GENERATION_STATUS';
// ====== action interface ====== //
type Setter = StoreSetter<VideoStore>;
export interface GenerationBatchAction {
internal_deleteGeneration: (generationId: string) => Promise<void>;
internal_deleteGenerationBatch: (batchId: string, topicId: string) => Promise<void>;
internal_dispatchGenerationBatch: (
topicId: string,
payload: GenerationBatchDispatch,
action?: string,
) => void;
refreshGenerationBatches: () => Promise<void>;
removeGeneration: (generationId: string) => Promise<void>;
removeGenerationBatch: (batchId: string, topicId: string) => Promise<void>;
setTopicBatchLoaded: (topicId: string) => void;
useCheckGenerationStatus: (
generationId: string,
asyncTaskId: string,
topicId: string,
enable?: boolean,
) => SWRResponse<GetGenerationStatusResult>;
useFetchGenerationBatches: (topicId?: string | null) => SWRResponse<GenerationBatch[]>;
}
export const createGenerationBatchSlice = (set: Setter, get: () => VideoStore, _api?: unknown) =>
new GenerationBatchActionImpl(set, get, _api);
// ====== action implementation ====== //
export class GenerationBatchActionImpl {
readonly #get: () => VideoStore;
readonly #set: Setter;
export const createGenerationBatchSlice: StateCreator<
VideoStore,
[['zustand/devtools', never]],
[],
GenerationBatchAction
> = (set, get) => ({
internal_deleteGeneration: async (generationId: string) => {
constructor(set: Setter, get: () => VideoStore, _api?: unknown) {
void _api;
this.#set = set;
this.#get = get;
}
internal_deleteGeneration = async (generationId: string): Promise<void> => {
const { activeGenerationTopicId, refreshGenerationBatches, internal_dispatchGenerationBatch } =
get();
this.#get();
if (!activeGenerationTopicId) return;
const currentBatches = get().generationBatchesMap[activeGenerationTopicId] || [];
const currentBatches = this.#get().generationBatchesMap[activeGenerationTopicId] || [];
const targetBatch = currentBatches.find((batch) =>
batch.generations.some((gen) => gen.id === generationId),
);
@ -74,10 +58,10 @@ export const createGenerationBatchSlice: StateCreator<
await generationService.deleteGeneration(generationId);
await refreshGenerationBatches();
},
};
internal_deleteGenerationBatch: async (batchId: string, topicId: string) => {
const { internal_dispatchGenerationBatch, refreshGenerationBatches } = get();
internal_deleteGenerationBatch = async (batchId: string, topicId: string): Promise<void> => {
const { internal_dispatchGenerationBatch, refreshGenerationBatches } = this.#get();
// Optimistic update
internal_dispatchGenerationBatch(
@ -88,75 +72,84 @@ export const createGenerationBatchSlice: StateCreator<
await generationBatchService.deleteGenerationBatch(batchId);
await refreshGenerationBatches();
},
};
internal_dispatchGenerationBatch: (topicId, payload, action) => {
const currentBatches = get().generationBatchesMap[topicId] || [];
internal_dispatchGenerationBatch = (
topicId: string,
payload: GenerationBatchDispatch,
action?: string,
): void => {
const currentBatches = this.#get().generationBatchesMap[topicId] || [];
const nextBatches = generationBatchReducer(currentBatches, payload);
const nextMap = {
...get().generationBatchesMap,
...this.#get().generationBatchesMap,
[topicId]: nextBatches,
};
if (isEqual(nextMap, get().generationBatchesMap)) return;
if (isEqual(nextMap, this.#get().generationBatchesMap)) return;
set(
this.#set(
{
generationBatchesMap: nextMap,
},
false,
action ?? n(`dispatchGenerationBatch/${payload.type}`),
);
},
};
refreshGenerationBatches: async () => {
const { activeGenerationTopicId } = get();
refreshGenerationBatches = async (): Promise<void> => {
const { activeGenerationTopicId } = this.#get();
if (activeGenerationTopicId) {
await mutate([SWR_USE_FETCH_GENERATION_BATCHES, activeGenerationTopicId]);
}
},
};
removeGeneration: async (generationId: string) => {
removeGeneration = async (generationId: string): Promise<void> => {
const { internal_deleteGeneration, activeGenerationTopicId, internal_deleteGenerationBatch } =
get();
this.#get();
await internal_deleteGeneration(generationId);
// Video batch has only 1 generation, so delete the batch directly
if (activeGenerationTopicId) {
const updatedBatches = get().generationBatchesMap[activeGenerationTopicId] || [];
const updatedBatches = this.#get().generationBatchesMap[activeGenerationTopicId] || [];
const emptyBatches = updatedBatches.filter((batch) => batch.generations.length === 0);
for (const emptyBatch of emptyBatches) {
await internal_deleteGenerationBatch(emptyBatch.id, activeGenerationTopicId);
}
}
},
};
removeGenerationBatch: async (batchId: string, topicId: string) => {
const { internal_deleteGenerationBatch } = get();
removeGenerationBatch = async (batchId: string, topicId: string): Promise<void> => {
const { internal_deleteGenerationBatch } = this.#get();
await internal_deleteGenerationBatch(batchId, topicId);
},
};
setTopicBatchLoaded: (topicId: string) => {
setTopicBatchLoaded = (topicId: string): void => {
const nextMap = {
...get().generationBatchesMap,
...this.#get().generationBatchesMap,
[topicId]: [],
};
if (isEqual(nextMap, get().generationBatchesMap)) return;
if (isEqual(nextMap, this.#get().generationBatchesMap)) return;
set(
this.#set(
{
generationBatchesMap: nextMap,
},
false,
n('setTopicBatchLoaded'),
);
},
};
useCheckGenerationStatus: (generationId, asyncTaskId, topicId, enable = true) => {
useCheckGenerationStatus = (
generationId: string,
asyncTaskId: string,
topicId: string,
enable = true,
): SWRResponse<GetGenerationStatusResult> => {
const requestCountRef = useRef(0);
const isErrorRef = useRef(false);
@ -178,7 +171,7 @@ export const createGenerationBatchSlice: StateCreator<
isErrorRef.current = false;
const currentBatches = get().generationBatchesMap[topicId] || [];
const currentBatches = this.#get().generationBatchesMap[topicId] || [];
const targetBatch = currentBatches.find((batch) =>
batch.generations.some((gen) => gen.id === generationId),
);
@ -190,7 +183,7 @@ export const createGenerationBatchSlice: StateCreator<
requestCountRef.current = 0;
if (data.generation) {
get().internal_dispatchGenerationBatch(
this.#get().internal_dispatchGenerationBatch(
topicId,
{
batchId: targetBatch.id,
@ -205,11 +198,12 @@ export const createGenerationBatchSlice: StateCreator<
// Update topic cover if generation succeeds and has a thumbnail
if (data.status === AsyncTaskStatus.Success && data.generation.asset?.thumbnailUrl) {
const currentTopic =
generationTopicSelectors.getGenerationTopicById(topicId)(get());
const currentTopic = generationTopicSelectors.getGenerationTopicById(topicId)(
this.#get(),
);
if (currentTopic && !currentTopic.coverUrl) {
await get().updateGenerationTopicCover(
await this.#get().updateGenerationTopicCover(
topicId,
data.generation.asset.thumbnailUrl,
);
@ -217,7 +211,7 @@ export const createGenerationBatchSlice: StateCreator<
}
}
await get().refreshGenerationBatches();
await this.#get().refreshGenerationBatches();
}
},
refreshInterval: (data: GetGenerationStatusResult | undefined) => {
@ -244,9 +238,9 @@ export const createGenerationBatchSlice: StateCreator<
refreshWhenHidden: false,
},
);
},
};
useFetchGenerationBatches: (topicId) =>
useFetchGenerationBatches = (topicId?: string | null): SWRResponse<GenerationBatch[]> =>
useClientDataSWR<GenerationBatch[]>(
topicId ? [SWR_USE_FETCH_GENERATION_BATCHES, topicId] : null,
async ([, topicId]: [string, string]) => {
@ -255,13 +249,13 @@ export const createGenerationBatchSlice: StateCreator<
{
onSuccess: (data) => {
const nextMap = {
...get().generationBatchesMap,
...this.#get().generationBatchesMap,
[topicId!]: data,
};
if (isEqual(nextMap, get().generationBatchesMap)) return;
if (isEqual(nextMap, this.#get().generationBatchesMap)) return;
set(
this.#set(
{
generationBatchesMap: nextMap,
},
@ -270,5 +264,10 @@ export const createGenerationBatchSlice: StateCreator<
);
},
},
),
});
);
}
export type GenerationBatchAction = Pick<
GenerationBatchActionImpl,
keyof GenerationBatchActionImpl
>;

View file

@ -5,30 +5,15 @@ import {
type RuntimeVideoGenParamsValue,
type VideoModelParamsSchema,
} from 'model-bank';
import { type StateCreator } from 'zustand/vanilla';
import { aiProviderSelectors, getAiInfraStoreState } from '@/store/aiInfra';
import { useGlobalStore } from '@/store/global';
import { type StoreSetter } from '@/store/types';
import { useUserStore } from '@/store/user';
import { authSelectors } from '@/store/user/selectors';
import type { VideoStore } from '../../store';
export interface GenerationConfigAction {
initializeVideoConfig: (
isLogin?: boolean,
lastSelectedVideoModel?: string,
lastSelectedVideoProvider?: string,
) => void;
setModelAndProviderOnSelect: (model: string, provider: string) => void;
setParamOnInput: <K extends RuntimeVideoGenParamsKeys>(
paramName: K,
value: RuntimeVideoGenParamsValue,
) => void;
}
export function getVideoModelAndDefaults(model: string, provider: string) {
const enabledVideoModelList = aiProviderSelectors.enabledVideoModelList(getAiInfraStoreState());
@ -54,13 +39,25 @@ export function getVideoModelAndDefaults(model: string, provider: string) {
return { activeModel, defaultValues, parametersSchema };
}
export const createGenerationConfigSlice: StateCreator<
VideoStore,
[['zustand/devtools', never]],
[],
GenerationConfigAction
> = (set) => ({
initializeVideoConfig: (isLogin, lastSelectedVideoModel, lastSelectedVideoProvider) => {
type Setter = StoreSetter<VideoStore>;
export const createGenerationConfigSlice = (set: Setter, get: () => VideoStore, _api?: unknown) =>
new GenerationConfigActionImpl(set, get, _api);
export class GenerationConfigActionImpl {
readonly #set: Setter;
constructor(set: Setter, _get: () => VideoStore, _api?: unknown) {
void _get;
void _api;
this.#set = set;
}
initializeVideoConfig = (
isLogin?: boolean,
lastSelectedVideoModel?: string,
lastSelectedVideoProvider?: string,
): void => {
if (isLogin && lastSelectedVideoModel && lastSelectedVideoProvider) {
try {
const { defaultValues, parametersSchema } = getVideoModelAndDefaults(
@ -68,7 +65,7 @@ export const createGenerationConfigSlice: StateCreator<
lastSelectedVideoProvider,
);
set(
this.#set(
{
isInit: true,
model: lastSelectedVideoModel,
@ -80,17 +77,17 @@ export const createGenerationConfigSlice: StateCreator<
`initializeVideoConfig/${lastSelectedVideoModel}/${lastSelectedVideoProvider}`,
);
} catch {
set({ isInit: true }, false, 'initializeVideoConfig/default');
this.#set({ isInit: true }, false, 'initializeVideoConfig/default');
}
} else {
set({ isInit: true }, false, 'initializeVideoConfig/default');
this.#set({ isInit: true }, false, 'initializeVideoConfig/default');
}
},
};
setModelAndProviderOnSelect: (model, provider) => {
setModelAndProviderOnSelect = (model: string, provider: string): void => {
const { defaultValues, parametersSchema } = getVideoModelAndDefaults(model, provider);
set(
this.#set(
{
model,
parameters: defaultValues,
@ -108,10 +105,13 @@ export const createGenerationConfigSlice: StateCreator<
lastSelectedVideoProvider: provider,
});
}
},
};
setParamOnInput: (paramName, value) => {
set(
setParamOnInput = <K extends RuntimeVideoGenParamsKeys>(
paramName: K,
value: RuntimeVideoGenParamsValue,
): void => {
this.#set(
(state) => {
const { parameters } = state;
return { parameters: { ...parameters, [paramName]: value } };
@ -119,5 +119,10 @@ export const createGenerationConfigSlice: StateCreator<
false,
`setParamOnInput/${paramName}`,
);
},
});
};
}
export type GenerationConfigAction = Pick<
GenerationConfigActionImpl,
keyof GenerationConfigActionImpl
>;

View file

@ -1,13 +1,13 @@
import { chainSummaryGenerationTitle } from '@lobechat/prompts';
import isEqual from 'fast-deep-equal';
import type { SWRResponse } from 'swr';
import { type StateCreator } from 'zustand/vanilla';
import { LOADING_FLAT } from '@/const/message';
import { mutate, useClientDataSWR } from '@/libs/swr';
import { type UpdateTopicValue } from '@/server/routers/lambda/generationTopic';
import { chatService } from '@/services/chat';
import { generationTopicService } from '@/services/generationTopic';
import { type StoreSetter } from '@/store/types';
import { useUserStore } from '@/store/user';
import { systemAgentSelectors, userGeneralSettingsSelectors } from '@/store/user/selectors';
import { type ImageGenerationTopic } from '@/types/generation';
@ -22,104 +22,97 @@ const FETCH_GENERATION_TOPICS_KEY = 'fetchVideoGenerationTopics';
const n = setNamespace('videoGenerationTopic');
export interface GenerationTopicAction {
createGenerationTopic: (prompts: string[]) => Promise<string>;
internal_createGenerationTopic: () => Promise<string>;
internal_dispatchGenerationTopic: (payload: GenerationTopicDispatch, action?: any) => void;
internal_removeGenerationTopic: (id: string) => Promise<void>;
internal_updateGenerationTopic: (id: string, data: UpdateTopicValue) => Promise<void>;
internal_updateGenerationTopicCover: (topicId: string, coverUrl: string) => Promise<void>;
internal_updateGenerationTopicLoading: (id: string, loading: boolean) => void;
internal_updateGenerationTopicTitleInSummary: (id: string, title: string) => void;
type Setter = StoreSetter<VideoStore>;
openNewGenerationTopic: () => void;
refreshGenerationTopics: () => Promise<void>;
removeGenerationTopic: (id: string) => Promise<void>;
summaryGenerationTopicTitle: (topicId: string, prompts: string[]) => Promise<string>;
switchGenerationTopic: (topicId: string) => void;
updateGenerationTopicCover: (topicId: string, imageUrl: string) => Promise<void>;
useFetchGenerationTopics: (enabled: boolean) => SWRResponse<ImageGenerationTopic[]>;
}
export const createGenerationTopicSlice = (set: Setter, get: () => VideoStore, _api?: unknown) =>
new GenerationTopicActionImpl(set, get, _api);
export const createGenerationTopicSlice: StateCreator<
VideoStore,
[['zustand/devtools', never]],
[],
GenerationTopicAction
> = (set, get) => ({
createGenerationTopic: async (prompts: string[]) => {
export class GenerationTopicActionImpl {
readonly #get: () => VideoStore;
readonly #set: Setter;
constructor(set: Setter, get: () => VideoStore, _api?: unknown) {
void _api;
this.#set = set;
this.#get = get;
}
createGenerationTopic = async (prompts: string[]): Promise<string> => {
if (!prompts || prompts.length === 0) {
throw new Error('Prompts cannot be empty when creating a generation topic');
}
const { internal_createGenerationTopic, summaryGenerationTopicTitle } = get();
const { internal_createGenerationTopic, summaryGenerationTopicTitle } = this.#get();
const topicId = await internal_createGenerationTopic();
summaryGenerationTopicTitle(topicId, prompts);
return topicId;
},
};
internal_createGenerationTopic: async () => {
internal_createGenerationTopic = async (): Promise<string> => {
const tmpId = Date.now().toString();
get().internal_dispatchGenerationTopic(
this.#get().internal_dispatchGenerationTopic(
{ type: 'addTopic', value: { id: tmpId, title: '' } },
'internal_createGenerationTopic',
);
get().internal_updateGenerationTopicLoading(tmpId, true);
this.#get().internal_updateGenerationTopicLoading(tmpId, true);
const topicId = await generationTopicService.createTopic('video');
get().internal_updateGenerationTopicLoading(tmpId, false);
this.#get().internal_updateGenerationTopicLoading(tmpId, false);
get().internal_updateGenerationTopicLoading(topicId, true);
await get().refreshGenerationTopics();
get().internal_updateGenerationTopicLoading(topicId, false);
this.#get().internal_updateGenerationTopicLoading(topicId, true);
await this.#get().refreshGenerationTopics();
this.#get().internal_updateGenerationTopicLoading(topicId, false);
return topicId;
},
};
internal_dispatchGenerationTopic: (payload, action) => {
const nextTopics = generationTopicReducer(get().generationTopics, payload);
internal_dispatchGenerationTopic = (payload: GenerationTopicDispatch, action?: any): void => {
const nextTopics = generationTopicReducer(this.#get().generationTopics, payload);
if (isEqual(nextTopics, get().generationTopics)) return;
if (isEqual(nextTopics, this.#get().generationTopics)) return;
set(
this.#set(
{ generationTopics: nextTopics },
false,
action ?? n(`dispatchGenerationTopic/${payload.type}`),
);
},
};
internal_removeGenerationTopic: async (id: string) => {
get().internal_updateGenerationTopicLoading(id, true);
internal_removeGenerationTopic = async (id: string): Promise<void> => {
this.#get().internal_updateGenerationTopicLoading(id, true);
try {
await generationTopicService.deleteTopic(id);
await get().refreshGenerationTopics();
await this.#get().refreshGenerationTopics();
} finally {
get().internal_updateGenerationTopicLoading(id, false);
this.#get().internal_updateGenerationTopicLoading(id, false);
}
},
};
internal_updateGenerationTopic: async (id, data) => {
get().internal_dispatchGenerationTopic({ id, type: 'updateTopic', value: data });
internal_updateGenerationTopic = async (id: string, data: UpdateTopicValue): Promise<void> => {
this.#get().internal_dispatchGenerationTopic({ id, type: 'updateTopic', value: data });
get().internal_updateGenerationTopicLoading(id, true);
this.#get().internal_updateGenerationTopicLoading(id, true);
await generationTopicService.updateTopic(id, data);
await get().refreshGenerationTopics();
get().internal_updateGenerationTopicLoading(id, false);
},
await this.#get().refreshGenerationTopics();
this.#get().internal_updateGenerationTopicLoading(id, false);
};
internal_updateGenerationTopicCover: async (topicId: string, coverUrl: string) => {
internal_updateGenerationTopicCover = async (
topicId: string,
coverUrl: string,
): Promise<void> => {
const {
internal_dispatchGenerationTopic,
internal_updateGenerationTopicLoading,
refreshGenerationTopics,
} = get();
} = this.#get();
internal_dispatchGenerationTopic(
{ id: topicId, type: 'updateTopic', value: { coverUrl } },
@ -135,10 +128,10 @@ export const createGenerationTopicSlice: StateCreator<
} finally {
internal_updateGenerationTopicLoading(topicId, false);
}
},
};
internal_updateGenerationTopicLoading: (id, loading) => {
set(
internal_updateGenerationTopicLoading = (id: string, loading: boolean): void => {
this.#set(
(state) => {
if (loading) return { loadingGenerationTopicIds: [...state.loadingGenerationTopicIds, id] };
@ -149,31 +142,31 @@ export const createGenerationTopicSlice: StateCreator<
false,
n('updateGenerationTopicLoading'),
);
},
};
internal_updateGenerationTopicTitleInSummary: (id, title) => {
get().internal_dispatchGenerationTopic(
internal_updateGenerationTopicTitleInSummary = (id: string, title: string): void => {
this.#get().internal_dispatchGenerationTopic(
{ id, type: 'updateTopic', value: { title } },
'updateGenerationTopicTitleInSummary',
);
},
};
openNewGenerationTopic: () => {
set({ activeGenerationTopicId: null }, false, n('openNewGenerationTopic'));
},
openNewGenerationTopic = (): void => {
this.#set({ activeGenerationTopicId: null }, false, n('openNewGenerationTopic'));
};
refreshGenerationTopics: async () => {
refreshGenerationTopics = async (): Promise<void> => {
await mutate([FETCH_GENERATION_TOPICS_KEY]);
},
};
removeGenerationTopic: async (id: string) => {
removeGenerationTopic = async (id: string): Promise<void> => {
const {
internal_removeGenerationTopic,
generationTopics,
activeGenerationTopicId,
switchGenerationTopic,
openNewGenerationTopic,
} = get();
} = this.#get();
const isRemovingActiveTopic = activeGenerationTopicId === id;
let topicIndexToRemove = -1;
@ -185,7 +178,7 @@ export const createGenerationTopicSlice: StateCreator<
await internal_removeGenerationTopic(id);
if (isRemovingActiveTopic) {
const newTopics = get().generationTopics;
const newTopics = this.#get().generationTopics;
if (newTopics.length > 0) {
const newActiveIndex = Math.min(topicIndexToRemove, newTopics.length - 1);
@ -200,14 +193,14 @@ export const createGenerationTopicSlice: StateCreator<
openNewGenerationTopic();
}
}
},
};
summaryGenerationTopicTitle: async (topicId: string, prompts: string[]) => {
const topic = generationTopicSelectors.getGenerationTopicById(topicId)(get());
summaryGenerationTopicTitle = async (topicId: string, prompts: string[]): Promise<string> => {
const topic = generationTopicSelectors.getGenerationTopicById(topicId)(this.#get());
if (!topic) throw new Error(`Topic ${topicId} not found`);
const { internal_updateGenerationTopicTitleInSummary, internal_updateGenerationTopicLoading } =
get();
this.#get();
internal_updateGenerationTopicLoading(topicId, true);
internal_updateGenerationTopicTitleInSummary(topicId, LOADING_FLAT);
@ -233,10 +226,10 @@ export const createGenerationTopicSlice: StateCreator<
onError: async () => {
const fallbackTitle = generateFallbackTitle();
internal_updateGenerationTopicTitleInSummary(topicId, fallbackTitle);
await get().internal_updateGenerationTopic(topicId, { title: fallbackTitle });
await this.#get().internal_updateGenerationTopic(topicId, { title: fallbackTitle });
},
onFinish: async (text) => {
await get().internal_updateGenerationTopic(topicId, { title: text });
await this.#get().internal_updateGenerationTopic(topicId, { title: text });
},
onLoadingChange: (loading) => {
internal_updateGenerationTopicLoading(topicId, loading);
@ -260,29 +253,34 @@ export const createGenerationTopicSlice: StateCreator<
});
return output;
},
};
switchGenerationTopic: (topicId: string) => {
if (get().activeGenerationTopicId === topicId) return;
switchGenerationTopic = (topicId: string): void => {
if (this.#get().activeGenerationTopicId === topicId) return;
set({ activeGenerationTopicId: topicId }, false, n('switchGenerationTopic'));
},
this.#set({ activeGenerationTopicId: topicId }, false, n('switchGenerationTopic'));
};
updateGenerationTopicCover: async (topicId: string, coverUrl: string) => {
const { internal_updateGenerationTopicCover } = get();
updateGenerationTopicCover = async (topicId: string, coverUrl: string): Promise<void> => {
const { internal_updateGenerationTopicCover } = this.#get();
await internal_updateGenerationTopicCover(topicId, coverUrl);
},
};
useFetchGenerationTopics: (enabled) =>
useFetchGenerationTopics = (enabled: boolean): SWRResponse<ImageGenerationTopic[]> =>
useClientDataSWR<ImageGenerationTopic[]>(
enabled ? [FETCH_GENERATION_TOPICS_KEY] : null,
() => generationTopicService.getAllGenerationTopics('video'),
{
onSuccess: (data) => {
if (isEqual(data, get().generationTopics)) return;
set({ generationTopics: data }, false, n('useFetchGenerationTopics'));
if (isEqual(data, this.#get().generationTopics)) return;
this.#set({ generationTopics: data }, false, n('useFetchGenerationTopics'));
},
suspense: true,
},
),
});
);
}
export type GenerationTopicAction = Pick<
GenerationTopicActionImpl,
keyof GenerationTopicActionImpl
>;

View file

@ -5,6 +5,7 @@ import { type StateCreator } from 'zustand/vanilla';
import { createDevtools } from '../middleware/createDevtools';
import { expose } from '../middleware/expose';
import { flattenActions } from '../utils/flattenActions';
import { initialState, type VideoStoreState } from './initialState';
import { createCreateVideoSlice, type CreateVideoAction } from './slices/createVideo/action';
import {
@ -22,20 +23,23 @@ import {
// =============== aggregate createStoreFn ============ //
export interface VideoStore
extends
GenerationConfigAction,
GenerationTopicAction,
GenerationBatchAction,
CreateVideoAction,
VideoStoreState {}
type VideoStoreAction = GenerationConfigAction &
GenerationTopicAction &
GenerationBatchAction &
CreateVideoAction;
const createStore: StateCreator<VideoStore, [['zustand/devtools', never]]> = (...parameters) => ({
export interface VideoStore extends VideoStoreAction, VideoStoreState {}
const createStore: StateCreator<VideoStore, [['zustand/devtools', never]]> = (
...parameters: Parameters<StateCreator<VideoStore, [['zustand/devtools', never]]>>
) => ({
...initialState,
...createGenerationConfigSlice(...parameters),
...createGenerationTopicSlice(...parameters),
...createGenerationBatchSlice(...parameters),
...createCreateVideoSlice(...parameters),
...flattenActions<VideoStoreAction>([
createGenerationConfigSlice(...parameters),
createGenerationTopicSlice(...parameters),
createGenerationBatchSlice(...parameters),
createCreateVideoSlice(...parameters),
]),
});
// =============== implement useStore ============ //