mirror of
https://github.com/lobehub/lobehub
synced 2026-04-21 09:37:28 +00:00
♻️ refactor(store): class-based Zustand actions with flattenActions (#13383)
♻️ refactor(store): migrate slices to class actions with flattenActions
- Video store: generationConfig/Topic/Batch/createVideo as *ActionImpl; aggregate with flattenActions
- Eval store: benchmark/dataset/run/testCase as classes; top-level flattenActions
- Tool agentSkills: AgentSkillsActionImpl + Pick typing
- groupProfile: flattenActions around ActionImpl instead of spreading instance
- agentGroup: wrap chatGroupAction with flattenActions for consistent aggregation
Made-with: Cursor
This commit is contained in:
parent
6402656ec7
commit
491aba4dbd
13 changed files with 568 additions and 584 deletions
|
|
@ -4,6 +4,7 @@ import { type StateCreator } from 'zustand/vanilla';
|
|||
|
||||
import { createDevtools } from '../middleware/createDevtools';
|
||||
import { expose } from '../middleware/expose';
|
||||
import { flattenActions } from '../utils/flattenActions';
|
||||
import { type ChatGroupAction } from './action';
|
||||
import { chatGroupAction } from './action';
|
||||
import { type ChatGroupState } from './initialState';
|
||||
|
|
@ -15,7 +16,7 @@ const createStore: StateCreator<ChatGroupStore, [['zustand/devtools', never]]> =
|
|||
...params: Parameters<StateCreator<ChatGroupStore, [['zustand/devtools', never]]>>
|
||||
) => ({
|
||||
...initialChatGroupState,
|
||||
...chatGroupAction(...params),
|
||||
...flattenActions<ChatGroupAction>([chatGroupAction(...params)]),
|
||||
});
|
||||
|
||||
const devtools = createDevtools('agentGroup');
|
||||
|
|
|
|||
|
|
@ -1,161 +1,147 @@
|
|||
import isEqual from 'fast-deep-equal';
|
||||
import { type SWRResponse } from 'swr';
|
||||
import { type StateCreator } from 'zustand/vanilla';
|
||||
import { type SWRResponse } from 'swr';
|
||||
|
||||
import { mutate, useClientDataSWR } from '@/libs/swr';
|
||||
import { agentEvalService } from '@/services/agentEval';
|
||||
import { type EvalStore } from '@/store/eval/store';
|
||||
import { type EvalStore } from '@/store/eval/store';
|
||||
import { type StoreSetter } from '@/store/types';
|
||||
|
||||
import { type BenchmarkDetailDispatch,benchmarkDetailReducer } from './reducer';
|
||||
import { type BenchmarkDetailDispatch, benchmarkDetailReducer } from './reducer';
|
||||
|
||||
const FETCH_BENCHMARKS_KEY = 'FETCH_BENCHMARKS';
|
||||
const FETCH_BENCHMARK_DETAIL_KEY = 'FETCH_BENCHMARK_DETAIL';
|
||||
|
||||
export interface BenchmarkAction {
|
||||
createBenchmark: (params: {
|
||||
type Setter = StoreSetter<EvalStore>;
|
||||
|
||||
export const createBenchmarkSlice = (set: Setter, get: () => EvalStore, _api?: unknown) =>
|
||||
new BenchmarkActionImpl(set, get, _api);
|
||||
|
||||
export class BenchmarkActionImpl {
|
||||
readonly #get: () => EvalStore;
|
||||
readonly #set: Setter;
|
||||
|
||||
constructor(set: Setter, get: () => EvalStore, _api?: unknown) {
|
||||
void _api;
|
||||
this.#set = set;
|
||||
this.#get = get;
|
||||
}
|
||||
|
||||
createBenchmark = async (params: {
|
||||
description?: string;
|
||||
identifier: string;
|
||||
metadata?: Record<string, unknown>;
|
||||
name: string;
|
||||
rubrics?: any[];
|
||||
tags?: string[];
|
||||
}) => Promise<any>;
|
||||
deleteBenchmark: (id: string) => Promise<void>;
|
||||
// Internal methods
|
||||
internal_dispatchBenchmarkDetail: (payload: BenchmarkDetailDispatch) => void;
|
||||
internal_updateBenchmarkDetailLoading: (id: string, loading: boolean) => void;
|
||||
refreshBenchmarkDetail: (id: string) => Promise<void>;
|
||||
refreshBenchmarks: () => Promise<void>;
|
||||
updateBenchmark: (params: {
|
||||
}): Promise<any> => {
|
||||
this.#set({ isCreatingBenchmark: true }, false, 'createBenchmark/start');
|
||||
try {
|
||||
const result = await agentEvalService.createBenchmark({
|
||||
description: params.description,
|
||||
identifier: params.identifier,
|
||||
metadata: params.metadata,
|
||||
name: params.name,
|
||||
rubrics: params.rubrics ?? [],
|
||||
tags: params.tags,
|
||||
});
|
||||
await this.#get().refreshBenchmarks();
|
||||
return result;
|
||||
} finally {
|
||||
this.#set({ isCreatingBenchmark: false }, false, 'createBenchmark/end');
|
||||
}
|
||||
};
|
||||
|
||||
deleteBenchmark = async (id: string): Promise<void> => {
|
||||
this.#set({ isDeletingBenchmark: true }, false, 'deleteBenchmark/start');
|
||||
try {
|
||||
await agentEvalService.deleteBenchmark(id);
|
||||
await this.#get().refreshBenchmarks();
|
||||
} finally {
|
||||
this.#set({ isDeletingBenchmark: false }, false, 'deleteBenchmark/end');
|
||||
}
|
||||
};
|
||||
|
||||
refreshBenchmarkDetail = async (id: string): Promise<void> => {
|
||||
await mutate([FETCH_BENCHMARK_DETAIL_KEY, id]);
|
||||
};
|
||||
|
||||
refreshBenchmarks = async (): Promise<void> => {
|
||||
await mutate(FETCH_BENCHMARKS_KEY);
|
||||
};
|
||||
|
||||
updateBenchmark = async (params: {
|
||||
description?: string;
|
||||
id: string;
|
||||
identifier: string;
|
||||
metadata?: Record<string, unknown>;
|
||||
name: string;
|
||||
tags?: string[];
|
||||
}) => Promise<void>;
|
||||
|
||||
useFetchBenchmarkDetail: (id?: string) => SWRResponse;
|
||||
useFetchBenchmarks: () => SWRResponse;
|
||||
}
|
||||
|
||||
export const createBenchmarkSlice: StateCreator<
|
||||
EvalStore,
|
||||
[['zustand/devtools', never]],
|
||||
[],
|
||||
BenchmarkAction
|
||||
> = (set, get) => ({
|
||||
createBenchmark: async (params) => {
|
||||
set({ isCreatingBenchmark: true }, false, 'createBenchmark/start');
|
||||
try {
|
||||
const result = await agentEvalService.createBenchmark({
|
||||
identifier: params.identifier,
|
||||
name: params.name,
|
||||
description: params.description,
|
||||
metadata: params.metadata,
|
||||
rubrics: params.rubrics ?? [],
|
||||
tags: params.tags,
|
||||
});
|
||||
await get().refreshBenchmarks();
|
||||
return result;
|
||||
} finally {
|
||||
set({ isCreatingBenchmark: false }, false, 'createBenchmark/end');
|
||||
}
|
||||
},
|
||||
|
||||
deleteBenchmark: async (id) => {
|
||||
set({ isDeletingBenchmark: true }, false, 'deleteBenchmark/start');
|
||||
try {
|
||||
await agentEvalService.deleteBenchmark(id);
|
||||
await get().refreshBenchmarks();
|
||||
} finally {
|
||||
set({ isDeletingBenchmark: false }, false, 'deleteBenchmark/end');
|
||||
}
|
||||
},
|
||||
|
||||
refreshBenchmarkDetail: async (id) => {
|
||||
await mutate([FETCH_BENCHMARK_DETAIL_KEY, id]);
|
||||
},
|
||||
|
||||
refreshBenchmarks: async () => {
|
||||
await mutate(FETCH_BENCHMARKS_KEY);
|
||||
},
|
||||
|
||||
updateBenchmark: async (params) => {
|
||||
}): Promise<void> => {
|
||||
const { id } = params;
|
||||
|
||||
// 1. Optimistic update
|
||||
get().internal_dispatchBenchmarkDetail({
|
||||
type: 'updateBenchmarkDetail',
|
||||
this.#get().internal_dispatchBenchmarkDetail({
|
||||
id,
|
||||
type: 'updateBenchmarkDetail',
|
||||
value: params,
|
||||
});
|
||||
|
||||
// 2. Set loading
|
||||
get().internal_updateBenchmarkDetailLoading(id, true);
|
||||
this.#get().internal_updateBenchmarkDetailLoading(id, true);
|
||||
|
||||
try {
|
||||
// 3. Call service
|
||||
await agentEvalService.updateBenchmark({
|
||||
description: params.description,
|
||||
id: params.id,
|
||||
identifier: params.identifier,
|
||||
name: params.name,
|
||||
description: params.description,
|
||||
metadata: params.metadata,
|
||||
name: params.name,
|
||||
tags: params.tags,
|
||||
});
|
||||
|
||||
// 4. Refresh from server
|
||||
await get().refreshBenchmarks();
|
||||
await get().refreshBenchmarkDetail(id);
|
||||
await this.#get().refreshBenchmarks();
|
||||
await this.#get().refreshBenchmarkDetail(id);
|
||||
} finally {
|
||||
get().internal_updateBenchmarkDetailLoading(id, false);
|
||||
this.#get().internal_updateBenchmarkDetailLoading(id, false);
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
useFetchBenchmarkDetail: (id) => {
|
||||
return useClientDataSWR(
|
||||
useFetchBenchmarkDetail = (id?: string): SWRResponse =>
|
||||
useClientDataSWR(
|
||||
id ? [FETCH_BENCHMARK_DETAIL_KEY, id] : null,
|
||||
() => agentEvalService.getBenchmark(id!),
|
||||
{
|
||||
onSuccess: (data: any) => {
|
||||
get().internal_dispatchBenchmarkDetail({
|
||||
type: 'setBenchmarkDetail',
|
||||
this.#get().internal_dispatchBenchmarkDetail({
|
||||
id: id!,
|
||||
type: 'setBenchmarkDetail',
|
||||
value: data,
|
||||
});
|
||||
get().internal_updateBenchmarkDetailLoading(id!, false);
|
||||
this.#get().internal_updateBenchmarkDetailLoading(id!, false);
|
||||
},
|
||||
},
|
||||
);
|
||||
},
|
||||
|
||||
useFetchBenchmarks: () => {
|
||||
return useClientDataSWR(FETCH_BENCHMARKS_KEY, () => agentEvalService.listBenchmarks(), {
|
||||
useFetchBenchmarks = (): SWRResponse =>
|
||||
useClientDataSWR(FETCH_BENCHMARKS_KEY, () => agentEvalService.listBenchmarks(), {
|
||||
onSuccess: (data: any) => {
|
||||
set(
|
||||
this.#set(
|
||||
{ benchmarkList: data, benchmarkListInit: true, isLoadingBenchmarkList: false },
|
||||
false,
|
||||
'useFetchBenchmarks/success',
|
||||
);
|
||||
},
|
||||
});
|
||||
},
|
||||
|
||||
// Internal - Dispatch to reducer
|
||||
internal_dispatchBenchmarkDetail: (payload) => {
|
||||
const currentMap = get().benchmarkDetailMap;
|
||||
internal_dispatchBenchmarkDetail = (payload: BenchmarkDetailDispatch): void => {
|
||||
const currentMap = this.#get().benchmarkDetailMap;
|
||||
const nextMap = benchmarkDetailReducer(currentMap, payload);
|
||||
|
||||
// No need to update if map is the same
|
||||
if (isEqual(nextMap, currentMap)) return;
|
||||
|
||||
set({ benchmarkDetailMap: nextMap }, false, `dispatchBenchmarkDetail/${payload.type}`);
|
||||
},
|
||||
this.#set({ benchmarkDetailMap: nextMap }, false, `dispatchBenchmarkDetail/${payload.type}`);
|
||||
};
|
||||
|
||||
// Internal - Update loading state for specific detail
|
||||
internal_updateBenchmarkDetailLoading: (id, loading) => {
|
||||
set(
|
||||
internal_updateBenchmarkDetailLoading = (id: string, loading: boolean): void => {
|
||||
this.#set(
|
||||
(state) => {
|
||||
if (loading) {
|
||||
return { loadingBenchmarkDetailIds: [...state.loadingBenchmarkDetailIds, id] };
|
||||
|
|
@ -167,5 +153,7 @@ export const createBenchmarkSlice: StateCreator<
|
|||
false,
|
||||
'updateBenchmarkDetailLoading',
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
export type BenchmarkAction = Pick<BenchmarkActionImpl, keyof BenchmarkActionImpl>;
|
||||
|
|
|
|||
|
|
@ -1,65 +1,62 @@
|
|||
import isEqual from 'fast-deep-equal';
|
||||
import { type SWRResponse } from 'swr';
|
||||
import { type StateCreator } from 'zustand/vanilla';
|
||||
import { type SWRResponse } from 'swr';
|
||||
|
||||
import { mutate, useClientDataSWR } from '@/libs/swr';
|
||||
import { agentEvalService } from '@/services/agentEval';
|
||||
import { type EvalStore } from '@/store/eval/store';
|
||||
import { type EvalStore } from '@/store/eval/store';
|
||||
import { type StoreSetter } from '@/store/types';
|
||||
|
||||
import { type DatasetDetailDispatch,datasetDetailReducer } from './reducer';
|
||||
import { type DatasetDetailDispatch, datasetDetailReducer } from './reducer';
|
||||
|
||||
const FETCH_DATASETS_KEY = 'FETCH_DATASETS';
|
||||
const FETCH_DATASET_DETAIL_KEY = 'FETCH_DATASET_DETAIL';
|
||||
|
||||
export interface DatasetAction {
|
||||
// Internal methods
|
||||
internal_dispatchDatasetDetail: (payload: DatasetDetailDispatch) => void;
|
||||
internal_updateDatasetDetailLoading: (id: string, loading: boolean) => void;
|
||||
refreshDatasetDetail: (id: string) => Promise<void>;
|
||||
refreshDatasets: (benchmarkId: string) => Promise<void>;
|
||||
type Setter = StoreSetter<EvalStore>;
|
||||
|
||||
useFetchDatasetDetail: (id?: string) => SWRResponse;
|
||||
useFetchDatasets: (benchmarkId?: string) => SWRResponse;
|
||||
}
|
||||
export const createDatasetSlice = (set: Setter, get: () => EvalStore, _api?: unknown) =>
|
||||
new DatasetActionImpl(set, get, _api);
|
||||
|
||||
export const createDatasetSlice: StateCreator<
|
||||
EvalStore,
|
||||
[['zustand/devtools', never]],
|
||||
[],
|
||||
DatasetAction
|
||||
> = (set, get) => ({
|
||||
refreshDatasetDetail: async (id) => {
|
||||
export class DatasetActionImpl {
|
||||
readonly #get: () => EvalStore;
|
||||
readonly #set: Setter;
|
||||
|
||||
constructor(set: Setter, get: () => EvalStore, _api?: unknown) {
|
||||
void _api;
|
||||
this.#set = set;
|
||||
this.#get = get;
|
||||
}
|
||||
|
||||
refreshDatasetDetail = async (id: string): Promise<void> => {
|
||||
await mutate([FETCH_DATASET_DETAIL_KEY, id]);
|
||||
},
|
||||
};
|
||||
|
||||
refreshDatasets: async (benchmarkId) => {
|
||||
refreshDatasets = async (benchmarkId: string): Promise<void> => {
|
||||
await mutate([FETCH_DATASETS_KEY, benchmarkId]);
|
||||
},
|
||||
};
|
||||
|
||||
useFetchDatasetDetail: (id) => {
|
||||
return useClientDataSWR(
|
||||
useFetchDatasetDetail = (id?: string): SWRResponse =>
|
||||
useClientDataSWR(
|
||||
id ? [FETCH_DATASET_DETAIL_KEY, id] : null,
|
||||
() => agentEvalService.getDataset(id!),
|
||||
{
|
||||
onSuccess: (data: any) => {
|
||||
get().internal_dispatchDatasetDetail({
|
||||
type: 'setDatasetDetail',
|
||||
this.#get().internal_dispatchDatasetDetail({
|
||||
id: id!,
|
||||
type: 'setDatasetDetail',
|
||||
value: data,
|
||||
});
|
||||
get().internal_updateDatasetDetailLoading(id!, false);
|
||||
this.#get().internal_updateDatasetDetailLoading(id!, false);
|
||||
},
|
||||
},
|
||||
);
|
||||
},
|
||||
|
||||
useFetchDatasets: (benchmarkId) => {
|
||||
return useClientDataSWR(
|
||||
useFetchDatasets = (benchmarkId?: string): SWRResponse =>
|
||||
useClientDataSWR(
|
||||
benchmarkId ? [FETCH_DATASETS_KEY, benchmarkId] : null,
|
||||
() => agentEvalService.listDatasets(benchmarkId!),
|
||||
{
|
||||
onSuccess: (data: any) => {
|
||||
set(
|
||||
this.#set(
|
||||
{
|
||||
datasetList: data,
|
||||
isLoadingDatasets: false,
|
||||
|
|
@ -70,22 +67,18 @@ export const createDatasetSlice: StateCreator<
|
|||
},
|
||||
},
|
||||
);
|
||||
},
|
||||
|
||||
// Internal - Dispatch to reducer
|
||||
internal_dispatchDatasetDetail: (payload) => {
|
||||
const currentMap = get().datasetDetailMap;
|
||||
internal_dispatchDatasetDetail = (payload: DatasetDetailDispatch): void => {
|
||||
const currentMap = this.#get().datasetDetailMap;
|
||||
const nextMap = datasetDetailReducer(currentMap, payload);
|
||||
|
||||
// No need to update if map is the same
|
||||
if (isEqual(nextMap, currentMap)) return;
|
||||
|
||||
set({ datasetDetailMap: nextMap }, false, `dispatchDatasetDetail/${payload.type}`);
|
||||
},
|
||||
this.#set({ datasetDetailMap: nextMap }, false, `dispatchDatasetDetail/${payload.type}`);
|
||||
};
|
||||
|
||||
// Internal - Update loading state for specific detail
|
||||
internal_updateDatasetDetailLoading: (id, loading) => {
|
||||
set(
|
||||
internal_updateDatasetDetailLoading = (id: string, loading: boolean): void => {
|
||||
this.#set(
|
||||
(state) => {
|
||||
if (loading) {
|
||||
return { loadingDatasetDetailIds: [...state.loadingDatasetDetailIds, id] };
|
||||
|
|
@ -97,5 +90,7 @@ export const createDatasetSlice: StateCreator<
|
|||
false,
|
||||
'updateDatasetDetailLoading',
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
export type DatasetAction = Pick<DatasetActionImpl, keyof DatasetActionImpl>;
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
import type { EvalRunInputConfig } from '@lobechat/types';
|
||||
import isEqual from 'fast-deep-equal';
|
||||
import type { SWRResponse } from 'swr';
|
||||
import type { StateCreator } from 'zustand/vanilla';
|
||||
|
||||
import { mutate, useClientDataSWR } from '@/libs/swr';
|
||||
import { agentEvalService } from '@/services/agentEval';
|
||||
import type { EvalStore } from '@/store/eval/store';
|
||||
import { type StoreSetter } from '@/store/types';
|
||||
|
||||
import { type RunDetailDispatch, runDetailReducer } from './reducer';
|
||||
|
||||
|
|
@ -14,76 +14,59 @@ const FETCH_DATASET_RUNS_KEY = 'FETCH_EVAL_DATASET_RUNS';
|
|||
const FETCH_RUN_DETAIL_KEY = 'FETCH_EVAL_RUN_DETAIL';
|
||||
const FETCH_RUN_RESULTS_KEY = 'FETCH_EVAL_RUN_RESULTS';
|
||||
|
||||
export interface RunAction {
|
||||
abortRun: (id: string) => Promise<void>;
|
||||
createRun: (params: {
|
||||
type Setter = StoreSetter<EvalStore>;
|
||||
|
||||
export const createRunSlice = (set: Setter, get: () => EvalStore, _api?: unknown) =>
|
||||
new RunActionImpl(set, get, _api);
|
||||
|
||||
export class RunActionImpl {
|
||||
readonly #get: () => EvalStore;
|
||||
readonly #set: Setter;
|
||||
|
||||
constructor(set: Setter, get: () => EvalStore, _api?: unknown) {
|
||||
void _api;
|
||||
this.#set = set;
|
||||
this.#get = get;
|
||||
}
|
||||
|
||||
abortRun = async (id: string): Promise<void> => {
|
||||
await agentEvalService.abortRun(id);
|
||||
await this.#get().refreshRunDetail(id);
|
||||
};
|
||||
|
||||
createRun = async (params: {
|
||||
config?: EvalRunInputConfig;
|
||||
datasetId: string;
|
||||
name?: string;
|
||||
targetAgentId?: string;
|
||||
}) => Promise<any>;
|
||||
deleteRun: (id: string) => Promise<void>;
|
||||
internal_dispatchRunDetail: (payload: RunDetailDispatch) => void;
|
||||
internal_updateRunDetailLoading: (id: string, loading: boolean) => void;
|
||||
internal_updateRunResultLoading: (id: string, loading: boolean) => void;
|
||||
refreshDatasetRuns: (datasetId: string) => Promise<void>;
|
||||
refreshRunDetail: (id: string) => Promise<void>;
|
||||
refreshRuns: (benchmarkId?: string) => Promise<void>;
|
||||
retryRunCase: (runId: string, testCaseId: string) => Promise<void>;
|
||||
retryRunErrors: (id: string) => Promise<void>;
|
||||
startRun: (id: string, force?: boolean) => Promise<void>;
|
||||
updateRun: (params: {
|
||||
config?: EvalRunInputConfig;
|
||||
datasetId?: string;
|
||||
id: string;
|
||||
name?: string;
|
||||
targetAgentId?: string | null;
|
||||
}) => Promise<any>;
|
||||
useFetchDatasetRuns: (datasetId?: string) => SWRResponse;
|
||||
useFetchRunDetail: (id: string, config?: { refreshInterval?: number }) => SWRResponse;
|
||||
useFetchRunResults: (id: string, config?: { refreshInterval?: number }) => SWRResponse;
|
||||
useFetchRuns: (benchmarkId?: string) => SWRResponse;
|
||||
}
|
||||
|
||||
export const createRunSlice: StateCreator<
|
||||
EvalStore,
|
||||
[['zustand/devtools', never]],
|
||||
[],
|
||||
RunAction
|
||||
> = (set, get) => ({
|
||||
abortRun: async (id) => {
|
||||
await agentEvalService.abortRun(id);
|
||||
await get().refreshRunDetail(id);
|
||||
},
|
||||
|
||||
createRun: async (params) => {
|
||||
set({ isCreatingRun: true }, false, 'createRun/start');
|
||||
}): Promise<any> => {
|
||||
this.#set({ isCreatingRun: true }, false, 'createRun/start');
|
||||
try {
|
||||
const result = await agentEvalService.createRun(params);
|
||||
await get().refreshRuns();
|
||||
await this.#get().refreshRuns();
|
||||
return result;
|
||||
} finally {
|
||||
set({ isCreatingRun: false }, false, 'createRun/end');
|
||||
this.#set({ isCreatingRun: false }, false, 'createRun/end');
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
deleteRun: async (id) => {
|
||||
deleteRun = async (id: string): Promise<void> => {
|
||||
await agentEvalService.deleteRun(id);
|
||||
get().internal_dispatchRunDetail({ id, type: 'deleteRunDetail' });
|
||||
await get().refreshRuns();
|
||||
},
|
||||
this.#get().internal_dispatchRunDetail({ id, type: 'deleteRunDetail' });
|
||||
await this.#get().refreshRuns();
|
||||
};
|
||||
|
||||
internal_dispatchRunDetail: (payload) => {
|
||||
const currentMap = get().runDetailMap;
|
||||
internal_dispatchRunDetail = (payload: RunDetailDispatch): void => {
|
||||
const currentMap = this.#get().runDetailMap;
|
||||
const nextMap = runDetailReducer(currentMap, payload);
|
||||
|
||||
if (isEqual(nextMap, currentMap)) return;
|
||||
|
||||
set({ runDetailMap: nextMap }, false, `dispatchRunDetail/${payload.type}`);
|
||||
},
|
||||
this.#set({ runDetailMap: nextMap }, false, `dispatchRunDetail/${payload.type}`);
|
||||
};
|
||||
|
||||
internal_updateRunDetailLoading: (id, loading) => {
|
||||
set(
|
||||
internal_updateRunDetailLoading = (id: string, loading: boolean): void => {
|
||||
this.#set(
|
||||
(state) => {
|
||||
if (loading) {
|
||||
return { loadingRunDetailIds: [...state.loadingRunDetailIds, id] };
|
||||
|
|
@ -95,10 +78,10 @@ export const createRunSlice: StateCreator<
|
|||
false,
|
||||
'updateRunDetailLoading',
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
internal_updateRunResultLoading: (id, loading) => {
|
||||
set(
|
||||
internal_updateRunResultLoading = (id: string, loading: boolean): void => {
|
||||
this.#set(
|
||||
(state) => {
|
||||
if (loading) {
|
||||
return { loadingRunResultIds: [...state.loadingRunResultIds, id] };
|
||||
|
|
@ -110,92 +93,95 @@ export const createRunSlice: StateCreator<
|
|||
false,
|
||||
'updateRunResultLoading',
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
refreshDatasetRuns: async (datasetId) => {
|
||||
refreshDatasetRuns = async (datasetId: string): Promise<void> => {
|
||||
await mutate([FETCH_DATASET_RUNS_KEY, datasetId]);
|
||||
},
|
||||
};
|
||||
|
||||
refreshRunDetail: async (id) => {
|
||||
refreshRunDetail = async (id: string): Promise<void> => {
|
||||
await mutate([FETCH_RUN_DETAIL_KEY, id]);
|
||||
},
|
||||
};
|
||||
|
||||
refreshRuns: async (benchmarkId) => {
|
||||
refreshRuns = async (benchmarkId?: string): Promise<void> => {
|
||||
if (benchmarkId) {
|
||||
await mutate([FETCH_RUNS_KEY, benchmarkId]);
|
||||
} else {
|
||||
// Revalidate all benchmark-level run list entries
|
||||
await mutate((key) => Array.isArray(key) && key[0] === FETCH_RUNS_KEY);
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
retryRunCase: async (runId, testCaseId) => {
|
||||
retryRunCase = async (runId: string, testCaseId: string): Promise<void> => {
|
||||
await agentEvalService.retryRunCase(runId, testCaseId);
|
||||
await get().refreshRunDetail(runId);
|
||||
},
|
||||
await this.#get().refreshRunDetail(runId);
|
||||
};
|
||||
|
||||
retryRunErrors: async (id) => {
|
||||
retryRunErrors = async (id: string): Promise<void> => {
|
||||
await agentEvalService.retryRunErrors(id);
|
||||
await get().refreshRunDetail(id);
|
||||
},
|
||||
await this.#get().refreshRunDetail(id);
|
||||
};
|
||||
|
||||
startRun: async (id, force) => {
|
||||
startRun = async (id: string, force?: boolean): Promise<void> => {
|
||||
await agentEvalService.startRun(id, force);
|
||||
await get().refreshRunDetail(id);
|
||||
},
|
||||
await this.#get().refreshRunDetail(id);
|
||||
};
|
||||
|
||||
updateRun: async (params) => {
|
||||
updateRun = async (params: {
|
||||
config?: EvalRunInputConfig;
|
||||
datasetId?: string;
|
||||
id: string;
|
||||
name?: string;
|
||||
targetAgentId?: string | null;
|
||||
}): Promise<any> => {
|
||||
const result = await agentEvalService.updateRun(params);
|
||||
await get().refreshRunDetail(params.id);
|
||||
await get().refreshRuns();
|
||||
await this.#get().refreshRunDetail(params.id);
|
||||
await this.#get().refreshRuns();
|
||||
return result;
|
||||
},
|
||||
};
|
||||
|
||||
useFetchRunDetail: (id, config) => {
|
||||
return useClientDataSWR(
|
||||
useFetchRunDetail = (id: string, config?: { refreshInterval?: number }): SWRResponse =>
|
||||
useClientDataSWR(
|
||||
id ? [FETCH_RUN_DETAIL_KEY, id] : null,
|
||||
() => agentEvalService.getRunDetails(id),
|
||||
{
|
||||
...config,
|
||||
onSuccess: (data: any) => {
|
||||
get().internal_dispatchRunDetail({
|
||||
this.#get().internal_dispatchRunDetail({
|
||||
id,
|
||||
type: 'setRunDetail',
|
||||
value: data,
|
||||
});
|
||||
get().internal_updateRunDetailLoading(id, false);
|
||||
this.#get().internal_updateRunDetailLoading(id, false);
|
||||
},
|
||||
},
|
||||
);
|
||||
},
|
||||
|
||||
useFetchRunResults: (id, config) => {
|
||||
return useClientDataSWR(
|
||||
useFetchRunResults = (id: string, config?: { refreshInterval?: number }): SWRResponse =>
|
||||
useClientDataSWR(
|
||||
id ? [FETCH_RUN_RESULTS_KEY, id] : null,
|
||||
() => agentEvalService.getRunResults(id),
|
||||
{
|
||||
...config,
|
||||
onSuccess: (data: any) => {
|
||||
set(
|
||||
this.#set(
|
||||
(state) => ({
|
||||
runResultsMap: { ...state.runResultsMap, [id]: data },
|
||||
}),
|
||||
false,
|
||||
'useFetchRunResults/success',
|
||||
);
|
||||
get().internal_updateRunResultLoading(id, false);
|
||||
this.#get().internal_updateRunResultLoading(id, false);
|
||||
},
|
||||
},
|
||||
);
|
||||
},
|
||||
|
||||
useFetchDatasetRuns: (datasetId) => {
|
||||
return useClientDataSWR(
|
||||
useFetchDatasetRuns = (datasetId?: string): SWRResponse =>
|
||||
useClientDataSWR(
|
||||
datasetId ? [FETCH_DATASET_RUNS_KEY, datasetId] : null,
|
||||
() => agentEvalService.listRuns({ datasetId: datasetId! }),
|
||||
{
|
||||
onSuccess: (data: any) => {
|
||||
set(
|
||||
this.#set(
|
||||
(state) => ({
|
||||
datasetRunListMap: { ...state.datasetRunListMap, [datasetId!]: data.data },
|
||||
}),
|
||||
|
|
@ -205,17 +191,17 @@ export const createRunSlice: StateCreator<
|
|||
},
|
||||
},
|
||||
);
|
||||
},
|
||||
|
||||
useFetchRuns: (benchmarkId) => {
|
||||
return useClientDataSWR(
|
||||
useFetchRuns = (benchmarkId?: string): SWRResponse =>
|
||||
useClientDataSWR(
|
||||
benchmarkId ? [FETCH_RUNS_KEY, benchmarkId] : null,
|
||||
() => agentEvalService.listRuns({ benchmarkId: benchmarkId! }),
|
||||
{
|
||||
onSuccess: (data: any) => {
|
||||
set({ isLoadingRuns: false, runList: data.data }, false, 'useFetchRuns/success');
|
||||
this.#set({ isLoadingRuns: false, runList: data.data }, false, 'useFetchRuns/success');
|
||||
},
|
||||
},
|
||||
);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export type RunAction = Pick<RunActionImpl, keyof RunActionImpl>;
|
||||
|
|
|
|||
|
|
@ -1,54 +1,50 @@
|
|||
import type { SWRResponse } from 'swr';
|
||||
import type { StateCreator } from 'zustand/vanilla';
|
||||
|
||||
import { mutate, useClientDataSWR } from '@/libs/swr';
|
||||
import { agentEvalService } from '@/services/agentEval';
|
||||
import type { EvalStore } from '@/store/eval/store';
|
||||
import { type StoreSetter } from '@/store/types';
|
||||
|
||||
const FETCH_TEST_CASES_KEY = 'FETCH_TEST_CASES';
|
||||
|
||||
export interface TestCaseAction {
|
||||
getTestCasesByDatasetId: (datasetId: string) => any[];
|
||||
getTestCasesTotalByDatasetId: (datasetId: string) => number;
|
||||
isLoadingTestCases: (datasetId: string) => boolean;
|
||||
refreshTestCases: (datasetId: string) => Promise<void>;
|
||||
useFetchTestCases: (params: {
|
||||
type Setter = StoreSetter<EvalStore>;
|
||||
|
||||
export const createTestCaseSlice = (set: Setter, get: () => EvalStore, _api?: unknown) =>
|
||||
new TestCaseActionImpl(set, get, _api);
|
||||
|
||||
export class TestCaseActionImpl {
|
||||
readonly #get: () => EvalStore;
|
||||
readonly #set: Setter;
|
||||
|
||||
constructor(set: Setter, get: () => EvalStore, _api?: unknown) {
|
||||
void _api;
|
||||
this.#set = set;
|
||||
this.#get = get;
|
||||
}
|
||||
|
||||
getTestCasesByDatasetId = (datasetId: string): any[] => {
|
||||
return this.#get().testCasesCache[datasetId]?.data || [];
|
||||
};
|
||||
|
||||
getTestCasesTotalByDatasetId = (datasetId: string): number => {
|
||||
return this.#get().testCasesCache[datasetId]?.total || 0;
|
||||
};
|
||||
|
||||
isLoadingTestCases = (datasetId: string): boolean => {
|
||||
return this.#get().loadingTestCaseIds.includes(datasetId);
|
||||
};
|
||||
|
||||
refreshTestCases = async (datasetId: string): Promise<void> => {
|
||||
await mutate(
|
||||
(key) => Array.isArray(key) && key[0] === FETCH_TEST_CASES_KEY && key[1] === datasetId,
|
||||
);
|
||||
};
|
||||
|
||||
useFetchTestCases = (params: {
|
||||
datasetId: string;
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
}) => SWRResponse;
|
||||
}
|
||||
|
||||
export const createTestCaseSlice: StateCreator<
|
||||
EvalStore,
|
||||
[['zustand/devtools', never]],
|
||||
[],
|
||||
TestCaseAction
|
||||
> = (set, get) => ({
|
||||
// Get test cases for a specific dataset from cache
|
||||
getTestCasesByDatasetId: (datasetId) => {
|
||||
return get().testCasesCache[datasetId]?.data || [];
|
||||
},
|
||||
|
||||
// Get total count for a specific dataset from cache
|
||||
getTestCasesTotalByDatasetId: (datasetId) => {
|
||||
return get().testCasesCache[datasetId]?.total || 0;
|
||||
},
|
||||
|
||||
// Check if test cases are currently loading for a dataset
|
||||
isLoadingTestCases: (datasetId) => {
|
||||
return get().loadingTestCaseIds.includes(datasetId);
|
||||
},
|
||||
|
||||
refreshTestCases: async (datasetId) => {
|
||||
// Mutate all SWR keys that start with [FETCH_TEST_CASES_KEY, datasetId]
|
||||
await mutate(
|
||||
(key) =>
|
||||
Array.isArray(key) && key[0] === FETCH_TEST_CASES_KEY && key[1] === datasetId,
|
||||
);
|
||||
},
|
||||
|
||||
useFetchTestCases: (params) => {
|
||||
}): SWRResponse => {
|
||||
const { datasetId, limit = 10, offset = 0 } = params;
|
||||
|
||||
return useClientDataSWR(
|
||||
|
|
@ -56,7 +52,7 @@ export const createTestCaseSlice: StateCreator<
|
|||
() => agentEvalService.listTestCases({ datasetId, limit, offset }),
|
||||
{
|
||||
onSuccess: (data: any) => {
|
||||
set(
|
||||
this.#set(
|
||||
(state) => ({
|
||||
loadingTestCaseIds: state.loadingTestCaseIds.filter((id) => id !== datasetId),
|
||||
testCasesCache: {
|
||||
|
|
@ -74,5 +70,7 @@ export const createTestCaseSlice: StateCreator<
|
|||
},
|
||||
},
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
export type TestCaseAction = Pick<TestCaseActionImpl, keyof TestCaseActionImpl>;
|
||||
|
|
|
|||
|
|
@ -4,24 +4,27 @@ import type { StateCreator } from 'zustand/vanilla';
|
|||
|
||||
import { createDevtools } from '../middleware/createDevtools';
|
||||
import { expose } from '../middleware/expose';
|
||||
import { flattenActions } from '../utils/flattenActions';
|
||||
import { type EvalStoreState, initialState } from './initialState';
|
||||
import { type BenchmarkAction, createBenchmarkSlice } from './slices/benchmark/action';
|
||||
import { createDatasetSlice, type DatasetAction } from './slices/dataset/action';
|
||||
import { createRunSlice, type RunAction } from './slices/run/action';
|
||||
import { createTestCaseSlice, type TestCaseAction } from './slices/testCase/action';
|
||||
|
||||
export type EvalStore = EvalStoreState &
|
||||
BenchmarkAction &
|
||||
DatasetAction &
|
||||
RunAction &
|
||||
TestCaseAction;
|
||||
type EvalStoreAction = BenchmarkAction & DatasetAction & RunAction & TestCaseAction;
|
||||
|
||||
const createStore: StateCreator<EvalStore, [['zustand/devtools', never]]> = (set, get, store) => ({
|
||||
export type EvalStore = EvalStoreState & EvalStoreAction;
|
||||
|
||||
const createStore: StateCreator<EvalStore, [['zustand/devtools', never]]> = (
|
||||
...parameters: Parameters<StateCreator<EvalStore, [['zustand/devtools', never]]>>
|
||||
) => ({
|
||||
...initialState,
|
||||
...createBenchmarkSlice(set, get, store),
|
||||
...createDatasetSlice(set, get, store),
|
||||
...createRunSlice(set, get, store),
|
||||
...createTestCaseSlice(set, get, store),
|
||||
...flattenActions<EvalStoreAction>([
|
||||
createBenchmarkSlice(...parameters),
|
||||
createDatasetSlice(...parameters),
|
||||
createRunSlice(...parameters),
|
||||
createTestCaseSlice(...parameters),
|
||||
]),
|
||||
});
|
||||
|
||||
const devtools = createDevtools('eval');
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import { type StateCreator } from 'zustand';
|
|||
|
||||
import { EDITOR_DEBOUNCE_TIME, EDITOR_MAX_WAIT } from '@/const/index';
|
||||
import { type StoreSetter } from '@/store/types';
|
||||
import { flattenActions } from '@/store/utils/flattenActions';
|
||||
|
||||
import { type SaveState, type SaveStatus, type State } from './initialState';
|
||||
import { initialState } from './initialState';
|
||||
|
|
@ -179,7 +180,7 @@ export class ActionImpl {
|
|||
};
|
||||
}
|
||||
|
||||
export const store: StateCreator<Store> = (set, get, _api) => ({
|
||||
export const store: StateCreator<Store> = (...parameters: Parameters<StateCreator<Store>>) => ({
|
||||
...initialState,
|
||||
...new ActionImpl(set, get, _api),
|
||||
...flattenActions<Action>([new ActionImpl(...parameters)]),
|
||||
});
|
||||
|
|
|
|||
|
|
@ -11,10 +11,10 @@ import {
|
|||
} from '@lobechat/types';
|
||||
import { produce } from 'immer';
|
||||
import useSWR, { mutate, type SWRResponse } from 'swr';
|
||||
import { type StateCreator } from 'zustand/vanilla';
|
||||
|
||||
import { useClientDataSWR } from '@/libs/swr';
|
||||
import { agentSkillService } from '@/services/skill';
|
||||
import { type StoreSetter } from '@/store/types';
|
||||
import { setNamespace } from '@/utils/storeDebug';
|
||||
|
||||
import { type ToolStore } from '../../store';
|
||||
|
|
@ -27,36 +27,31 @@ export interface AgentSkillDetailData {
|
|||
skillDetail?: SkillItem;
|
||||
}
|
||||
|
||||
export interface AgentSkillsAction {
|
||||
createAgentSkill: (params: CreateSkillInput) => Promise<SkillItem | undefined>;
|
||||
deleteAgentSkill: (id: string) => Promise<void>;
|
||||
fetchAgentSkillDetail: (id: string) => Promise<SkillItem | undefined>;
|
||||
importAgentSkillFromGitHub: (params: ImportGitHubInput) => Promise<SkillImportResult | undefined>;
|
||||
importAgentSkillFromUrl: (params: ImportUrlInput) => Promise<SkillImportResult | undefined>;
|
||||
importAgentSkillFromZip: (params: ImportZipInput) => Promise<SkillImportResult | undefined>;
|
||||
refreshAgentSkills: () => Promise<void>;
|
||||
updateAgentSkill: (params: UpdateSkillInput) => Promise<SkillItem | undefined>;
|
||||
useFetchAgentSkillDetail: (skillId?: string) => SWRResponse<AgentSkillDetailData>;
|
||||
useFetchAgentSkills: (enabled: boolean) => SWRResponse<SkillListItem[]>;
|
||||
}
|
||||
type Setter = StoreSetter<ToolStore>;
|
||||
|
||||
export const createAgentSkillsSlice: StateCreator<
|
||||
ToolStore,
|
||||
[['zustand/devtools', never]],
|
||||
[],
|
||||
AgentSkillsAction
|
||||
> = (set, get) => ({
|
||||
createAgentSkill: async (params) => {
|
||||
export const createAgentSkillsSlice = (set: Setter, get: () => ToolStore, _api?: unknown) =>
|
||||
new AgentSkillsActionImpl(set, get, _api);
|
||||
|
||||
export class AgentSkillsActionImpl {
|
||||
readonly #get: () => ToolStore;
|
||||
readonly #set: Setter;
|
||||
|
||||
constructor(set: Setter, get: () => ToolStore, _api?: unknown) {
|
||||
void _api;
|
||||
this.#set = set;
|
||||
this.#get = get;
|
||||
}
|
||||
|
||||
createAgentSkill = async (params: CreateSkillInput): Promise<SkillItem | undefined> => {
|
||||
const result = await agentSkillService.createSkill(params);
|
||||
await get().refreshAgentSkills();
|
||||
await this.#get().refreshAgentSkills();
|
||||
return result;
|
||||
},
|
||||
};
|
||||
|
||||
deleteAgentSkill: async (id) => {
|
||||
deleteAgentSkill = async (id: string): Promise<void> => {
|
||||
await agentSkillService.deleteSkill(id);
|
||||
|
||||
// Clean up detail map
|
||||
set(
|
||||
this.#set(
|
||||
produce((draft: AgentSkillsState) => {
|
||||
delete draft.agentSkillDetailMap[id];
|
||||
}),
|
||||
|
|
@ -64,19 +59,18 @@ export const createAgentSkillsSlice: StateCreator<
|
|||
n('deleteAgentSkill'),
|
||||
);
|
||||
|
||||
// Clear SWR cache
|
||||
await mutate(['fetchAgentSkillDetail', id].join('-'), undefined, { revalidate: false });
|
||||
|
||||
await get().refreshAgentSkills();
|
||||
},
|
||||
await this.#get().refreshAgentSkills();
|
||||
};
|
||||
|
||||
fetchAgentSkillDetail: async (id) => {
|
||||
const cached = get().agentSkillDetailMap[id];
|
||||
fetchAgentSkillDetail = async (id: string): Promise<SkillItem | undefined> => {
|
||||
const cached = this.#get().agentSkillDetailMap[id];
|
||||
if (cached) return cached;
|
||||
|
||||
const detail = await agentSkillService.getById(id);
|
||||
if (detail) {
|
||||
set(
|
||||
this.#set(
|
||||
produce((draft: AgentSkillsState) => {
|
||||
draft.agentSkillDetailMap[id] = detail;
|
||||
}),
|
||||
|
|
@ -85,37 +79,42 @@ export const createAgentSkillsSlice: StateCreator<
|
|||
);
|
||||
}
|
||||
return detail;
|
||||
},
|
||||
};
|
||||
|
||||
importAgentSkillFromGitHub: async (params) => {
|
||||
importAgentSkillFromGitHub = async (
|
||||
params: ImportGitHubInput,
|
||||
): Promise<SkillImportResult | undefined> => {
|
||||
const result = await agentSkillService.importFromGitHub(params);
|
||||
await get().refreshAgentSkills();
|
||||
await this.#get().refreshAgentSkills();
|
||||
return result;
|
||||
},
|
||||
};
|
||||
|
||||
importAgentSkillFromUrl: async (params) => {
|
||||
importAgentSkillFromUrl = async (
|
||||
params: ImportUrlInput,
|
||||
): Promise<SkillImportResult | undefined> => {
|
||||
const result = await agentSkillService.importFromUrl(params);
|
||||
await get().refreshAgentSkills();
|
||||
await this.#get().refreshAgentSkills();
|
||||
return result;
|
||||
},
|
||||
};
|
||||
|
||||
importAgentSkillFromZip: async (params) => {
|
||||
importAgentSkillFromZip = async (
|
||||
params: ImportZipInput,
|
||||
): Promise<SkillImportResult | undefined> => {
|
||||
const result = await agentSkillService.importFromZip(params);
|
||||
await get().refreshAgentSkills();
|
||||
await this.#get().refreshAgentSkills();
|
||||
return result;
|
||||
},
|
||||
};
|
||||
|
||||
refreshAgentSkills: async () => {
|
||||
refreshAgentSkills = async (): Promise<void> => {
|
||||
const { data } = await agentSkillService.list();
|
||||
set({ agentSkills: data }, false, n('refreshAgentSkills'));
|
||||
},
|
||||
this.#set({ agentSkills: data }, false, n('refreshAgentSkills'));
|
||||
};
|
||||
|
||||
updateAgentSkill: async (params) => {
|
||||
updateAgentSkill = async (params: UpdateSkillInput): Promise<SkillItem | undefined> => {
|
||||
const result = await agentSkillService.updateSkill(params);
|
||||
|
||||
// Update detail map if cached
|
||||
if (result) {
|
||||
set(
|
||||
this.#set(
|
||||
produce((draft: AgentSkillsState) => {
|
||||
draft.agentSkillDetailMap[params.id] = result;
|
||||
}),
|
||||
|
|
@ -124,14 +123,13 @@ export const createAgentSkillsSlice: StateCreator<
|
|||
);
|
||||
}
|
||||
|
||||
// Clear SWR cache so next open refetches instead of showing stale data
|
||||
await mutate(['fetchAgentSkillDetail', params.id].join('-'), undefined, { revalidate: false });
|
||||
|
||||
await get().refreshAgentSkills();
|
||||
await this.#get().refreshAgentSkills();
|
||||
return result;
|
||||
},
|
||||
};
|
||||
|
||||
useFetchAgentSkillDetail: (skillId) =>
|
||||
useFetchAgentSkillDetail = (skillId?: string): SWRResponse<AgentSkillDetailData> =>
|
||||
useClientDataSWR<AgentSkillDetailData>(
|
||||
skillId ? ['fetchAgentSkillDetail', skillId].join('-') : null,
|
||||
async () => {
|
||||
|
|
@ -141,7 +139,7 @@ export const createAgentSkillsSlice: StateCreator<
|
|||
]);
|
||||
|
||||
if (detail) {
|
||||
set(
|
||||
this.#set(
|
||||
produce((draft: AgentSkillsState) => {
|
||||
draft.agentSkillDetailMap[skillId!] = detail;
|
||||
}),
|
||||
|
|
@ -153,9 +151,9 @@ export const createAgentSkillsSlice: StateCreator<
|
|||
return { resourceTree, skillDetail: detail };
|
||||
},
|
||||
{ revalidateOnFocus: false },
|
||||
),
|
||||
);
|
||||
|
||||
useFetchAgentSkills: (enabled) =>
|
||||
useFetchAgentSkills = (enabled: boolean): SWRResponse<SkillListItem[]> =>
|
||||
useSWR<SkillListItem[]>(
|
||||
enabled ? 'fetchAgentSkills' : null,
|
||||
async () => {
|
||||
|
|
@ -165,9 +163,11 @@ export const createAgentSkillsSlice: StateCreator<
|
|||
{
|
||||
fallbackData: [],
|
||||
onSuccess: (data) => {
|
||||
set({ agentSkills: data }, false, n('useFetchAgentSkills'));
|
||||
this.#set({ agentSkills: data }, false, n('useFetchAgentSkills'));
|
||||
},
|
||||
revalidateOnFocus: false,
|
||||
},
|
||||
),
|
||||
});
|
||||
);
|
||||
}
|
||||
|
||||
export type AgentSkillsAction = Pick<AgentSkillsActionImpl, keyof AgentSkillsActionImpl>;
|
||||
|
|
|
|||
|
|
@ -1,35 +1,35 @@
|
|||
import { ENABLE_BUSINESS_FEATURES } from '@lobechat/business-const';
|
||||
import { t } from 'i18next';
|
||||
import { type StateCreator } from 'zustand';
|
||||
|
||||
import { markUserValidAction } from '@/business/client/markUserValidAction';
|
||||
import { message } from '@/components/AntdStaticMethods';
|
||||
import { videoService } from '@/services/video';
|
||||
import { type StoreSetter } from '@/store/types';
|
||||
|
||||
import { type VideoStore } from '../../store';
|
||||
import { generationBatchSelectors } from '../generationBatch/selectors';
|
||||
import { videoGenerationConfigSelectors } from '../generationConfig/selectors';
|
||||
import { generationTopicSelectors } from '../generationTopic';
|
||||
|
||||
// ====== action interface ====== //
|
||||
type Setter = StoreSetter<VideoStore>;
|
||||
|
||||
export interface CreateVideoAction {
|
||||
createVideo: () => Promise<void>;
|
||||
recreateVideo: (generationBatchId: string) => Promise<void>;
|
||||
}
|
||||
export const createCreateVideoSlice = (set: Setter, get: () => VideoStore, _api?: unknown) =>
|
||||
new CreateVideoActionImpl(set, get, _api);
|
||||
|
||||
// ====== action implementation ====== //
|
||||
export class CreateVideoActionImpl {
|
||||
readonly #get: () => VideoStore;
|
||||
readonly #set: Setter;
|
||||
|
||||
export const createCreateVideoSlice: StateCreator<
|
||||
VideoStore,
|
||||
[['zustand/devtools', never]],
|
||||
[],
|
||||
CreateVideoAction
|
||||
> = (set, get) => ({
|
||||
async createVideo() {
|
||||
set({ isCreating: true }, false, 'createVideo/startCreateVideo');
|
||||
constructor(set: Setter, get: () => VideoStore, _api?: unknown) {
|
||||
void _api;
|
||||
this.#set = set;
|
||||
this.#get = get;
|
||||
}
|
||||
|
||||
const store = get();
|
||||
createVideo = async (): Promise<void> => {
|
||||
this.#set({ isCreating: true }, false, 'createVideo/startCreateVideo');
|
||||
|
||||
const store = this.#get();
|
||||
const parameters = videoGenerationConfigSelectors.parameters(store);
|
||||
const provider = videoGenerationConfigSelectors.provider(store);
|
||||
const model = videoGenerationConfigSelectors.model(store);
|
||||
|
|
@ -58,7 +58,7 @@ export const createCreateVideoSlice: StateCreator<
|
|||
content: t('generation.validation.endFrameRequiresStartFrame', { ns: 'video' }),
|
||||
duration: 3,
|
||||
});
|
||||
set({ isCreating: false }, false, 'createVideo/endCreateVideo');
|
||||
this.#set({ isCreating: false }, false, 'createVideo/endCreateVideo');
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -84,7 +84,11 @@ export const createCreateVideoSlice: StateCreator<
|
|||
try {
|
||||
// 3. If it's a new topic, set the creating state after topic creation
|
||||
if (isNewTopic) {
|
||||
set({ isCreatingWithNewTopic: true }, false, 'createVideo/startCreateVideoWithNewTopic');
|
||||
this.#set(
|
||||
{ isCreatingWithNewTopic: true },
|
||||
false,
|
||||
'createVideo/startCreateVideoWithNewTopic',
|
||||
);
|
||||
}
|
||||
|
||||
if (ENABLE_BUSINESS_FEATURES) {
|
||||
|
|
@ -101,11 +105,11 @@ export const createCreateVideoSlice: StateCreator<
|
|||
|
||||
// 5. Refresh generation batches to show the new batch
|
||||
if (!isNewTopic) {
|
||||
await get().refreshGenerationBatches();
|
||||
await this.#get().refreshGenerationBatches();
|
||||
}
|
||||
|
||||
// 6. Clear the prompt input after successful video creation
|
||||
set(
|
||||
this.#set(
|
||||
(state) => ({
|
||||
parameters: { ...state.parameters, prompt: '' },
|
||||
}),
|
||||
|
|
@ -115,21 +119,21 @@ export const createCreateVideoSlice: StateCreator<
|
|||
} finally {
|
||||
// 7. Reset all creating states
|
||||
if (isNewTopic) {
|
||||
set(
|
||||
this.#set(
|
||||
{ isCreating: false, isCreatingWithNewTopic: false },
|
||||
false,
|
||||
'createVideo/endCreateVideoWithNewTopic',
|
||||
);
|
||||
} else {
|
||||
set({ isCreating: false }, false, 'createVideo/endCreateVideo');
|
||||
this.#set({ isCreating: false }, false, 'createVideo/endCreateVideo');
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
async recreateVideo(generationBatchId: string) {
|
||||
set({ isCreating: true }, false, 'recreateVideo/start');
|
||||
recreateVideo = async (generationBatchId: string): Promise<void> => {
|
||||
this.#set({ isCreating: true }, false, 'recreateVideo/start');
|
||||
|
||||
const store = get();
|
||||
const store = this.#get();
|
||||
const activeGenerationTopicId = generationTopicSelectors.activeGenerationTopicId(store);
|
||||
if (!activeGenerationTopicId) {
|
||||
throw new Error('No active generation topic');
|
||||
|
|
@ -150,7 +154,9 @@ export const createCreateVideoSlice: StateCreator<
|
|||
|
||||
await store.refreshGenerationBatches();
|
||||
} finally {
|
||||
set({ isCreating: false }, false, 'recreateVideo/end');
|
||||
this.#set({ isCreating: false }, false, 'recreateVideo/end');
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
export type CreateVideoAction = Pick<CreateVideoActionImpl, keyof CreateVideoActionImpl>;
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
import { isEqual } from 'es-toolkit/compat';
|
||||
import { useRef } from 'react';
|
||||
import type { SWRResponse } from 'swr';
|
||||
import { type StateCreator } from 'zustand';
|
||||
|
||||
import { mutate, useClientDataSWR } from '@/libs/swr';
|
||||
import { type GetGenerationStatusResult } from '@/server/routers/lambda/generation';
|
||||
import { generationService } from '@/services/generation';
|
||||
import { generationBatchService } from '@/services/generationBatch';
|
||||
import { type StoreSetter } from '@/store/types';
|
||||
import { AsyncTaskStatus } from '@/types/asyncTask';
|
||||
import { type GenerationBatch } from '@/types/generation';
|
||||
import { setNamespace } from '@/utils/storeDebug';
|
||||
|
|
@ -21,44 +21,28 @@ const n = setNamespace('generationBatch');
|
|||
const SWR_USE_FETCH_GENERATION_BATCHES = 'SWR_USE_FETCH_VIDEO_GENERATION_BATCHES';
|
||||
const SWR_USE_CHECK_GENERATION_STATUS = 'SWR_USE_CHECK_VIDEO_GENERATION_STATUS';
|
||||
|
||||
// ====== action interface ====== //
|
||||
type Setter = StoreSetter<VideoStore>;
|
||||
|
||||
export interface GenerationBatchAction {
|
||||
internal_deleteGeneration: (generationId: string) => Promise<void>;
|
||||
internal_deleteGenerationBatch: (batchId: string, topicId: string) => Promise<void>;
|
||||
internal_dispatchGenerationBatch: (
|
||||
topicId: string,
|
||||
payload: GenerationBatchDispatch,
|
||||
action?: string,
|
||||
) => void;
|
||||
refreshGenerationBatches: () => Promise<void>;
|
||||
removeGeneration: (generationId: string) => Promise<void>;
|
||||
removeGenerationBatch: (batchId: string, topicId: string) => Promise<void>;
|
||||
setTopicBatchLoaded: (topicId: string) => void;
|
||||
useCheckGenerationStatus: (
|
||||
generationId: string,
|
||||
asyncTaskId: string,
|
||||
topicId: string,
|
||||
enable?: boolean,
|
||||
) => SWRResponse<GetGenerationStatusResult>;
|
||||
useFetchGenerationBatches: (topicId?: string | null) => SWRResponse<GenerationBatch[]>;
|
||||
}
|
||||
export const createGenerationBatchSlice = (set: Setter, get: () => VideoStore, _api?: unknown) =>
|
||||
new GenerationBatchActionImpl(set, get, _api);
|
||||
|
||||
// ====== action implementation ====== //
|
||||
export class GenerationBatchActionImpl {
|
||||
readonly #get: () => VideoStore;
|
||||
readonly #set: Setter;
|
||||
|
||||
export const createGenerationBatchSlice: StateCreator<
|
||||
VideoStore,
|
||||
[['zustand/devtools', never]],
|
||||
[],
|
||||
GenerationBatchAction
|
||||
> = (set, get) => ({
|
||||
internal_deleteGeneration: async (generationId: string) => {
|
||||
constructor(set: Setter, get: () => VideoStore, _api?: unknown) {
|
||||
void _api;
|
||||
this.#set = set;
|
||||
this.#get = get;
|
||||
}
|
||||
|
||||
internal_deleteGeneration = async (generationId: string): Promise<void> => {
|
||||
const { activeGenerationTopicId, refreshGenerationBatches, internal_dispatchGenerationBatch } =
|
||||
get();
|
||||
this.#get();
|
||||
|
||||
if (!activeGenerationTopicId) return;
|
||||
|
||||
const currentBatches = get().generationBatchesMap[activeGenerationTopicId] || [];
|
||||
const currentBatches = this.#get().generationBatchesMap[activeGenerationTopicId] || [];
|
||||
const targetBatch = currentBatches.find((batch) =>
|
||||
batch.generations.some((gen) => gen.id === generationId),
|
||||
);
|
||||
|
|
@ -74,10 +58,10 @@ export const createGenerationBatchSlice: StateCreator<
|
|||
|
||||
await generationService.deleteGeneration(generationId);
|
||||
await refreshGenerationBatches();
|
||||
},
|
||||
};
|
||||
|
||||
internal_deleteGenerationBatch: async (batchId: string, topicId: string) => {
|
||||
const { internal_dispatchGenerationBatch, refreshGenerationBatches } = get();
|
||||
internal_deleteGenerationBatch = async (batchId: string, topicId: string): Promise<void> => {
|
||||
const { internal_dispatchGenerationBatch, refreshGenerationBatches } = this.#get();
|
||||
|
||||
// Optimistic update
|
||||
internal_dispatchGenerationBatch(
|
||||
|
|
@ -88,75 +72,84 @@ export const createGenerationBatchSlice: StateCreator<
|
|||
|
||||
await generationBatchService.deleteGenerationBatch(batchId);
|
||||
await refreshGenerationBatches();
|
||||
},
|
||||
};
|
||||
|
||||
internal_dispatchGenerationBatch: (topicId, payload, action) => {
|
||||
const currentBatches = get().generationBatchesMap[topicId] || [];
|
||||
internal_dispatchGenerationBatch = (
|
||||
topicId: string,
|
||||
payload: GenerationBatchDispatch,
|
||||
action?: string,
|
||||
): void => {
|
||||
const currentBatches = this.#get().generationBatchesMap[topicId] || [];
|
||||
const nextBatches = generationBatchReducer(currentBatches, payload);
|
||||
|
||||
const nextMap = {
|
||||
...get().generationBatchesMap,
|
||||
...this.#get().generationBatchesMap,
|
||||
[topicId]: nextBatches,
|
||||
};
|
||||
|
||||
if (isEqual(nextMap, get().generationBatchesMap)) return;
|
||||
if (isEqual(nextMap, this.#get().generationBatchesMap)) return;
|
||||
|
||||
set(
|
||||
this.#set(
|
||||
{
|
||||
generationBatchesMap: nextMap,
|
||||
},
|
||||
false,
|
||||
action ?? n(`dispatchGenerationBatch/${payload.type}`),
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
refreshGenerationBatches: async () => {
|
||||
const { activeGenerationTopicId } = get();
|
||||
refreshGenerationBatches = async (): Promise<void> => {
|
||||
const { activeGenerationTopicId } = this.#get();
|
||||
if (activeGenerationTopicId) {
|
||||
await mutate([SWR_USE_FETCH_GENERATION_BATCHES, activeGenerationTopicId]);
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
removeGeneration: async (generationId: string) => {
|
||||
removeGeneration = async (generationId: string): Promise<void> => {
|
||||
const { internal_deleteGeneration, activeGenerationTopicId, internal_deleteGenerationBatch } =
|
||||
get();
|
||||
this.#get();
|
||||
|
||||
await internal_deleteGeneration(generationId);
|
||||
|
||||
// Video batch has only 1 generation, so delete the batch directly
|
||||
if (activeGenerationTopicId) {
|
||||
const updatedBatches = get().generationBatchesMap[activeGenerationTopicId] || [];
|
||||
const updatedBatches = this.#get().generationBatchesMap[activeGenerationTopicId] || [];
|
||||
const emptyBatches = updatedBatches.filter((batch) => batch.generations.length === 0);
|
||||
|
||||
for (const emptyBatch of emptyBatches) {
|
||||
await internal_deleteGenerationBatch(emptyBatch.id, activeGenerationTopicId);
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
removeGenerationBatch: async (batchId: string, topicId: string) => {
|
||||
const { internal_deleteGenerationBatch } = get();
|
||||
removeGenerationBatch = async (batchId: string, topicId: string): Promise<void> => {
|
||||
const { internal_deleteGenerationBatch } = this.#get();
|
||||
await internal_deleteGenerationBatch(batchId, topicId);
|
||||
},
|
||||
};
|
||||
|
||||
setTopicBatchLoaded: (topicId: string) => {
|
||||
setTopicBatchLoaded = (topicId: string): void => {
|
||||
const nextMap = {
|
||||
...get().generationBatchesMap,
|
||||
...this.#get().generationBatchesMap,
|
||||
[topicId]: [],
|
||||
};
|
||||
|
||||
if (isEqual(nextMap, get().generationBatchesMap)) return;
|
||||
if (isEqual(nextMap, this.#get().generationBatchesMap)) return;
|
||||
|
||||
set(
|
||||
this.#set(
|
||||
{
|
||||
generationBatchesMap: nextMap,
|
||||
},
|
||||
false,
|
||||
n('setTopicBatchLoaded'),
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
useCheckGenerationStatus: (generationId, asyncTaskId, topicId, enable = true) => {
|
||||
useCheckGenerationStatus = (
|
||||
generationId: string,
|
||||
asyncTaskId: string,
|
||||
topicId: string,
|
||||
enable = true,
|
||||
): SWRResponse<GetGenerationStatusResult> => {
|
||||
const requestCountRef = useRef(0);
|
||||
const isErrorRef = useRef(false);
|
||||
|
||||
|
|
@ -178,7 +171,7 @@ export const createGenerationBatchSlice: StateCreator<
|
|||
|
||||
isErrorRef.current = false;
|
||||
|
||||
const currentBatches = get().generationBatchesMap[topicId] || [];
|
||||
const currentBatches = this.#get().generationBatchesMap[topicId] || [];
|
||||
const targetBatch = currentBatches.find((batch) =>
|
||||
batch.generations.some((gen) => gen.id === generationId),
|
||||
);
|
||||
|
|
@ -190,7 +183,7 @@ export const createGenerationBatchSlice: StateCreator<
|
|||
requestCountRef.current = 0;
|
||||
|
||||
if (data.generation) {
|
||||
get().internal_dispatchGenerationBatch(
|
||||
this.#get().internal_dispatchGenerationBatch(
|
||||
topicId,
|
||||
{
|
||||
batchId: targetBatch.id,
|
||||
|
|
@ -205,11 +198,12 @@ export const createGenerationBatchSlice: StateCreator<
|
|||
|
||||
// Update topic cover if generation succeeds and has a thumbnail
|
||||
if (data.status === AsyncTaskStatus.Success && data.generation.asset?.thumbnailUrl) {
|
||||
const currentTopic =
|
||||
generationTopicSelectors.getGenerationTopicById(topicId)(get());
|
||||
const currentTopic = generationTopicSelectors.getGenerationTopicById(topicId)(
|
||||
this.#get(),
|
||||
);
|
||||
|
||||
if (currentTopic && !currentTopic.coverUrl) {
|
||||
await get().updateGenerationTopicCover(
|
||||
await this.#get().updateGenerationTopicCover(
|
||||
topicId,
|
||||
data.generation.asset.thumbnailUrl,
|
||||
);
|
||||
|
|
@ -217,7 +211,7 @@ export const createGenerationBatchSlice: StateCreator<
|
|||
}
|
||||
}
|
||||
|
||||
await get().refreshGenerationBatches();
|
||||
await this.#get().refreshGenerationBatches();
|
||||
}
|
||||
},
|
||||
refreshInterval: (data: GetGenerationStatusResult | undefined) => {
|
||||
|
|
@ -244,9 +238,9 @@ export const createGenerationBatchSlice: StateCreator<
|
|||
refreshWhenHidden: false,
|
||||
},
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
useFetchGenerationBatches: (topicId) =>
|
||||
useFetchGenerationBatches = (topicId?: string | null): SWRResponse<GenerationBatch[]> =>
|
||||
useClientDataSWR<GenerationBatch[]>(
|
||||
topicId ? [SWR_USE_FETCH_GENERATION_BATCHES, topicId] : null,
|
||||
async ([, topicId]: [string, string]) => {
|
||||
|
|
@ -255,13 +249,13 @@ export const createGenerationBatchSlice: StateCreator<
|
|||
{
|
||||
onSuccess: (data) => {
|
||||
const nextMap = {
|
||||
...get().generationBatchesMap,
|
||||
...this.#get().generationBatchesMap,
|
||||
[topicId!]: data,
|
||||
};
|
||||
|
||||
if (isEqual(nextMap, get().generationBatchesMap)) return;
|
||||
if (isEqual(nextMap, this.#get().generationBatchesMap)) return;
|
||||
|
||||
set(
|
||||
this.#set(
|
||||
{
|
||||
generationBatchesMap: nextMap,
|
||||
},
|
||||
|
|
@ -270,5 +264,10 @@ export const createGenerationBatchSlice: StateCreator<
|
|||
);
|
||||
},
|
||||
},
|
||||
),
|
||||
});
|
||||
);
|
||||
}
|
||||
|
||||
export type GenerationBatchAction = Pick<
|
||||
GenerationBatchActionImpl,
|
||||
keyof GenerationBatchActionImpl
|
||||
>;
|
||||
|
|
|
|||
|
|
@ -5,30 +5,15 @@ import {
|
|||
type RuntimeVideoGenParamsValue,
|
||||
type VideoModelParamsSchema,
|
||||
} from 'model-bank';
|
||||
import { type StateCreator } from 'zustand/vanilla';
|
||||
|
||||
import { aiProviderSelectors, getAiInfraStoreState } from '@/store/aiInfra';
|
||||
import { useGlobalStore } from '@/store/global';
|
||||
import { type StoreSetter } from '@/store/types';
|
||||
import { useUserStore } from '@/store/user';
|
||||
import { authSelectors } from '@/store/user/selectors';
|
||||
|
||||
import type { VideoStore } from '../../store';
|
||||
|
||||
export interface GenerationConfigAction {
|
||||
initializeVideoConfig: (
|
||||
isLogin?: boolean,
|
||||
lastSelectedVideoModel?: string,
|
||||
lastSelectedVideoProvider?: string,
|
||||
) => void;
|
||||
|
||||
setModelAndProviderOnSelect: (model: string, provider: string) => void;
|
||||
|
||||
setParamOnInput: <K extends RuntimeVideoGenParamsKeys>(
|
||||
paramName: K,
|
||||
value: RuntimeVideoGenParamsValue,
|
||||
) => void;
|
||||
}
|
||||
|
||||
export function getVideoModelAndDefaults(model: string, provider: string) {
|
||||
const enabledVideoModelList = aiProviderSelectors.enabledVideoModelList(getAiInfraStoreState());
|
||||
|
||||
|
|
@ -54,13 +39,25 @@ export function getVideoModelAndDefaults(model: string, provider: string) {
|
|||
return { activeModel, defaultValues, parametersSchema };
|
||||
}
|
||||
|
||||
export const createGenerationConfigSlice: StateCreator<
|
||||
VideoStore,
|
||||
[['zustand/devtools', never]],
|
||||
[],
|
||||
GenerationConfigAction
|
||||
> = (set) => ({
|
||||
initializeVideoConfig: (isLogin, lastSelectedVideoModel, lastSelectedVideoProvider) => {
|
||||
type Setter = StoreSetter<VideoStore>;
|
||||
|
||||
export const createGenerationConfigSlice = (set: Setter, get: () => VideoStore, _api?: unknown) =>
|
||||
new GenerationConfigActionImpl(set, get, _api);
|
||||
|
||||
export class GenerationConfigActionImpl {
|
||||
readonly #set: Setter;
|
||||
|
||||
constructor(set: Setter, _get: () => VideoStore, _api?: unknown) {
|
||||
void _get;
|
||||
void _api;
|
||||
this.#set = set;
|
||||
}
|
||||
|
||||
initializeVideoConfig = (
|
||||
isLogin?: boolean,
|
||||
lastSelectedVideoModel?: string,
|
||||
lastSelectedVideoProvider?: string,
|
||||
): void => {
|
||||
if (isLogin && lastSelectedVideoModel && lastSelectedVideoProvider) {
|
||||
try {
|
||||
const { defaultValues, parametersSchema } = getVideoModelAndDefaults(
|
||||
|
|
@ -68,7 +65,7 @@ export const createGenerationConfigSlice: StateCreator<
|
|||
lastSelectedVideoProvider,
|
||||
);
|
||||
|
||||
set(
|
||||
this.#set(
|
||||
{
|
||||
isInit: true,
|
||||
model: lastSelectedVideoModel,
|
||||
|
|
@ -80,17 +77,17 @@ export const createGenerationConfigSlice: StateCreator<
|
|||
`initializeVideoConfig/${lastSelectedVideoModel}/${lastSelectedVideoProvider}`,
|
||||
);
|
||||
} catch {
|
||||
set({ isInit: true }, false, 'initializeVideoConfig/default');
|
||||
this.#set({ isInit: true }, false, 'initializeVideoConfig/default');
|
||||
}
|
||||
} else {
|
||||
set({ isInit: true }, false, 'initializeVideoConfig/default');
|
||||
this.#set({ isInit: true }, false, 'initializeVideoConfig/default');
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
setModelAndProviderOnSelect: (model, provider) => {
|
||||
setModelAndProviderOnSelect = (model: string, provider: string): void => {
|
||||
const { defaultValues, parametersSchema } = getVideoModelAndDefaults(model, provider);
|
||||
|
||||
set(
|
||||
this.#set(
|
||||
{
|
||||
model,
|
||||
parameters: defaultValues,
|
||||
|
|
@ -108,10 +105,13 @@ export const createGenerationConfigSlice: StateCreator<
|
|||
lastSelectedVideoProvider: provider,
|
||||
});
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
setParamOnInput: (paramName, value) => {
|
||||
set(
|
||||
setParamOnInput = <K extends RuntimeVideoGenParamsKeys>(
|
||||
paramName: K,
|
||||
value: RuntimeVideoGenParamsValue,
|
||||
): void => {
|
||||
this.#set(
|
||||
(state) => {
|
||||
const { parameters } = state;
|
||||
return { parameters: { ...parameters, [paramName]: value } };
|
||||
|
|
@ -119,5 +119,10 @@ export const createGenerationConfigSlice: StateCreator<
|
|||
false,
|
||||
`setParamOnInput/${paramName}`,
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
export type GenerationConfigAction = Pick<
|
||||
GenerationConfigActionImpl,
|
||||
keyof GenerationConfigActionImpl
|
||||
>;
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
import { chainSummaryGenerationTitle } from '@lobechat/prompts';
|
||||
import isEqual from 'fast-deep-equal';
|
||||
import type { SWRResponse } from 'swr';
|
||||
import { type StateCreator } from 'zustand/vanilla';
|
||||
|
||||
import { LOADING_FLAT } from '@/const/message';
|
||||
import { mutate, useClientDataSWR } from '@/libs/swr';
|
||||
import { type UpdateTopicValue } from '@/server/routers/lambda/generationTopic';
|
||||
import { chatService } from '@/services/chat';
|
||||
import { generationTopicService } from '@/services/generationTopic';
|
||||
import { type StoreSetter } from '@/store/types';
|
||||
import { useUserStore } from '@/store/user';
|
||||
import { systemAgentSelectors, userGeneralSettingsSelectors } from '@/store/user/selectors';
|
||||
import { type ImageGenerationTopic } from '@/types/generation';
|
||||
|
|
@ -22,104 +22,97 @@ const FETCH_GENERATION_TOPICS_KEY = 'fetchVideoGenerationTopics';
|
|||
|
||||
const n = setNamespace('videoGenerationTopic');
|
||||
|
||||
export interface GenerationTopicAction {
|
||||
createGenerationTopic: (prompts: string[]) => Promise<string>;
|
||||
internal_createGenerationTopic: () => Promise<string>;
|
||||
internal_dispatchGenerationTopic: (payload: GenerationTopicDispatch, action?: any) => void;
|
||||
internal_removeGenerationTopic: (id: string) => Promise<void>;
|
||||
internal_updateGenerationTopic: (id: string, data: UpdateTopicValue) => Promise<void>;
|
||||
internal_updateGenerationTopicCover: (topicId: string, coverUrl: string) => Promise<void>;
|
||||
internal_updateGenerationTopicLoading: (id: string, loading: boolean) => void;
|
||||
internal_updateGenerationTopicTitleInSummary: (id: string, title: string) => void;
|
||||
type Setter = StoreSetter<VideoStore>;
|
||||
|
||||
openNewGenerationTopic: () => void;
|
||||
refreshGenerationTopics: () => Promise<void>;
|
||||
removeGenerationTopic: (id: string) => Promise<void>;
|
||||
summaryGenerationTopicTitle: (topicId: string, prompts: string[]) => Promise<string>;
|
||||
switchGenerationTopic: (topicId: string) => void;
|
||||
updateGenerationTopicCover: (topicId: string, imageUrl: string) => Promise<void>;
|
||||
useFetchGenerationTopics: (enabled: boolean) => SWRResponse<ImageGenerationTopic[]>;
|
||||
}
|
||||
export const createGenerationTopicSlice = (set: Setter, get: () => VideoStore, _api?: unknown) =>
|
||||
new GenerationTopicActionImpl(set, get, _api);
|
||||
|
||||
export const createGenerationTopicSlice: StateCreator<
|
||||
VideoStore,
|
||||
[['zustand/devtools', never]],
|
||||
[],
|
||||
GenerationTopicAction
|
||||
> = (set, get) => ({
|
||||
createGenerationTopic: async (prompts: string[]) => {
|
||||
export class GenerationTopicActionImpl {
|
||||
readonly #get: () => VideoStore;
|
||||
readonly #set: Setter;
|
||||
|
||||
constructor(set: Setter, get: () => VideoStore, _api?: unknown) {
|
||||
void _api;
|
||||
this.#set = set;
|
||||
this.#get = get;
|
||||
}
|
||||
|
||||
createGenerationTopic = async (prompts: string[]): Promise<string> => {
|
||||
if (!prompts || prompts.length === 0) {
|
||||
throw new Error('Prompts cannot be empty when creating a generation topic');
|
||||
}
|
||||
|
||||
const { internal_createGenerationTopic, summaryGenerationTopicTitle } = get();
|
||||
const { internal_createGenerationTopic, summaryGenerationTopicTitle } = this.#get();
|
||||
|
||||
const topicId = await internal_createGenerationTopic();
|
||||
|
||||
summaryGenerationTopicTitle(topicId, prompts);
|
||||
|
||||
return topicId;
|
||||
},
|
||||
};
|
||||
|
||||
internal_createGenerationTopic: async () => {
|
||||
internal_createGenerationTopic = async (): Promise<string> => {
|
||||
const tmpId = Date.now().toString();
|
||||
|
||||
get().internal_dispatchGenerationTopic(
|
||||
this.#get().internal_dispatchGenerationTopic(
|
||||
{ type: 'addTopic', value: { id: tmpId, title: '' } },
|
||||
'internal_createGenerationTopic',
|
||||
);
|
||||
|
||||
get().internal_updateGenerationTopicLoading(tmpId, true);
|
||||
this.#get().internal_updateGenerationTopicLoading(tmpId, true);
|
||||
|
||||
const topicId = await generationTopicService.createTopic('video');
|
||||
get().internal_updateGenerationTopicLoading(tmpId, false);
|
||||
this.#get().internal_updateGenerationTopicLoading(tmpId, false);
|
||||
|
||||
get().internal_updateGenerationTopicLoading(topicId, true);
|
||||
await get().refreshGenerationTopics();
|
||||
get().internal_updateGenerationTopicLoading(topicId, false);
|
||||
this.#get().internal_updateGenerationTopicLoading(topicId, true);
|
||||
await this.#get().refreshGenerationTopics();
|
||||
this.#get().internal_updateGenerationTopicLoading(topicId, false);
|
||||
|
||||
return topicId;
|
||||
},
|
||||
};
|
||||
|
||||
internal_dispatchGenerationTopic: (payload, action) => {
|
||||
const nextTopics = generationTopicReducer(get().generationTopics, payload);
|
||||
internal_dispatchGenerationTopic = (payload: GenerationTopicDispatch, action?: any): void => {
|
||||
const nextTopics = generationTopicReducer(this.#get().generationTopics, payload);
|
||||
|
||||
if (isEqual(nextTopics, get().generationTopics)) return;
|
||||
if (isEqual(nextTopics, this.#get().generationTopics)) return;
|
||||
|
||||
set(
|
||||
this.#set(
|
||||
{ generationTopics: nextTopics },
|
||||
false,
|
||||
action ?? n(`dispatchGenerationTopic/${payload.type}`),
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
internal_removeGenerationTopic: async (id: string) => {
|
||||
get().internal_updateGenerationTopicLoading(id, true);
|
||||
internal_removeGenerationTopic = async (id: string): Promise<void> => {
|
||||
this.#get().internal_updateGenerationTopicLoading(id, true);
|
||||
try {
|
||||
await generationTopicService.deleteTopic(id);
|
||||
await get().refreshGenerationTopics();
|
||||
await this.#get().refreshGenerationTopics();
|
||||
} finally {
|
||||
get().internal_updateGenerationTopicLoading(id, false);
|
||||
this.#get().internal_updateGenerationTopicLoading(id, false);
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
internal_updateGenerationTopic: async (id, data) => {
|
||||
get().internal_dispatchGenerationTopic({ id, type: 'updateTopic', value: data });
|
||||
internal_updateGenerationTopic = async (id: string, data: UpdateTopicValue): Promise<void> => {
|
||||
this.#get().internal_dispatchGenerationTopic({ id, type: 'updateTopic', value: data });
|
||||
|
||||
get().internal_updateGenerationTopicLoading(id, true);
|
||||
this.#get().internal_updateGenerationTopicLoading(id, true);
|
||||
|
||||
await generationTopicService.updateTopic(id, data);
|
||||
|
||||
await get().refreshGenerationTopics();
|
||||
get().internal_updateGenerationTopicLoading(id, false);
|
||||
},
|
||||
await this.#get().refreshGenerationTopics();
|
||||
this.#get().internal_updateGenerationTopicLoading(id, false);
|
||||
};
|
||||
|
||||
internal_updateGenerationTopicCover: async (topicId: string, coverUrl: string) => {
|
||||
internal_updateGenerationTopicCover = async (
|
||||
topicId: string,
|
||||
coverUrl: string,
|
||||
): Promise<void> => {
|
||||
const {
|
||||
internal_dispatchGenerationTopic,
|
||||
internal_updateGenerationTopicLoading,
|
||||
refreshGenerationTopics,
|
||||
} = get();
|
||||
} = this.#get();
|
||||
|
||||
internal_dispatchGenerationTopic(
|
||||
{ id: topicId, type: 'updateTopic', value: { coverUrl } },
|
||||
|
|
@ -135,10 +128,10 @@ export const createGenerationTopicSlice: StateCreator<
|
|||
} finally {
|
||||
internal_updateGenerationTopicLoading(topicId, false);
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
internal_updateGenerationTopicLoading: (id, loading) => {
|
||||
set(
|
||||
internal_updateGenerationTopicLoading = (id: string, loading: boolean): void => {
|
||||
this.#set(
|
||||
(state) => {
|
||||
if (loading) return { loadingGenerationTopicIds: [...state.loadingGenerationTopicIds, id] };
|
||||
|
||||
|
|
@ -149,31 +142,31 @@ export const createGenerationTopicSlice: StateCreator<
|
|||
false,
|
||||
n('updateGenerationTopicLoading'),
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
internal_updateGenerationTopicTitleInSummary: (id, title) => {
|
||||
get().internal_dispatchGenerationTopic(
|
||||
internal_updateGenerationTopicTitleInSummary = (id: string, title: string): void => {
|
||||
this.#get().internal_dispatchGenerationTopic(
|
||||
{ id, type: 'updateTopic', value: { title } },
|
||||
'updateGenerationTopicTitleInSummary',
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
openNewGenerationTopic: () => {
|
||||
set({ activeGenerationTopicId: null }, false, n('openNewGenerationTopic'));
|
||||
},
|
||||
openNewGenerationTopic = (): void => {
|
||||
this.#set({ activeGenerationTopicId: null }, false, n('openNewGenerationTopic'));
|
||||
};
|
||||
|
||||
refreshGenerationTopics: async () => {
|
||||
refreshGenerationTopics = async (): Promise<void> => {
|
||||
await mutate([FETCH_GENERATION_TOPICS_KEY]);
|
||||
},
|
||||
};
|
||||
|
||||
removeGenerationTopic: async (id: string) => {
|
||||
removeGenerationTopic = async (id: string): Promise<void> => {
|
||||
const {
|
||||
internal_removeGenerationTopic,
|
||||
generationTopics,
|
||||
activeGenerationTopicId,
|
||||
switchGenerationTopic,
|
||||
openNewGenerationTopic,
|
||||
} = get();
|
||||
} = this.#get();
|
||||
|
||||
const isRemovingActiveTopic = activeGenerationTopicId === id;
|
||||
let topicIndexToRemove = -1;
|
||||
|
|
@ -185,7 +178,7 @@ export const createGenerationTopicSlice: StateCreator<
|
|||
await internal_removeGenerationTopic(id);
|
||||
|
||||
if (isRemovingActiveTopic) {
|
||||
const newTopics = get().generationTopics;
|
||||
const newTopics = this.#get().generationTopics;
|
||||
|
||||
if (newTopics.length > 0) {
|
||||
const newActiveIndex = Math.min(topicIndexToRemove, newTopics.length - 1);
|
||||
|
|
@ -200,14 +193,14 @@ export const createGenerationTopicSlice: StateCreator<
|
|||
openNewGenerationTopic();
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
summaryGenerationTopicTitle: async (topicId: string, prompts: string[]) => {
|
||||
const topic = generationTopicSelectors.getGenerationTopicById(topicId)(get());
|
||||
summaryGenerationTopicTitle = async (topicId: string, prompts: string[]): Promise<string> => {
|
||||
const topic = generationTopicSelectors.getGenerationTopicById(topicId)(this.#get());
|
||||
if (!topic) throw new Error(`Topic ${topicId} not found`);
|
||||
|
||||
const { internal_updateGenerationTopicTitleInSummary, internal_updateGenerationTopicLoading } =
|
||||
get();
|
||||
this.#get();
|
||||
|
||||
internal_updateGenerationTopicLoading(topicId, true);
|
||||
internal_updateGenerationTopicTitleInSummary(topicId, LOADING_FLAT);
|
||||
|
|
@ -233,10 +226,10 @@ export const createGenerationTopicSlice: StateCreator<
|
|||
onError: async () => {
|
||||
const fallbackTitle = generateFallbackTitle();
|
||||
internal_updateGenerationTopicTitleInSummary(topicId, fallbackTitle);
|
||||
await get().internal_updateGenerationTopic(topicId, { title: fallbackTitle });
|
||||
await this.#get().internal_updateGenerationTopic(topicId, { title: fallbackTitle });
|
||||
},
|
||||
onFinish: async (text) => {
|
||||
await get().internal_updateGenerationTopic(topicId, { title: text });
|
||||
await this.#get().internal_updateGenerationTopic(topicId, { title: text });
|
||||
},
|
||||
onLoadingChange: (loading) => {
|
||||
internal_updateGenerationTopicLoading(topicId, loading);
|
||||
|
|
@ -260,29 +253,34 @@ export const createGenerationTopicSlice: StateCreator<
|
|||
});
|
||||
|
||||
return output;
|
||||
},
|
||||
};
|
||||
|
||||
switchGenerationTopic: (topicId: string) => {
|
||||
if (get().activeGenerationTopicId === topicId) return;
|
||||
switchGenerationTopic = (topicId: string): void => {
|
||||
if (this.#get().activeGenerationTopicId === topicId) return;
|
||||
|
||||
set({ activeGenerationTopicId: topicId }, false, n('switchGenerationTopic'));
|
||||
},
|
||||
this.#set({ activeGenerationTopicId: topicId }, false, n('switchGenerationTopic'));
|
||||
};
|
||||
|
||||
updateGenerationTopicCover: async (topicId: string, coverUrl: string) => {
|
||||
const { internal_updateGenerationTopicCover } = get();
|
||||
updateGenerationTopicCover = async (topicId: string, coverUrl: string): Promise<void> => {
|
||||
const { internal_updateGenerationTopicCover } = this.#get();
|
||||
await internal_updateGenerationTopicCover(topicId, coverUrl);
|
||||
},
|
||||
};
|
||||
|
||||
useFetchGenerationTopics: (enabled) =>
|
||||
useFetchGenerationTopics = (enabled: boolean): SWRResponse<ImageGenerationTopic[]> =>
|
||||
useClientDataSWR<ImageGenerationTopic[]>(
|
||||
enabled ? [FETCH_GENERATION_TOPICS_KEY] : null,
|
||||
() => generationTopicService.getAllGenerationTopics('video'),
|
||||
{
|
||||
onSuccess: (data) => {
|
||||
if (isEqual(data, get().generationTopics)) return;
|
||||
set({ generationTopics: data }, false, n('useFetchGenerationTopics'));
|
||||
if (isEqual(data, this.#get().generationTopics)) return;
|
||||
this.#set({ generationTopics: data }, false, n('useFetchGenerationTopics'));
|
||||
},
|
||||
suspense: true,
|
||||
},
|
||||
),
|
||||
});
|
||||
);
|
||||
}
|
||||
|
||||
export type GenerationTopicAction = Pick<
|
||||
GenerationTopicActionImpl,
|
||||
keyof GenerationTopicActionImpl
|
||||
>;
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import { type StateCreator } from 'zustand/vanilla';
|
|||
|
||||
import { createDevtools } from '../middleware/createDevtools';
|
||||
import { expose } from '../middleware/expose';
|
||||
import { flattenActions } from '../utils/flattenActions';
|
||||
import { initialState, type VideoStoreState } from './initialState';
|
||||
import { createCreateVideoSlice, type CreateVideoAction } from './slices/createVideo/action';
|
||||
import {
|
||||
|
|
@ -22,20 +23,23 @@ import {
|
|||
|
||||
// =============== aggregate createStoreFn ============ //
|
||||
|
||||
export interface VideoStore
|
||||
extends
|
||||
GenerationConfigAction,
|
||||
GenerationTopicAction,
|
||||
GenerationBatchAction,
|
||||
CreateVideoAction,
|
||||
VideoStoreState {}
|
||||
type VideoStoreAction = GenerationConfigAction &
|
||||
GenerationTopicAction &
|
||||
GenerationBatchAction &
|
||||
CreateVideoAction;
|
||||
|
||||
const createStore: StateCreator<VideoStore, [['zustand/devtools', never]]> = (...parameters) => ({
|
||||
export interface VideoStore extends VideoStoreAction, VideoStoreState {}
|
||||
|
||||
const createStore: StateCreator<VideoStore, [['zustand/devtools', never]]> = (
|
||||
...parameters: Parameters<StateCreator<VideoStore, [['zustand/devtools', never]]>>
|
||||
) => ({
|
||||
...initialState,
|
||||
...createGenerationConfigSlice(...parameters),
|
||||
...createGenerationTopicSlice(...parameters),
|
||||
...createGenerationBatchSlice(...parameters),
|
||||
...createCreateVideoSlice(...parameters),
|
||||
...flattenActions<VideoStoreAction>([
|
||||
createGenerationConfigSlice(...parameters),
|
||||
createGenerationTopicSlice(...parameters),
|
||||
createGenerationBatchSlice(...parameters),
|
||||
createCreateVideoSlice(...parameters),
|
||||
]),
|
||||
});
|
||||
|
||||
// =============== implement useStore ============ //
|
||||
|
|
|
|||
Loading…
Reference in a new issue