mirror of
https://github.com/lobehub/lobehub
synced 2026-04-21 17:47:27 +00:00
✨ feat: support server db mode with Postgres / Drizzle ORM / tRPC (#2556)
This commit is contained in:
parent
7789340ead
commit
b26afbff7f
80 changed files with 14644 additions and 233 deletions
39
.github/workflows/test.yml
vendored
39
.github/workflows/test.yml
vendored
|
|
@ -1,9 +1,23 @@
|
|||
name: Test CI
|
||||
|
||||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:16
|
||||
env:
|
||||
POSTGRES_PASSWORD: postgres
|
||||
options: >-
|
||||
--health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
|
||||
|
||||
|
||||
ports:
|
||||
- 5432:5432
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
|
|
@ -18,10 +32,27 @@ jobs:
|
|||
- name: Lint
|
||||
run: bun run lint
|
||||
|
||||
- name: Test and coverage
|
||||
run: bun run test:coverage
|
||||
- name: Test Server Coverage
|
||||
run: bun run test-server:coverage
|
||||
env:
|
||||
DATABASE_TEST_URL: postgresql://postgres:postgres@localhost:5432/postgres
|
||||
DATABASE_DRIVER: node
|
||||
NEXT_PUBLIC_SERVICE_MODE: server
|
||||
KEY_VAULTS_SECRET: LA7n9k3JdEcbSgml2sxfw+4TV1AzaaFU5+R176aQz4s=
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
- name: Upload Server coverage to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }} # required
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
files: ./coverage/server/lcov.info
|
||||
flags: server
|
||||
|
||||
- name: Test App Coverage
|
||||
run: bun run test-app:coverage
|
||||
|
||||
- name: Upload App Coverage to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
files: ./coverage/app/lcov.info
|
||||
flags: app
|
||||
|
|
|
|||
11
codecov.yml
Normal file
11
codecov.yml
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
coverage:
|
||||
status:
|
||||
project:
|
||||
default: off
|
||||
server:
|
||||
flags:
|
||||
- server
|
||||
app:
|
||||
flags:
|
||||
- app
|
||||
patch: off
|
||||
29
drizzle.config.ts
Normal file
29
drizzle.config.ts
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
import * as dotenv from 'dotenv';
|
||||
import type { Config } from 'drizzle-kit';
|
||||
|
||||
// Read the .env file if it exists, or a file specified by the
|
||||
|
||||
// dotenv_config_path parameter that's passed to Node.js
|
||||
|
||||
dotenv.config();
|
||||
|
||||
let connectionString = process.env.DATABASE_URL;
|
||||
|
||||
if (process.env.NODE_ENV === 'test') {
|
||||
console.log('current ENV:', process.env.NODE_ENV);
|
||||
connectionString = process.env.DATABASE_TEST_URL;
|
||||
}
|
||||
|
||||
if (!connectionString)
|
||||
throw new Error('`DATABASE_URL` or `DATABASE_TEST_URL` not found in environment');
|
||||
|
||||
export default {
|
||||
dbCredentials: {
|
||||
url: connectionString,
|
||||
},
|
||||
dialect: 'postgresql',
|
||||
out: './src/database/server/migrations',
|
||||
|
||||
schema: './src/database/server/schemas/lobechat.ts',
|
||||
strict: true,
|
||||
} satisfies Config;
|
||||
|
|
@ -60,6 +60,9 @@ const nextConfig = {
|
|||
},
|
||||
});
|
||||
|
||||
// https://github.com/pinojs/pino/issues/688#issuecomment-637763276
|
||||
config.externals.push('pino-pretty');
|
||||
|
||||
return config;
|
||||
},
|
||||
};
|
||||
|
|
|
|||
26
package.json
26
package.json
|
|
@ -27,10 +27,17 @@
|
|||
"sideEffects": false,
|
||||
"scripts": {
|
||||
"build": "next build",
|
||||
"postbuild": "npm run build-sitemap",
|
||||
"postbuild": "npm run build-sitemap && npm run build-migrate-db",
|
||||
"build-migrate-db": "bun run db:migrate",
|
||||
"build-sitemap": "next-sitemap --config next-sitemap.config.mjs",
|
||||
"build:analyze": "ANALYZE=true next build",
|
||||
"build:docker": "DOCKER=true next build && npm run build-sitemap",
|
||||
"db:generate": "drizzle-kit generate -- dotenv_config_path='.env'",
|
||||
"db:migrate": "MIGRATION_DB=1 tsx scripts/migrateServerDB/index.ts",
|
||||
"db:push": "drizzle-kit push -- dotenv_config_path='.env'",
|
||||
"db:push-test": "NODE_ENV=test drizzle-kit push -- dotenv_config_path='.env'",
|
||||
"db:studio": "drizzle-kit studio",
|
||||
"db:z-pull": "drizzle-kit introspect -- dotenv_config_path='.env'",
|
||||
"dev": "next dev -p 3010",
|
||||
"dev:clerk-proxy": "ngrok http http://localhost:3011",
|
||||
"docs:i18n": "lobe-i18n md && npm run workflow:docs && npm run lint:mdx",
|
||||
|
|
@ -48,8 +55,11 @@
|
|||
"release": "semantic-release",
|
||||
"start": "next start",
|
||||
"stylelint": "stylelint \"src/**/*.{js,jsx,ts,tsx}\" --fix",
|
||||
"test": "vitest",
|
||||
"test:coverage": "vitest run --coverage",
|
||||
"test": "npm run test-app && npm run test-server",
|
||||
"test-app": "vitest run --config vitest.config.ts",
|
||||
"test-app:coverage": "vitest run --config vitest.config.ts --coverage",
|
||||
"test-server": "vitest run --config vitest.server.config.ts",
|
||||
"test-server:coverage": "vitest run --config vitest.server.config.ts --coverage",
|
||||
"test:update": "vitest -u",
|
||||
"type-check": "tsc --noEmit",
|
||||
"workflow:docs": "tsx scripts/docsWorkflow/index.ts",
|
||||
|
|
@ -101,6 +111,7 @@
|
|||
"@lobehub/tts": "^1.24.1",
|
||||
"@lobehub/ui": "^1.141.2",
|
||||
"@microsoft/fetch-event-source": "^2.0.1",
|
||||
"@neondatabase/serverless": "^0.9.3",
|
||||
"@next/third-parties": "^14.2.3",
|
||||
"@sentry/nextjs": "^7.116.0",
|
||||
"@t3-oss/env-nextjs": "^0.10.1",
|
||||
|
|
@ -119,6 +130,8 @@
|
|||
"debug": "^4.3.4",
|
||||
"dexie": "^3.2.7",
|
||||
"diff": "^5.2.0",
|
||||
"drizzle-orm": "^0.30.10",
|
||||
"drizzle-zod": "^0.5.1",
|
||||
"fast-deep-equal": "^3.1.3",
|
||||
"gpt-tokenizer": "^2.1.2",
|
||||
"i18next": "^23.11.5",
|
||||
|
|
@ -141,6 +154,7 @@
|
|||
"nuqs": "^1.17.4",
|
||||
"ollama": "^0.5.1",
|
||||
"openai": "^4.47.1",
|
||||
"pg": "^8.11.5",
|
||||
"pino": "^9.1.0",
|
||||
"polished": "^4.3.1",
|
||||
"posthog-js": "^1.135.2",
|
||||
|
|
@ -163,6 +177,7 @@
|
|||
"semver": "^7.6.2",
|
||||
"sharp": "^0.33.4",
|
||||
"superjson": "^2.2.1",
|
||||
"svix": "^1.24.0",
|
||||
"swr": "^2.2.5",
|
||||
"systemjs": "^6.15.1",
|
||||
"ts-md5": "^1.3.1",
|
||||
|
|
@ -171,6 +186,7 @@
|
|||
"use-merge-value": "^1.2.0",
|
||||
"utility-types": "^3.11.0",
|
||||
"uuid": "^10.0.0",
|
||||
"ws": "^8.17.0",
|
||||
"y-protocols": "^1.0.6",
|
||||
"y-webrtc": "^10.3.0",
|
||||
"yaml": "^2.4.2",
|
||||
|
|
@ -200,6 +216,7 @@
|
|||
"@types/lodash-es": "^4.17.12",
|
||||
"@types/node": "^20.12.12",
|
||||
"@types/numeral": "^2.0.5",
|
||||
"@types/pg": "^8.11.6",
|
||||
"@types/react": "^18.3.3",
|
||||
"@types/react-dom": "^18.3.0",
|
||||
"@types/rtl-detect": "^1.0.3",
|
||||
|
|
@ -207,12 +224,15 @@
|
|||
"@types/systemjs": "^6.13.5",
|
||||
"@types/ua-parser-js": "^0.7.39",
|
||||
"@types/uuid": "^9.0.8",
|
||||
"@types/ws": "^8.5.10",
|
||||
"@umijs/lint": "^4.2.5",
|
||||
"@vitest/coverage-v8": "~1.2.2",
|
||||
"ajv-keywords": "^5.1.0",
|
||||
"commitlint": "^19.3.0",
|
||||
"consola": "^3.2.3",
|
||||
"dotenv": "^16.4.5",
|
||||
"dpdm": "^3.14.0",
|
||||
"drizzle-kit": "^0.21.1",
|
||||
"eslint": "^8.57.0",
|
||||
"eslint-plugin-mdx": "^2.3.4",
|
||||
"fake-indexeddb": "^6.0.0",
|
||||
|
|
|
|||
30
scripts/migrateServerDB/index.ts
Normal file
30
scripts/migrateServerDB/index.ts
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
import * as dotenv from 'dotenv';
|
||||
import * as migrator from 'drizzle-orm/neon-serverless/migrator';
|
||||
import { join } from 'node:path';
|
||||
|
||||
import { serverDB } from '../../src/database/server/core/db';
|
||||
|
||||
// Read the `.env` file if it exists, or a file specified by the
|
||||
// dotenv_config_path parameter that's passed to Node.js
|
||||
dotenv.config();
|
||||
|
||||
const runMigrations = async () => {
|
||||
await migrator.migrate(serverDB, {
|
||||
migrationsFolder: join(__dirname, '../../src/database/server/migrations'),
|
||||
});
|
||||
console.log('✅ database migration pass.');
|
||||
// eslint-disable-next-line unicorn/no-process-exit
|
||||
process.exit(0);
|
||||
};
|
||||
|
||||
let connectionString = process.env.DATABASE_URL;
|
||||
|
||||
// only migrate database if the connection string is available
|
||||
if (connectionString) {
|
||||
// eslint-disable-next-line unicorn/prefer-top-level-await
|
||||
runMigrations().catch((err) => {
|
||||
console.error('❌ Database migrate failed:', err);
|
||||
// eslint-disable-next-line unicorn/no-process-exit
|
||||
process.exit(1);
|
||||
});
|
||||
}
|
||||
|
|
@ -6,6 +6,7 @@ import { useTranslation } from 'react-i18next';
|
|||
import { CellProps } from '@/components/Cell';
|
||||
import { enableAuth } from '@/const/auth';
|
||||
import { DISCORD, DOCUMENTS, FEEDBACK } from '@/const/url';
|
||||
import { isServerMode } from '@/const/version';
|
||||
import { usePWAInstall } from '@/hooks/usePWAInstall';
|
||||
import { useUserStore } from '@/store/user';
|
||||
import { authSelectors } from '@/store/user/slices/auth/selectors';
|
||||
|
|
@ -109,7 +110,7 @@ export const useCategory = () => {
|
|||
|
||||
/* ↑ cloud slot ↑ */
|
||||
...(canInstall ? pwa : []),
|
||||
...(isLogin ? data : []),
|
||||
...(isLogin && !isServerMode ? data : []),
|
||||
...helps,
|
||||
].filter(Boolean) as CellProps[];
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import { ActionIcon, Icon } from '@lobehub/ui';
|
||||
import { App, Dropdown, type MenuProps } from 'antd';
|
||||
import { App, Dropdown } from 'antd';
|
||||
import { createStyles } from 'antd-style';
|
||||
import { ItemType } from 'antd/es/menu/interface';
|
||||
import isEqual from 'fast-deep-equal';
|
||||
import {
|
||||
Check,
|
||||
|
|
@ -16,6 +17,7 @@ import {
|
|||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { isServerMode } from '@/const/version';
|
||||
import { configService } from '@/services/config';
|
||||
import { useSessionStore } from '@/store/session';
|
||||
import { sessionHelpers } from '@/store/session/helpers';
|
||||
|
|
@ -58,108 +60,113 @@ const Actions = memo<ActionProps>(({ group, id, openCreateGroupModal, setOpen })
|
|||
const isDefault = group === SessionDefaultGroup.Default;
|
||||
// const hasDivider = !isDefault || Object.keys(sessionByGroup).length > 0;
|
||||
|
||||
const items: MenuProps['items'] = useMemo(
|
||||
() => [
|
||||
{
|
||||
icon: <Icon icon={pin ? PinOff : Pin} />,
|
||||
key: 'pin',
|
||||
label: t(pin ? 'pinOff' : 'pin'),
|
||||
onClick: () => {
|
||||
pinSession(id, !pin);
|
||||
},
|
||||
},
|
||||
{
|
||||
icon: <Icon icon={LucideCopy} />,
|
||||
key: 'duplicate',
|
||||
label: t('duplicate', { ns: 'common' }),
|
||||
onClick: ({ domEvent }) => {
|
||||
domEvent.stopPropagation();
|
||||
|
||||
duplicateSession(id);
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'divider',
|
||||
},
|
||||
{
|
||||
children: [
|
||||
...sessionCustomGroups.map(({ id: groupId, name }) => ({
|
||||
icon: group === groupId ? <Icon icon={Check} /> : <div />,
|
||||
key: groupId,
|
||||
label: name,
|
||||
onClick: () => {
|
||||
updateSessionGroup(id, groupId);
|
||||
},
|
||||
})),
|
||||
const items = useMemo(
|
||||
() =>
|
||||
(
|
||||
[
|
||||
{
|
||||
icon: isDefault ? <Icon icon={Check} /> : <div />,
|
||||
key: 'defaultList',
|
||||
label: t('defaultList'),
|
||||
icon: <Icon icon={pin ? PinOff : Pin} />,
|
||||
key: 'pin',
|
||||
label: t(pin ? 'pinOff' : 'pin'),
|
||||
onClick: () => {
|
||||
updateSessionGroup(id, SessionDefaultGroup.Default);
|
||||
pinSession(id, !pin);
|
||||
},
|
||||
},
|
||||
{
|
||||
icon: <Icon icon={LucideCopy} />,
|
||||
key: 'duplicate',
|
||||
label: t('duplicate', { ns: 'common' }),
|
||||
onClick: ({ domEvent }) => {
|
||||
domEvent.stopPropagation();
|
||||
|
||||
duplicateSession(id);
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'divider',
|
||||
},
|
||||
{
|
||||
icon: <Icon icon={LucidePlus} />,
|
||||
key: 'createGroup',
|
||||
label: <div>{t('sessionGroup.createGroup')}</div>,
|
||||
children: [
|
||||
...sessionCustomGroups.map(({ id: groupId, name }) => ({
|
||||
icon: group === groupId ? <Icon icon={Check} /> : <div />,
|
||||
key: groupId,
|
||||
label: name,
|
||||
onClick: () => {
|
||||
updateSessionGroup(id, groupId);
|
||||
},
|
||||
})),
|
||||
{
|
||||
icon: isDefault ? <Icon icon={Check} /> : <div />,
|
||||
key: 'defaultList',
|
||||
label: t('defaultList'),
|
||||
onClick: () => {
|
||||
updateSessionGroup(id, SessionDefaultGroup.Default);
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'divider',
|
||||
},
|
||||
{
|
||||
icon: <Icon icon={LucidePlus} />,
|
||||
key: 'createGroup',
|
||||
label: <div>{t('sessionGroup.createGroup')}</div>,
|
||||
onClick: ({ domEvent }) => {
|
||||
domEvent.stopPropagation();
|
||||
openCreateGroupModal();
|
||||
},
|
||||
},
|
||||
],
|
||||
icon: <Icon icon={ListTree} />,
|
||||
key: 'moveGroup',
|
||||
label: t('sessionGroup.moveGroup'),
|
||||
},
|
||||
{
|
||||
type: 'divider',
|
||||
},
|
||||
isServerMode
|
||||
? undefined
|
||||
: {
|
||||
children: [
|
||||
{
|
||||
key: 'agent',
|
||||
label: t('exportType.agent', { ns: 'common' }),
|
||||
onClick: () => {
|
||||
configService.exportSingleAgent(id);
|
||||
},
|
||||
},
|
||||
{
|
||||
key: 'agentWithMessage',
|
||||
label: t('exportType.agentWithMessage', { ns: 'common' }),
|
||||
onClick: () => {
|
||||
configService.exportSingleSession(id);
|
||||
},
|
||||
},
|
||||
],
|
||||
icon: <Icon icon={HardDriveDownload} />,
|
||||
key: 'export',
|
||||
label: t('export', { ns: 'common' }),
|
||||
},
|
||||
{
|
||||
danger: true,
|
||||
icon: <Icon icon={Trash} />,
|
||||
key: 'delete',
|
||||
label: t('delete', { ns: 'common' }),
|
||||
onClick: ({ domEvent }) => {
|
||||
domEvent.stopPropagation();
|
||||
openCreateGroupModal();
|
||||
modal.confirm({
|
||||
centered: true,
|
||||
okButtonProps: { danger: true },
|
||||
onOk: async () => {
|
||||
await removeSession(id);
|
||||
message.success(t('confirmRemoveSessionSuccess'));
|
||||
},
|
||||
rootClassName: styles.modalRoot,
|
||||
title: t('confirmRemoveSessionItemAlert'),
|
||||
});
|
||||
},
|
||||
},
|
||||
],
|
||||
icon: <Icon icon={ListTree} />,
|
||||
key: 'moveGroup',
|
||||
label: t('sessionGroup.moveGroup'),
|
||||
},
|
||||
{
|
||||
type: 'divider',
|
||||
},
|
||||
{
|
||||
children: [
|
||||
{
|
||||
key: 'agent',
|
||||
label: t('exportType.agent', { ns: 'common' }),
|
||||
onClick: () => {
|
||||
configService.exportSingleAgent(id);
|
||||
},
|
||||
},
|
||||
{
|
||||
key: 'agentWithMessage',
|
||||
label: t('exportType.agentWithMessage', { ns: 'common' }),
|
||||
onClick: () => {
|
||||
configService.exportSingleSession(id);
|
||||
},
|
||||
},
|
||||
],
|
||||
icon: <Icon icon={HardDriveDownload} />,
|
||||
key: 'export',
|
||||
label: t('export', { ns: 'common' }),
|
||||
},
|
||||
{
|
||||
danger: true,
|
||||
icon: <Icon icon={Trash} />,
|
||||
key: 'delete',
|
||||
label: t('delete', { ns: 'common' }),
|
||||
onClick: ({ domEvent }) => {
|
||||
domEvent.stopPropagation();
|
||||
modal.confirm({
|
||||
centered: true,
|
||||
okButtonProps: { danger: true },
|
||||
onOk: async () => {
|
||||
await removeSession(id);
|
||||
message.success(t('confirmRemoveSessionSuccess'));
|
||||
},
|
||||
rootClassName: styles.modalRoot,
|
||||
title: t('confirmRemoveSessionItemAlert'),
|
||||
});
|
||||
},
|
||||
},
|
||||
],
|
||||
] as ItemType[]
|
||||
).filter(Boolean),
|
||||
[id, pin],
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import { memo, useMemo } from 'react';
|
|||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { HEADER_ICON_SIZE } from '@/const/layoutTokens';
|
||||
import { isServerMode } from '@/const/version';
|
||||
import { configService } from '@/services/config';
|
||||
import { useServerConfigStore } from '@/store/serverConfig';
|
||||
import { useSessionStore } from '@/store/session';
|
||||
|
|
@ -18,45 +19,50 @@ export const HeaderContent = memo<{ mobile?: boolean; modal?: boolean }>(({ moda
|
|||
const mobile = useServerConfigStore((s) => s.isMobile);
|
||||
|
||||
const items = useMemo<MenuProps['items']>(
|
||||
() => [
|
||||
{
|
||||
key: 'agent',
|
||||
label: <div>{t('exportType.agent', { ns: 'common' })}</div>,
|
||||
onClick: () => {
|
||||
if (!id) return;
|
||||
() =>
|
||||
isServerMode
|
||||
? []
|
||||
: [
|
||||
{
|
||||
key: 'agent',
|
||||
label: <div>{t('exportType.agent', { ns: 'common' })}</div>,
|
||||
onClick: () => {
|
||||
if (!id) return;
|
||||
|
||||
configService.exportSingleAgent(id);
|
||||
},
|
||||
},
|
||||
{
|
||||
key: 'agentWithMessage',
|
||||
label: <div>{t('exportType.agentWithMessage', { ns: 'common' })}</div>,
|
||||
onClick: () => {
|
||||
if (!id) return;
|
||||
configService.exportSingleAgent(id);
|
||||
},
|
||||
},
|
||||
{
|
||||
key: 'agentWithMessage',
|
||||
label: <div>{t('exportType.agentWithMessage', { ns: 'common' })}</div>,
|
||||
onClick: () => {
|
||||
if (!id) return;
|
||||
|
||||
configService.exportSingleSession(id);
|
||||
},
|
||||
},
|
||||
],
|
||||
configService.exportSingleSession(id);
|
||||
},
|
||||
},
|
||||
],
|
||||
[],
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<SubmitAgentButton modal={modal} />
|
||||
<Dropdown arrow={false} menu={{ items }} trigger={['click']}>
|
||||
{modal ? (
|
||||
<Button block icon={<Icon icon={HardDriveDownload} />}>
|
||||
{t('export', { ns: 'common' })}
|
||||
</Button>
|
||||
) : (
|
||||
<ActionIcon
|
||||
icon={HardDriveDownload}
|
||||
size={HEADER_ICON_SIZE(mobile)}
|
||||
title={t('export', { ns: 'common' })}
|
||||
/>
|
||||
)}
|
||||
</Dropdown>
|
||||
{!isServerMode && (
|
||||
<Dropdown arrow={false} menu={{ items }} trigger={['click']}>
|
||||
{modal ? (
|
||||
<Button block icon={<Icon icon={HardDriveDownload} />}>
|
||||
{t('export', { ns: 'common' })}
|
||||
</Button>
|
||||
) : (
|
||||
<ActionIcon
|
||||
icon={HardDriveDownload}
|
||||
size={HEADER_ICON_SIZE(mobile)}
|
||||
title={t('export', { ns: 'common' })}
|
||||
/>
|
||||
)}
|
||||
</Dropdown>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
|
|
|||
|
|
@ -0,0 +1,73 @@
|
|||
{
|
||||
"backup_code_enabled": false,
|
||||
"banned": false,
|
||||
"create_organization_enabled": true,
|
||||
"created_at": 1713709987911,
|
||||
"delete_self_enabled": true,
|
||||
"email_addresses": [
|
||||
{
|
||||
"created_at": 1713709977919,
|
||||
"email_address": "arvinx@foxmail.com",
|
||||
"id": "idn_2fPkD9X1lfzSn5lJVDGyochYq8k",
|
||||
"linked_to": [],
|
||||
"object": "email_address",
|
||||
"reserved": false,
|
||||
"updated_at": 1713709987951,
|
||||
"verification": []
|
||||
}
|
||||
],
|
||||
"external_accounts": [
|
||||
{
|
||||
"approved_scopes": "read:user user:email",
|
||||
"avatar_url": "https://avatars.githubusercontent.com/u/28616219?v=4",
|
||||
"created_at": 1713709542104,
|
||||
"email_address": "arvinx@foxmail.com",
|
||||
"first_name": "Arvin",
|
||||
"id": "eac_2fPjKROeJ1bBs8Uxa6RFMxKogTB",
|
||||
"identification_id": "idn_2fPjyV3sqtQJZUbEzdK2y23a1bq",
|
||||
"image_url": "https://img.clerk.com/eyJ0eXBlIjoicHJveHkiLCJzcmMiOiJodHRwczovL2F2YXRhcnMuZ2l0aHVidXNlcmNvbnRlbnQuY29tL3UvMjg2MTYyMTk/dj00IiwicyI6IkhCeHE5NmdlRk85ekRxMjJlR05EalUrbVFBbmVDZjRVQkpwNGYxcW5JajQifQ",
|
||||
"label": null,
|
||||
"last_name": "Xu",
|
||||
"object": "external_account",
|
||||
"provider": "oauth_github",
|
||||
"provider_user_id": "28616219",
|
||||
"public_metadata": {},
|
||||
"updated_at": 1713709542104,
|
||||
"username": "arvinxx",
|
||||
"verification": {
|
||||
"attempts": null,
|
||||
"expire_at": 1713710140131,
|
||||
"status": "verified",
|
||||
"strategy": "oauth_github"
|
||||
}
|
||||
}
|
||||
],
|
||||
"external_id": null,
|
||||
"first_name": "Arvin",
|
||||
"has_image": true,
|
||||
"id": "user_2fPkELglwI48WpZVwwdAxBKBPK6",
|
||||
"image_url": "https://img.clerk.com/eyJ0eXBlIjoicHJveHkiLCJzcmMiOiJodHRwczovL2ltYWdlcy5jbGVyay5kZXYvb2F1dGhfZ2l0aHViL2ltZ18yZlBrRU1adVpwdlpvZFBHcVREdHJnTzJJM3cifQ",
|
||||
"last_active_at": 1713709987902,
|
||||
"last_name": "Xu",
|
||||
"last_sign_in_at": null,
|
||||
"locked": false,
|
||||
"lockout_expires_in_seconds": null,
|
||||
"object": "user",
|
||||
"passkeys": [],
|
||||
"password_enabled": false,
|
||||
"phone_numbers": [],
|
||||
"primary_email_address_id": "idn_2fPkD9X1lfzSn5lJVDGyochYq8k",
|
||||
"primary_phone_number_id": null,
|
||||
"primary_web3_wallet_id": null,
|
||||
"private_metadata": {},
|
||||
"profile_image_url": "https://images.clerk.dev/oauth_github/img_2fPkEMZuZpvZodPGqTDtrgO2I3w",
|
||||
"public_metadata": {},
|
||||
"saml_accounts": [],
|
||||
"totp_enabled": false,
|
||||
"two_factor_enabled": false,
|
||||
"unsafe_metadata": {},
|
||||
"updated_at": 1713709987972,
|
||||
"username": "arvinxx",
|
||||
"verification_attempts_remaining": 100,
|
||||
"web3_wallets": []
|
||||
}
|
||||
159
src/app/api/webhooks/clerk/route.ts
Normal file
159
src/app/api/webhooks/clerk/route.ts
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
import { UserJSON } from '@clerk/backend';
|
||||
import { NextResponse } from 'next/server';
|
||||
|
||||
import { authEnv } from '@/config/auth';
|
||||
import { isServerMode } from '@/const/version';
|
||||
import { UserModel } from '@/database/server/models/user';
|
||||
import { pino } from '@/libs/logger';
|
||||
|
||||
import { validateRequest } from './validateRequest';
|
||||
|
||||
if (authEnv.NEXT_PUBLIC_ENABLE_CLERK_AUTH && isServerMode && !authEnv.CLERK_WEBHOOK_SECRET) {
|
||||
throw new Error('`CLERK_WEBHOOK_SECRET` environment variable is missing');
|
||||
}
|
||||
|
||||
const createUser = async (id: string, params: UserJSON) => {
|
||||
pino.info('creating user due to clerk webhook');
|
||||
|
||||
const userModel = new UserModel();
|
||||
|
||||
// Check if user already exists
|
||||
const res = await userModel.findById(id);
|
||||
|
||||
// If user already exists, skip creating a new user
|
||||
if (res)
|
||||
return NextResponse.json(
|
||||
{ message: 'user not created due to user already existing in the database', success: false },
|
||||
{ status: 200 },
|
||||
);
|
||||
|
||||
const email = params.email_addresses.find((e) => e.id === params.primary_email_address_id);
|
||||
const phone = params.phone_numbers.find((e) => e.id === params.primary_phone_number_id);
|
||||
|
||||
await userModel.createUser({
|
||||
avatar: params.image_url,
|
||||
clerkCreatedAt: new Date(params.created_at),
|
||||
email: email?.email_address,
|
||||
firstName: params.first_name,
|
||||
id,
|
||||
lastName: params.last_name,
|
||||
phone: phone?.phone_number,
|
||||
username: params.username,
|
||||
});
|
||||
|
||||
return NextResponse.json({ message: 'user created', success: true }, { status: 200 });
|
||||
};
|
||||
|
||||
const deleteUser = async (id?: string) => {
|
||||
if (id) {
|
||||
pino.info('delete user due to clerk webhook');
|
||||
const userModel = new UserModel();
|
||||
|
||||
await userModel.deleteUser(id);
|
||||
|
||||
return NextResponse.json({ message: 'user deleted' }, { status: 200 });
|
||||
} else {
|
||||
pino.warn('clerk sent a delete user request, but no user ID was included in the payload');
|
||||
return NextResponse.json({ message: 'ok' }, { status: 200 });
|
||||
}
|
||||
};
|
||||
|
||||
const updateUser = async (id: string, params: UserJSON) => {
|
||||
pino.info('updating user due to clerk webhook');
|
||||
|
||||
const userModel = new UserModel();
|
||||
|
||||
// Check if user already exists
|
||||
const res = await userModel.findById(id);
|
||||
|
||||
// If user not exists, skip update the user
|
||||
if (!res)
|
||||
return NextResponse.json(
|
||||
{
|
||||
message: "user not updated due to the user don't existing in the database",
|
||||
success: false,
|
||||
},
|
||||
{ status: 200 },
|
||||
);
|
||||
|
||||
const email = params.email_addresses.find((e) => e.id === params.primary_email_address_id);
|
||||
const phone = params.phone_numbers.find((e) => e.id === params.primary_phone_number_id);
|
||||
|
||||
await userModel.updateUser(id, {
|
||||
avatar: params.image_url,
|
||||
email: email?.email_address,
|
||||
firstName: params.first_name,
|
||||
id,
|
||||
lastName: params.last_name,
|
||||
phone: phone?.phone_number,
|
||||
username: params.username,
|
||||
});
|
||||
|
||||
return NextResponse.json({ message: 'user updated', success: true }, { status: 200 });
|
||||
};
|
||||
|
||||
export const POST = async (req: Request): Promise<NextResponse> => {
|
||||
const payload = await validateRequest(req, authEnv.CLERK_WEBHOOK_SECRET!);
|
||||
|
||||
if (!payload) {
|
||||
return NextResponse.json(
|
||||
{ error: 'webhook verification failed or payload was malformed' },
|
||||
{ status: 400 },
|
||||
);
|
||||
}
|
||||
|
||||
const { type, data } = payload;
|
||||
|
||||
pino.trace(`clerk webhook payload: ${{ data, type }}`);
|
||||
|
||||
switch (type) {
|
||||
case 'user.created': {
|
||||
return createUser(data.id, data);
|
||||
}
|
||||
case 'user.deleted': {
|
||||
return deleteUser(data.id);
|
||||
}
|
||||
case 'user.updated': {
|
||||
return updateUser(data.id, data);
|
||||
}
|
||||
|
||||
default: {
|
||||
pino.warn(
|
||||
`${req.url} received event type "${type}", but no handler is defined for this type`,
|
||||
);
|
||||
return NextResponse.json({ error: `uncreognised payload type: ${type}` }, { status: 400 });
|
||||
}
|
||||
// case 'user.updated':
|
||||
// break;
|
||||
// case 'session.created':
|
||||
// break;
|
||||
// case 'session.ended':
|
||||
// break;
|
||||
// case 'session.removed':
|
||||
// break;
|
||||
// case 'session.revoked':
|
||||
// break;
|
||||
// case 'email.created':
|
||||
// break;
|
||||
// case 'sms.created':
|
||||
// break;
|
||||
// case 'organization.created':
|
||||
// break;
|
||||
// case 'organization.updated':
|
||||
// break;
|
||||
// case 'organization.deleted':
|
||||
// break;
|
||||
// case 'organizationMembership.created':
|
||||
// break;
|
||||
// case 'organizationMembership.deleted':
|
||||
// break;
|
||||
// case 'organizationMembership.updated':
|
||||
// break;
|
||||
// case 'organizationInvitation.accepted':
|
||||
// break;
|
||||
// case 'organizationInvitation.created':
|
||||
// break;
|
||||
// case 'organizationInvitation.revoked':
|
||||
// break;
|
||||
}
|
||||
};
|
||||
22
src/app/api/webhooks/clerk/validateRequest.ts
Normal file
22
src/app/api/webhooks/clerk/validateRequest.ts
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
import { WebhookEvent } from '@clerk/nextjs/server';
|
||||
import { headers } from 'next/headers';
|
||||
import { Webhook } from 'svix';
|
||||
|
||||
export const validateRequest = async (request: Request, secret: string) => {
|
||||
const payloadString = await request.text();
|
||||
const headerPayload = headers();
|
||||
|
||||
const svixHeaders = {
|
||||
'svix-id': headerPayload.get('svix-id')!,
|
||||
'svix-signature': headerPayload.get('svix-signature')!,
|
||||
'svix-timestamp': headerPayload.get('svix-timestamp')!,
|
||||
};
|
||||
const wh = new Webhook(secret);
|
||||
|
||||
try {
|
||||
return wh.verify(payloadString, svixHeaders) as WebhookEvent;
|
||||
} catch {
|
||||
console.error('incoming webhook failed verification');
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
|
@ -3,7 +3,7 @@ import type { NextRequest } from 'next/server';
|
|||
|
||||
import { pino } from '@/libs/logger';
|
||||
import { createContext } from '@/server/context';
|
||||
import { edgeRouter } from '@/server/routers';
|
||||
import { edgeRouter } from '@/server/routers/edge';
|
||||
|
||||
export const runtime = 'edge';
|
||||
|
||||
|
|
|
|||
26
src/app/trpc/lambda/[trpc]/route.ts
Normal file
26
src/app/trpc/lambda/[trpc]/route.ts
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
import { fetchRequestHandler } from '@trpc/server/adapters/fetch';
|
||||
import type { NextRequest } from 'next/server';
|
||||
|
||||
import { pino } from '@/libs/logger';
|
||||
import { createContext } from '@/server/context';
|
||||
import { lambdaRouter } from '@/server/routers/lambda';
|
||||
|
||||
const handler = (req: NextRequest) =>
|
||||
fetchRequestHandler({
|
||||
/**
|
||||
* @link https://trpc.io/docs/v11/context
|
||||
*/
|
||||
createContext: () => createContext(req),
|
||||
|
||||
endpoint: '/trpc/lambda',
|
||||
|
||||
onError: ({ error, path, type }) => {
|
||||
pino.info(`Error in tRPC handler (lambda) on path: ${path}, type: ${type}`);
|
||||
console.error(error);
|
||||
},
|
||||
|
||||
req,
|
||||
router: lambdaRouter,
|
||||
});
|
||||
|
||||
export { handler as GET, handler as POST };
|
||||
|
|
@ -75,6 +75,7 @@ export const getAuthConfig = () => {
|
|||
server: {
|
||||
// Clerk
|
||||
CLERK_SECRET_KEY: z.string().optional(),
|
||||
CLERK_WEBHOOK_SECRET: z.string().optional(),
|
||||
|
||||
// NEXT-AUTH
|
||||
NEXT_AUTH_SECRET: z.string().optional(),
|
||||
|
|
@ -110,6 +111,7 @@ export const getAuthConfig = () => {
|
|||
NEXT_PUBLIC_ENABLE_CLERK_AUTH: !!process.env.NEXT_PUBLIC_CLERK_PUBLISHABLE_KEY,
|
||||
NEXT_PUBLIC_CLERK_PUBLISHABLE_KEY: process.env.NEXT_PUBLIC_CLERK_PUBLISHABLE_KEY,
|
||||
CLERK_SECRET_KEY: process.env.CLERK_SECRET_KEY,
|
||||
CLERK_WEBHOOK_SECRET: process.env.CLERK_WEBHOOK_SECRET,
|
||||
|
||||
// Next Auth
|
||||
NEXT_PUBLIC_ENABLE_NEXT_AUTH:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
/* eslint-disable sort-keys-fix/sort-keys-fix , typescript-sort-keys/interface */
|
||||
import { createEnv } from '@t3-oss/env-nextjs';
|
||||
import { z } from 'zod';
|
||||
|
||||
|
|
@ -8,8 +7,21 @@ export const getServerDBConfig = () => {
|
|||
NEXT_PUBLIC_ENABLED_SERVER_SERVICE: z.boolean(),
|
||||
},
|
||||
runtimeEnv: {
|
||||
DATABASE_DRIVER: process.env.DATABASE_DRIVER || 'neon',
|
||||
DATABASE_TEST_URL: process.env.DATABASE_TEST_URL,
|
||||
DATABASE_URL: process.env.DATABASE_URL,
|
||||
|
||||
KEY_VAULTS_SECRET: process.env.KEY_VAULTS_SECRET,
|
||||
|
||||
NEXT_PUBLIC_ENABLED_SERVER_SERVICE: process.env.NEXT_PUBLIC_SERVICE_MODE === 'server',
|
||||
},
|
||||
server: {
|
||||
DATABASE_DRIVER: z.enum(['neon', 'node']),
|
||||
DATABASE_TEST_URL: z.string().optional(),
|
||||
DATABASE_URL: z.string().optional(),
|
||||
|
||||
KEY_VAULTS_SECRET: z.string().optional(),
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
|
|
|
|||
44
src/database/server/core/db.ts
Normal file
44
src/database/server/core/db.ts
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
import { Pool as NeonPool, neonConfig } from '@neondatabase/serverless';
|
||||
import { NeonDatabase, drizzle as neonDrizzle } from 'drizzle-orm/neon-serverless';
|
||||
import { drizzle as nodeDrizzle } from 'drizzle-orm/node-postgres';
|
||||
import { Pool as NodePool } from 'pg';
|
||||
import ws from 'ws';
|
||||
|
||||
import { serverDBEnv } from '@/config/db';
|
||||
import { isServerMode } from '@/const/version';
|
||||
|
||||
import * as schema from '../schemas/lobechat';
|
||||
|
||||
export const getDBInstance = (): NeonDatabase<typeof schema> => {
|
||||
if (!isServerMode) return {} as any;
|
||||
|
||||
if (!serverDBEnv.KEY_VAULTS_SECRET) {
|
||||
throw new Error(
|
||||
` \`KEY_VAULTS_SECRET\` is not set, please set it in your environment variables.
|
||||
|
||||
If you don't have it, please run \`openssl rand -base64 32\` to create one.
|
||||
`,
|
||||
);
|
||||
}
|
||||
|
||||
let connectionString = serverDBEnv.DATABASE_URL;
|
||||
|
||||
if (!connectionString) {
|
||||
throw new Error(`You are try to use database, but "DATABASE_URL" is not set correctly`);
|
||||
}
|
||||
|
||||
if (serverDBEnv.DATABASE_DRIVER === 'node') {
|
||||
const client = new NodePool({ connectionString });
|
||||
return nodeDrizzle(client, { schema });
|
||||
}
|
||||
|
||||
if (process.env.MIGRATION_DB === '1') {
|
||||
// https://github.com/neondatabase/serverless/blob/main/CONFIG.md#websocketconstructor-typeof-websocket--undefined
|
||||
neonConfig.webSocketConstructor = ws;
|
||||
}
|
||||
|
||||
const client = new NeonPool({ connectionString });
|
||||
return neonDrizzle(client, { schema });
|
||||
};
|
||||
|
||||
export const serverDB = getDBInstance();
|
||||
45
src/database/server/core/dbForTest.ts
Normal file
45
src/database/server/core/dbForTest.ts
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
import { Pool as NeonPool, neonConfig } from '@neondatabase/serverless';
|
||||
import { drizzle as neonDrizzle } from 'drizzle-orm/neon-serverless';
|
||||
import * as migrator from 'drizzle-orm/neon-serverless/migrator';
|
||||
import { drizzle as nodeDrizzle } from 'drizzle-orm/node-postgres';
|
||||
import * as nodeMigrator from 'drizzle-orm/node-postgres/migrator';
|
||||
import { join } from 'node:path';
|
||||
import { Pool as NodePool } from 'pg';
|
||||
import ws from 'ws';
|
||||
|
||||
import { serverDBEnv } from '@/config/db';
|
||||
|
||||
import * as schema from '../schemas/lobechat';
|
||||
|
||||
export const getTestDBInstance = async () => {
|
||||
let connectionString = serverDBEnv.DATABASE_TEST_URL;
|
||||
|
||||
if (!connectionString) {
|
||||
throw new Error(`You are try to use database, but "DATABASE_TEST_URL" is not set correctly`);
|
||||
}
|
||||
|
||||
if (serverDBEnv.DATABASE_DRIVER === 'node') {
|
||||
const client = new NodePool({ connectionString });
|
||||
|
||||
const db = nodeDrizzle(client, { schema });
|
||||
|
||||
await nodeMigrator.migrate(db, {
|
||||
migrationsFolder: join(__dirname, '../migrations'),
|
||||
});
|
||||
|
||||
return db;
|
||||
}
|
||||
|
||||
// https://github.com/neondatabase/serverless/blob/main/CONFIG.md#websocketconstructor-typeof-websocket--undefined
|
||||
neonConfig.webSocketConstructor = ws;
|
||||
|
||||
const client = new NeonPool({ connectionString });
|
||||
|
||||
const db = neonDrizzle(client, { schema });
|
||||
|
||||
await migrator.migrate(db, {
|
||||
migrationsFolder: join(__dirname, '../migrations'),
|
||||
});
|
||||
|
||||
return db;
|
||||
};
|
||||
1
src/database/server/index.ts
Normal file
1
src/database/server/index.ts
Normal file
|
|
@ -0,0 +1 @@
|
|||
export { serverDB } from './core/db';
|
||||
439
src/database/server/migrations/0000_init.sql
Normal file
439
src/database/server/migrations/0000_init.sql
Normal file
|
|
@ -0,0 +1,439 @@
|
|||
CREATE TABLE IF NOT EXISTS "agents" (
|
||||
"id" text PRIMARY KEY NOT NULL,
|
||||
"slug" varchar(100),
|
||||
"title" text,
|
||||
"description" text,
|
||||
"tags" jsonb DEFAULT '[]'::jsonb,
|
||||
"avatar" text,
|
||||
"background_color" text,
|
||||
"plugins" jsonb DEFAULT '[]'::jsonb,
|
||||
"user_id" text NOT NULL,
|
||||
"chat_config" jsonb,
|
||||
"few_shots" jsonb,
|
||||
"model" text,
|
||||
"params" jsonb DEFAULT '{}'::jsonb,
|
||||
"provider" text,
|
||||
"system_role" text,
|
||||
"tts" jsonb,
|
||||
"created_at" timestamp with time zone DEFAULT now() NOT NULL,
|
||||
"updated_at" timestamp with time zone DEFAULT now() NOT NULL,
|
||||
CONSTRAINT "agents_slug_unique" UNIQUE("slug")
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "agents_tags" (
|
||||
"agent_id" text NOT NULL,
|
||||
"tag_id" integer NOT NULL,
|
||||
CONSTRAINT "agents_tags_agent_id_tag_id_pk" PRIMARY KEY("agent_id","tag_id")
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "agents_to_sessions" (
|
||||
"agent_id" text NOT NULL,
|
||||
"session_id" text NOT NULL,
|
||||
CONSTRAINT "agents_to_sessions_agent_id_session_id_pk" PRIMARY KEY("agent_id","session_id")
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "files" (
|
||||
"id" text PRIMARY KEY NOT NULL,
|
||||
"user_id" text NOT NULL,
|
||||
"file_type" varchar(255) NOT NULL,
|
||||
"name" text NOT NULL,
|
||||
"size" integer NOT NULL,
|
||||
"url" text NOT NULL,
|
||||
"metadata" jsonb,
|
||||
"created_at" timestamp with time zone DEFAULT now() NOT NULL,
|
||||
"updated_at" timestamp with time zone DEFAULT now() NOT NULL
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "files_to_agents" (
|
||||
"file_id" text NOT NULL,
|
||||
"agent_id" text NOT NULL,
|
||||
CONSTRAINT "files_to_agents_file_id_agent_id_pk" PRIMARY KEY("file_id","agent_id")
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "files_to_messages" (
|
||||
"file_id" text NOT NULL,
|
||||
"message_id" text NOT NULL,
|
||||
CONSTRAINT "files_to_messages_file_id_message_id_pk" PRIMARY KEY("file_id","message_id")
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "files_to_sessions" (
|
||||
"file_id" text NOT NULL,
|
||||
"session_id" text NOT NULL,
|
||||
CONSTRAINT "files_to_sessions_file_id_session_id_pk" PRIMARY KEY("file_id","session_id")
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "user_installed_plugins" (
|
||||
"user_id" text NOT NULL,
|
||||
"identifier" text NOT NULL,
|
||||
"type" text NOT NULL,
|
||||
"manifest" jsonb,
|
||||
"settings" jsonb,
|
||||
"custom_params" jsonb,
|
||||
"created_at" timestamp with time zone DEFAULT now() NOT NULL,
|
||||
"updated_at" timestamp with time zone DEFAULT now() NOT NULL,
|
||||
CONSTRAINT "user_installed_plugins_user_id_identifier_pk" PRIMARY KEY("user_id","identifier")
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "market" (
|
||||
"id" serial PRIMARY KEY NOT NULL,
|
||||
"agent_id" text,
|
||||
"plugin_id" integer,
|
||||
"type" text NOT NULL,
|
||||
"view" integer DEFAULT 0,
|
||||
"like" integer DEFAULT 0,
|
||||
"used" integer DEFAULT 0,
|
||||
"user_id" text NOT NULL,
|
||||
"created_at" timestamp with time zone DEFAULT now() NOT NULL,
|
||||
"updated_at" timestamp with time zone DEFAULT now() NOT NULL
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "message_plugins" (
|
||||
"id" text PRIMARY KEY NOT NULL,
|
||||
"tool_call_id" text,
|
||||
"type" text DEFAULT 'default',
|
||||
"api_name" text,
|
||||
"arguments" text,
|
||||
"identifier" text,
|
||||
"state" jsonb,
|
||||
"error" jsonb
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "message_tts" (
|
||||
"id" text PRIMARY KEY NOT NULL,
|
||||
"content_md5" text,
|
||||
"file_id" text,
|
||||
"voice" text
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "message_translates" (
|
||||
"id" text PRIMARY KEY NOT NULL,
|
||||
"content" text,
|
||||
"from" text,
|
||||
"to" text
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "messages" (
|
||||
"id" text PRIMARY KEY NOT NULL,
|
||||
"role" text NOT NULL,
|
||||
"content" text,
|
||||
"model" text,
|
||||
"provider" text,
|
||||
"favorite" boolean DEFAULT false,
|
||||
"error" jsonb,
|
||||
"tools" jsonb,
|
||||
"trace_id" text,
|
||||
"observation_id" text,
|
||||
"user_id" text NOT NULL,
|
||||
"session_id" text,
|
||||
"topic_id" text,
|
||||
"parent_id" text,
|
||||
"quota_id" text,
|
||||
"agent_id" text,
|
||||
"created_at" timestamp with time zone DEFAULT now() NOT NULL,
|
||||
"updated_at" timestamp with time zone DEFAULT now() NOT NULL
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "plugins" (
|
||||
"id" serial PRIMARY KEY NOT NULL,
|
||||
"identifier" text NOT NULL,
|
||||
"title" text NOT NULL,
|
||||
"description" text,
|
||||
"avatar" text,
|
||||
"author" text,
|
||||
"manifest" text NOT NULL,
|
||||
"locale" text NOT NULL,
|
||||
"created_at" timestamp with time zone DEFAULT now() NOT NULL,
|
||||
"updated_at" timestamp with time zone DEFAULT now() NOT NULL,
|
||||
CONSTRAINT "plugins_identifier_unique" UNIQUE("identifier")
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "plugins_tags" (
|
||||
"plugin_id" integer NOT NULL,
|
||||
"tag_id" integer NOT NULL,
|
||||
CONSTRAINT "plugins_tags_plugin_id_tag_id_pk" PRIMARY KEY("plugin_id","tag_id")
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "session_groups" (
|
||||
"id" text PRIMARY KEY NOT NULL,
|
||||
"name" text NOT NULL,
|
||||
"sort" integer,
|
||||
"user_id" text NOT NULL,
|
||||
"created_at" timestamp with time zone DEFAULT now() NOT NULL,
|
||||
"updated_at" timestamp with time zone DEFAULT now() NOT NULL
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "sessions" (
|
||||
"id" text PRIMARY KEY NOT NULL,
|
||||
"slug" varchar(100) NOT NULL,
|
||||
"title" text,
|
||||
"description" text,
|
||||
"avatar" text,
|
||||
"background_color" text,
|
||||
"type" text DEFAULT 'agent',
|
||||
"user_id" text NOT NULL,
|
||||
"group_id" text,
|
||||
"pinned" boolean DEFAULT false,
|
||||
"created_at" timestamp with time zone DEFAULT now() NOT NULL,
|
||||
"updated_at" timestamp with time zone DEFAULT now() NOT NULL
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "tags" (
|
||||
"id" serial PRIMARY KEY NOT NULL,
|
||||
"slug" text NOT NULL,
|
||||
"name" text,
|
||||
"user_id" text NOT NULL,
|
||||
"created_at" timestamp with time zone DEFAULT now() NOT NULL,
|
||||
"updated_at" timestamp with time zone DEFAULT now() NOT NULL,
|
||||
CONSTRAINT "tags_slug_unique" UNIQUE("slug")
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "topics" (
|
||||
"id" text PRIMARY KEY NOT NULL,
|
||||
"session_id" text,
|
||||
"user_id" text NOT NULL,
|
||||
"favorite" boolean DEFAULT false,
|
||||
"title" text,
|
||||
"created_at" timestamp with time zone DEFAULT now() NOT NULL,
|
||||
"updated_at" timestamp with time zone DEFAULT now() NOT NULL
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "user_settings" (
|
||||
"id" text PRIMARY KEY NOT NULL,
|
||||
"tts" jsonb,
|
||||
"key_vaults" text,
|
||||
"general" jsonb,
|
||||
"language_model" jsonb,
|
||||
"system_agent" jsonb,
|
||||
"default_agent" jsonb,
|
||||
"tool" jsonb
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE IF NOT EXISTS "users" (
|
||||
"id" text PRIMARY KEY NOT NULL,
|
||||
"username" text,
|
||||
"email" text,
|
||||
"avatar" text,
|
||||
"phone" text,
|
||||
"first_name" text,
|
||||
"last_name" text,
|
||||
"is_onboarded" boolean DEFAULT false,
|
||||
"clerk_created_at" timestamp with time zone,
|
||||
"preference" jsonb DEFAULT '{"guide":{"moveSettingsToAvatar":true,"topic":true},"telemetry":null,"useCmdEnterToSend":false}'::jsonb,
|
||||
"created_at" timestamp with time zone DEFAULT now() NOT NULL,
|
||||
"updated_at" timestamp with time zone DEFAULT now() NOT NULL,
|
||||
"key" text,
|
||||
CONSTRAINT "users_username_unique" UNIQUE("username")
|
||||
);
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "agents" ADD CONSTRAINT "agents_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "public"."users"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "agents_tags" ADD CONSTRAINT "agents_tags_agent_id_agents_id_fk" FOREIGN KEY ("agent_id") REFERENCES "public"."agents"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "agents_tags" ADD CONSTRAINT "agents_tags_tag_id_tags_id_fk" FOREIGN KEY ("tag_id") REFERENCES "public"."tags"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "agents_to_sessions" ADD CONSTRAINT "agents_to_sessions_agent_id_agents_id_fk" FOREIGN KEY ("agent_id") REFERENCES "public"."agents"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "agents_to_sessions" ADD CONSTRAINT "agents_to_sessions_session_id_sessions_id_fk" FOREIGN KEY ("session_id") REFERENCES "public"."sessions"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "files" ADD CONSTRAINT "files_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "public"."users"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "files_to_agents" ADD CONSTRAINT "files_to_agents_file_id_files_id_fk" FOREIGN KEY ("file_id") REFERENCES "public"."files"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "files_to_agents" ADD CONSTRAINT "files_to_agents_agent_id_agents_id_fk" FOREIGN KEY ("agent_id") REFERENCES "public"."agents"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "files_to_messages" ADD CONSTRAINT "files_to_messages_file_id_files_id_fk" FOREIGN KEY ("file_id") REFERENCES "public"."files"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "files_to_messages" ADD CONSTRAINT "files_to_messages_message_id_messages_id_fk" FOREIGN KEY ("message_id") REFERENCES "public"."messages"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "files_to_sessions" ADD CONSTRAINT "files_to_sessions_file_id_files_id_fk" FOREIGN KEY ("file_id") REFERENCES "public"."files"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "files_to_sessions" ADD CONSTRAINT "files_to_sessions_session_id_sessions_id_fk" FOREIGN KEY ("session_id") REFERENCES "public"."sessions"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "user_installed_plugins" ADD CONSTRAINT "user_installed_plugins_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "public"."users"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "market" ADD CONSTRAINT "market_agent_id_agents_id_fk" FOREIGN KEY ("agent_id") REFERENCES "public"."agents"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "market" ADD CONSTRAINT "market_plugin_id_plugins_id_fk" FOREIGN KEY ("plugin_id") REFERENCES "public"."plugins"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "market" ADD CONSTRAINT "market_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "public"."users"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "message_plugins" ADD CONSTRAINT "message_plugins_id_messages_id_fk" FOREIGN KEY ("id") REFERENCES "public"."messages"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "message_tts" ADD CONSTRAINT "message_tts_id_messages_id_fk" FOREIGN KEY ("id") REFERENCES "public"."messages"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "message_tts" ADD CONSTRAINT "message_tts_file_id_files_id_fk" FOREIGN KEY ("file_id") REFERENCES "public"."files"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "message_translates" ADD CONSTRAINT "message_translates_id_messages_id_fk" FOREIGN KEY ("id") REFERENCES "public"."messages"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "messages" ADD CONSTRAINT "messages_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "public"."users"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "messages" ADD CONSTRAINT "messages_session_id_sessions_id_fk" FOREIGN KEY ("session_id") REFERENCES "public"."sessions"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "messages" ADD CONSTRAINT "messages_topic_id_topics_id_fk" FOREIGN KEY ("topic_id") REFERENCES "public"."topics"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "messages" ADD CONSTRAINT "messages_parent_id_messages_id_fk" FOREIGN KEY ("parent_id") REFERENCES "public"."messages"("id") ON DELETE set null ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "messages" ADD CONSTRAINT "messages_quota_id_messages_id_fk" FOREIGN KEY ("quota_id") REFERENCES "public"."messages"("id") ON DELETE set null ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "messages" ADD CONSTRAINT "messages_agent_id_agents_id_fk" FOREIGN KEY ("agent_id") REFERENCES "public"."agents"("id") ON DELETE set null ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "plugins_tags" ADD CONSTRAINT "plugins_tags_plugin_id_plugins_id_fk" FOREIGN KEY ("plugin_id") REFERENCES "public"."plugins"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "plugins_tags" ADD CONSTRAINT "plugins_tags_tag_id_tags_id_fk" FOREIGN KEY ("tag_id") REFERENCES "public"."tags"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "session_groups" ADD CONSTRAINT "session_groups_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "public"."users"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "sessions" ADD CONSTRAINT "sessions_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "public"."users"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "sessions" ADD CONSTRAINT "sessions_group_id_session_groups_id_fk" FOREIGN KEY ("group_id") REFERENCES "public"."session_groups"("id") ON DELETE set null ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "tags" ADD CONSTRAINT "tags_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "public"."users"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "topics" ADD CONSTRAINT "topics_session_id_sessions_id_fk" FOREIGN KEY ("session_id") REFERENCES "public"."sessions"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "topics" ADD CONSTRAINT "topics_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "public"."users"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
DO $$ BEGIN
|
||||
ALTER TABLE "user_settings" ADD CONSTRAINT "user_settings_id_users_id_fk" FOREIGN KEY ("id") REFERENCES "public"."users"("id") ON DELETE cascade ON UPDATE no action;
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
--> statement-breakpoint
|
||||
CREATE INDEX IF NOT EXISTS "messages_created_at_idx" ON "messages" ("created_at");--> statement-breakpoint
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS "slug_user_id_unique" ON "sessions" ("slug","user_id");
|
||||
9
src/database/server/migrations/0001_add_client_id.sql
Normal file
9
src/database/server/migrations/0001_add_client_id.sql
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
ALTER TABLE "messages" ADD COLUMN "client_id" text;--> statement-breakpoint
|
||||
ALTER TABLE "session_groups" ADD COLUMN "client_id" text;--> statement-breakpoint
|
||||
ALTER TABLE "sessions" ADD COLUMN "client_id" text;--> statement-breakpoint
|
||||
ALTER TABLE "topics" ADD COLUMN "client_id" text;--> statement-breakpoint
|
||||
CREATE INDEX IF NOT EXISTS "messages_client_id_idx" ON "messages" ("client_id");--> statement-breakpoint
|
||||
ALTER TABLE "messages" ADD CONSTRAINT "messages_client_id_unique" UNIQUE("client_id");--> statement-breakpoint
|
||||
ALTER TABLE "session_groups" ADD CONSTRAINT "session_groups_client_id_unique" UNIQUE("client_id");--> statement-breakpoint
|
||||
ALTER TABLE "sessions" ADD CONSTRAINT "sessions_client_id_unique" UNIQUE("client_id");--> statement-breakpoint
|
||||
ALTER TABLE "topics" ADD CONSTRAINT "topics_client_id_unique" UNIQUE("client_id");
|
||||
9
src/database/server/migrations/0002_amusing_puma.sql
Normal file
9
src/database/server/migrations/0002_amusing_puma.sql
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
ALTER TABLE "messages" DROP CONSTRAINT "messages_client_id_unique";--> statement-breakpoint
|
||||
ALTER TABLE "session_groups" DROP CONSTRAINT "session_groups_client_id_unique";--> statement-breakpoint
|
||||
ALTER TABLE "sessions" DROP CONSTRAINT "sessions_client_id_unique";--> statement-breakpoint
|
||||
ALTER TABLE "topics" DROP CONSTRAINT "topics_client_id_unique";--> statement-breakpoint
|
||||
DROP INDEX IF EXISTS "messages_client_id_idx";--> statement-breakpoint
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS "message_client_id_user_unique" ON "messages" ("client_id","user_id");--> statement-breakpoint
|
||||
ALTER TABLE "session_groups" ADD CONSTRAINT "session_group_client_id_user_unique" UNIQUE("client_id","user_id");--> statement-breakpoint
|
||||
ALTER TABLE "sessions" ADD CONSTRAINT "sessions_client_id_user_id_unique" UNIQUE("client_id","user_id");--> statement-breakpoint
|
||||
ALTER TABLE "topics" ADD CONSTRAINT "topic_client_id_user_id_unique" UNIQUE("client_id","user_id");
|
||||
1583
src/database/server/migrations/meta/0000_snapshot.json
Normal file
1583
src/database/server/migrations/meta/0000_snapshot.json
Normal file
File diff suppressed because it is too large
Load diff
1636
src/database/server/migrations/meta/0001_snapshot.json
Normal file
1636
src/database/server/migrations/meta/0001_snapshot.json
Normal file
File diff suppressed because it is too large
Load diff
1630
src/database/server/migrations/meta/0002_snapshot.json
Normal file
1630
src/database/server/migrations/meta/0002_snapshot.json
Normal file
File diff suppressed because it is too large
Load diff
27
src/database/server/migrations/meta/_journal.json
Normal file
27
src/database/server/migrations/meta/_journal.json
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
{
|
||||
"dialect": "postgresql",
|
||||
"entries": [
|
||||
{
|
||||
"idx": 0,
|
||||
"version": "6",
|
||||
"when": 1716982944425,
|
||||
"tag": "0000_init",
|
||||
"breakpoints": true
|
||||
},
|
||||
{
|
||||
"idx": 1,
|
||||
"version": "6",
|
||||
"when": 1717153686544,
|
||||
"tag": "0001_add_client_id",
|
||||
"breakpoints": true
|
||||
},
|
||||
{
|
||||
"idx": 2,
|
||||
"version": "6",
|
||||
"when": 1717587734458,
|
||||
"tag": "0002_amusing_puma",
|
||||
"breakpoints": true
|
||||
}
|
||||
],
|
||||
"version": "6"
|
||||
}
|
||||
140
src/database/server/models/__tests__/file.test.ts
Normal file
140
src/database/server/models/__tests__/file.test.ts
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
// @vitest-environment node
|
||||
import { eq } from 'drizzle-orm';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { getTestDBInstance } from '@/database/server/core/dbForTest';
|
||||
|
||||
import { files, users } from '../../schemas/lobechat';
|
||||
import { FileModel } from '../file';
|
||||
|
||||
let serverDB = await getTestDBInstance();
|
||||
|
||||
vi.mock('@/database/server/core/db', async () => ({
|
||||
get serverDB() {
|
||||
return serverDB;
|
||||
},
|
||||
}));
|
||||
|
||||
const userId = 'file-model-test-user-id';
|
||||
const fileModel = new FileModel(userId);
|
||||
|
||||
beforeEach(async () => {
|
||||
await serverDB.delete(users).where(eq(users.id, userId));
|
||||
await serverDB.insert(users).values({ id: userId });
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await serverDB.delete(users).where(eq(users.id, userId));
|
||||
await serverDB.delete(files).where(eq(files.userId, userId));
|
||||
});
|
||||
|
||||
describe('FileModel', () => {
|
||||
it('should create a new file', async () => {
|
||||
const params = {
|
||||
name: 'test-file.txt',
|
||||
url: 'https://example.com/test-file.txt',
|
||||
size: 100,
|
||||
fileType: 'text/plain',
|
||||
};
|
||||
|
||||
const { id } = await fileModel.create(params);
|
||||
expect(id).toBeDefined();
|
||||
|
||||
const file = await serverDB.query.files.findFirst({ where: eq(files.id, id) });
|
||||
expect(file).toMatchObject({ ...params, userId });
|
||||
});
|
||||
|
||||
it('should delete a file by id', async () => {
|
||||
const { id } = await fileModel.create({
|
||||
name: 'test-file.txt',
|
||||
url: 'https://example.com/test-file.txt',
|
||||
size: 100,
|
||||
fileType: 'text/plain',
|
||||
});
|
||||
|
||||
await fileModel.delete(id);
|
||||
|
||||
const file = await serverDB.query.files.findFirst({ where: eq(files.id, id) });
|
||||
expect(file).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should clear all files for the user', async () => {
|
||||
await fileModel.create({
|
||||
name: 'test-file-1.txt',
|
||||
url: 'https://example.com/test-file-1.txt',
|
||||
size: 100,
|
||||
fileType: 'text/plain',
|
||||
});
|
||||
await fileModel.create({
|
||||
name: 'test-file-2.txt',
|
||||
url: 'https://example.com/test-file-2.txt',
|
||||
size: 200,
|
||||
fileType: 'text/plain',
|
||||
});
|
||||
|
||||
await fileModel.clear();
|
||||
|
||||
const userFiles = await serverDB.query.files.findMany({ where: eq(files.userId, userId) });
|
||||
expect(userFiles).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should query files for the user', async () => {
|
||||
await fileModel.create({
|
||||
name: 'test-file-1.txt',
|
||||
url: 'https://example.com/test-file-1.txt',
|
||||
size: 100,
|
||||
fileType: 'text/plain',
|
||||
});
|
||||
await fileModel.create({
|
||||
name: 'test-file-2.txt',
|
||||
url: 'https://example.com/test-file-2.txt',
|
||||
size: 200,
|
||||
fileType: 'text/plain',
|
||||
});
|
||||
|
||||
const userFiles = await fileModel.query();
|
||||
expect(userFiles).toHaveLength(2);
|
||||
expect(userFiles[0].name).toBe('test-file-2.txt');
|
||||
expect(userFiles[1].name).toBe('test-file-1.txt');
|
||||
});
|
||||
|
||||
it('should find a file by id', async () => {
|
||||
const { id } = await fileModel.create({
|
||||
name: 'test-file.txt',
|
||||
url: 'https://example.com/test-file.txt',
|
||||
size: 100,
|
||||
fileType: 'text/plain',
|
||||
});
|
||||
|
||||
const file = await fileModel.findById(id);
|
||||
expect(file).toMatchObject({
|
||||
id,
|
||||
name: 'test-file.txt',
|
||||
url: 'https://example.com/test-file.txt',
|
||||
size: 100,
|
||||
fileType: 'text/plain',
|
||||
userId,
|
||||
});
|
||||
});
|
||||
|
||||
it('should update a file', async () => {
|
||||
const { id } = await fileModel.create({
|
||||
name: 'test-file.txt',
|
||||
url: 'https://example.com/test-file.txt',
|
||||
size: 100,
|
||||
fileType: 'text/plain',
|
||||
});
|
||||
|
||||
await fileModel.update(id, { name: 'updated-test-file.txt', size: 200 });
|
||||
|
||||
const updatedFile = await serverDB.query.files.findFirst({ where: eq(files.id, id) });
|
||||
expect(updatedFile).toMatchObject({
|
||||
id,
|
||||
name: 'updated-test-file.txt',
|
||||
url: 'https://example.com/test-file.txt',
|
||||
size: 200,
|
||||
fileType: 'text/plain',
|
||||
userId,
|
||||
});
|
||||
});
|
||||
});
|
||||
847
src/database/server/models/__tests__/message.test.ts
Normal file
847
src/database/server/models/__tests__/message.test.ts
Normal file
|
|
@ -0,0 +1,847 @@
|
|||
import { eq } from 'drizzle-orm';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { getTestDBInstance } from '@/database/server/core/dbForTest';
|
||||
|
||||
import {
|
||||
files,
|
||||
filesToMessages,
|
||||
messagePlugins,
|
||||
messageTTS,
|
||||
messageTranslates,
|
||||
messages,
|
||||
sessions,
|
||||
topics,
|
||||
users,
|
||||
} from '../../schemas/lobechat';
|
||||
import { MessageModel } from '../message';
|
||||
|
||||
let serverDB = await getTestDBInstance();
|
||||
|
||||
vi.mock('@/database/server/core/db', async () => ({
|
||||
get serverDB() {
|
||||
return serverDB;
|
||||
},
|
||||
}));
|
||||
|
||||
const userId = 'message-db';
|
||||
const messageModel = new MessageModel(userId);
|
||||
|
||||
beforeEach(async () => {
|
||||
// 在每个测试用例之前,清空表
|
||||
await serverDB.transaction(async (trx) => {
|
||||
await trx.delete(users);
|
||||
await trx.insert(users).values([{ id: userId }, { id: '456' }]);
|
||||
|
||||
await trx.insert(sessions).values([
|
||||
// { id: 'session1', userId },
|
||||
// { id: 'session2', userId },
|
||||
{ id: '1', userId },
|
||||
]);
|
||||
await trx.insert(files).values({
|
||||
id: 'f1',
|
||||
userId: userId,
|
||||
url: 'abc',
|
||||
name: 'file-1',
|
||||
fileType: 'image/png',
|
||||
size: 1000,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
// 在每个测试用例之后,清空表
|
||||
await serverDB.delete(users);
|
||||
});
|
||||
|
||||
describe('MessageModel', () => {
|
||||
describe('query', () => {
|
||||
it('should query messages by user ID', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(messages).values([
|
||||
{ id: '1', userId, role: 'user', content: 'message 1', createdAt: new Date('2023-01-01') },
|
||||
{ id: '2', userId, role: 'user', content: 'message 2', createdAt: new Date('2023-02-01') },
|
||||
{
|
||||
id: '3',
|
||||
userId: '456',
|
||||
role: 'user',
|
||||
content: 'message 3',
|
||||
createdAt: new Date('2023-03-01'),
|
||||
},
|
||||
]);
|
||||
|
||||
// 调用 query 方法
|
||||
const result = await messageModel.query();
|
||||
|
||||
// 断言结果
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result[0].id).toBe('1');
|
||||
expect(result[1].id).toBe('2');
|
||||
});
|
||||
|
||||
it('should return empty messages if not match the user ID', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(messages).values([
|
||||
{ id: '1', userId: '456', role: 'user', content: '1', createdAt: new Date('2023-01-01') },
|
||||
{ id: '2', userId: '456', role: 'user', content: '2', createdAt: new Date('2023-02-01') },
|
||||
{ id: '3', userId: '456', role: 'user', content: '3', createdAt: new Date('2023-03-01') },
|
||||
]);
|
||||
|
||||
// 调用 query 方法
|
||||
const result = await messageModel.query();
|
||||
|
||||
// 断言结果
|
||||
expect(result).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should query messages with pagination', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(messages).values([
|
||||
{ id: '1', userId, role: 'user', content: 'message 1', createdAt: new Date('2023-01-01') },
|
||||
{ id: '2', userId, role: 'user', content: 'message 2', createdAt: new Date('2023-02-01') },
|
||||
{ id: '3', userId, role: 'user', content: 'message 3', createdAt: new Date('2023-03-01') },
|
||||
]);
|
||||
|
||||
// 测试分页
|
||||
const result1 = await messageModel.query({ current: 0, pageSize: 2 });
|
||||
expect(result1).toHaveLength(2);
|
||||
|
||||
const result2 = await messageModel.query({ current: 1, pageSize: 1 });
|
||||
expect(result2).toHaveLength(1);
|
||||
expect(result2[0].id).toBe('2');
|
||||
});
|
||||
|
||||
it('should filter messages by sessionId', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(sessions).values([
|
||||
{ id: 'session1', userId },
|
||||
{ id: 'session2', userId },
|
||||
]);
|
||||
await serverDB.insert(messages).values([
|
||||
{
|
||||
id: '1',
|
||||
userId,
|
||||
role: 'user',
|
||||
sessionId: 'session1',
|
||||
content: 'message 1',
|
||||
createdAt: new Date('2022-02-01'),
|
||||
},
|
||||
{
|
||||
id: '2',
|
||||
userId,
|
||||
role: 'user',
|
||||
sessionId: 'session1',
|
||||
content: 'message 2',
|
||||
createdAt: new Date('2023-02-02'),
|
||||
},
|
||||
{ id: '3', userId, role: 'user', sessionId: 'session2', content: 'message 3' },
|
||||
]);
|
||||
|
||||
// 测试根据 sessionId 过滤
|
||||
const result = await messageModel.query({ sessionId: 'session1' });
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result[0].id).toBe('1');
|
||||
expect(result[1].id).toBe('2');
|
||||
});
|
||||
|
||||
it('should filter messages by topicId', async () => {
|
||||
// 创建测试数据
|
||||
const sessionId = 'session1';
|
||||
await serverDB.insert(sessions).values([{ id: sessionId, userId }]);
|
||||
const topicId = 'topic1';
|
||||
await serverDB.insert(topics).values([
|
||||
{ id: topicId, sessionId, userId },
|
||||
{ id: 'topic2', sessionId, userId },
|
||||
]);
|
||||
|
||||
await serverDB.insert(messages).values([
|
||||
{ id: '1', userId, role: 'user', topicId, content: '1', createdAt: new Date('2022-04-01') },
|
||||
{ id: '2', userId, role: 'user', topicId, content: '2', createdAt: new Date('2023-02-01') },
|
||||
{ id: '3', userId, role: 'user', topicId: 'topic2', content: 'message 3' },
|
||||
]);
|
||||
|
||||
// 测试根据 topicId 过滤
|
||||
const result = await messageModel.query({ topicId });
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result[0].id).toBe('1');
|
||||
expect(result[1].id).toBe('2');
|
||||
});
|
||||
|
||||
it('should query messages with join', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.transaction(async (trx) => {
|
||||
await trx.insert(messages).values([
|
||||
{
|
||||
id: '1',
|
||||
userId,
|
||||
role: 'user',
|
||||
content: 'message 1',
|
||||
createdAt: new Date('2023-01-01'),
|
||||
},
|
||||
{
|
||||
id: '2',
|
||||
userId,
|
||||
role: 'user',
|
||||
content: 'message 2',
|
||||
createdAt: new Date('2023-02-01'),
|
||||
},
|
||||
{
|
||||
id: '3',
|
||||
userId: '456',
|
||||
role: 'user',
|
||||
content: 'message 3',
|
||||
createdAt: new Date('2023-03-01'),
|
||||
},
|
||||
]);
|
||||
await trx.insert(files).values([
|
||||
{ id: 'f-0', url: 'abc', name: 'file-1', userId, fileType: 'image/png', size: 1000 },
|
||||
{ id: 'f-1', url: 'abc', name: 'file-1', userId, fileType: 'image/png', size: 100 },
|
||||
{ id: 'f-3', url: 'abc', name: 'file-3', userId, fileType: 'image/png', size: 400 },
|
||||
]);
|
||||
await trx
|
||||
.insert(messageTTS)
|
||||
.values([{ id: '1' }, { id: '2', voice: 'a', fileId: 'f-1', contentMd5: 'abc' }]);
|
||||
|
||||
await trx.insert(filesToMessages).values([
|
||||
{ fileId: 'f-0', messageId: '1' },
|
||||
{ fileId: 'f-3', messageId: '1' },
|
||||
]);
|
||||
});
|
||||
|
||||
// 调用 query 方法
|
||||
const result = await messageModel.query();
|
||||
|
||||
// 断言结果
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result[0].id).toBe('1');
|
||||
expect(result[0].files).toEqual(['f-0', 'f-3']);
|
||||
|
||||
expect(result[1].id).toBe('2');
|
||||
expect(result[1].files).toEqual([]);
|
||||
});
|
||||
|
||||
it('should include translate, tts and other extra fields in query result', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.transaction(async (trx) => {
|
||||
await trx.insert(messages).values([
|
||||
{
|
||||
id: '1',
|
||||
userId,
|
||||
role: 'user',
|
||||
content: 'message 1',
|
||||
createdAt: new Date('2023-01-01'),
|
||||
},
|
||||
]);
|
||||
await trx
|
||||
.insert(messageTranslates)
|
||||
.values([{ id: '1', content: 'translated', from: 'en', to: 'zh' }]);
|
||||
await trx
|
||||
.insert(messageTTS)
|
||||
.values([{ id: '1', voice: 'voice1', fileId: 'f1', contentMd5: 'md5' }]);
|
||||
});
|
||||
|
||||
// 调用 query 方法
|
||||
const result = await messageModel.query();
|
||||
|
||||
// 断言结果
|
||||
expect(result[0].extra.translate).toEqual({ content: 'translated', from: 'en', to: 'zh' });
|
||||
// TODO: 确认是否需要包含 tts 字段
|
||||
expect(result[0].extra.tts).toEqual({
|
||||
// contentMd5: 'md5', file: 'f1', voice: 'voice1'
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle edge cases of pagination parameters', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(messages).values([
|
||||
{ id: '1', userId, role: 'user', content: 'message 1' },
|
||||
{ id: '2', userId, role: 'user', content: 'message 2' },
|
||||
{ id: '3', userId, role: 'user', content: 'message 3' },
|
||||
]);
|
||||
|
||||
// 测试 current 和 pageSize 的边界情况
|
||||
const result1 = await messageModel.query({ current: 0, pageSize: 2 });
|
||||
expect(result1).toHaveLength(2);
|
||||
|
||||
const result2 = await messageModel.query({ current: 1, pageSize: 2 });
|
||||
expect(result2).toHaveLength(1);
|
||||
|
||||
const result3 = await messageModel.query({ current: 2, pageSize: 2 });
|
||||
expect(result3).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('queryAll', () => {
|
||||
it('should return all messages belonging to the user in ascending order', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(messages).values([
|
||||
{
|
||||
id: '1',
|
||||
userId,
|
||||
role: 'user',
|
||||
content: 'message 1',
|
||||
createdAt: new Date('2023-01-01'),
|
||||
},
|
||||
{
|
||||
id: '2',
|
||||
userId,
|
||||
role: 'user',
|
||||
content: 'message 2',
|
||||
createdAt: new Date('2023-02-01'),
|
||||
},
|
||||
{
|
||||
id: '3',
|
||||
userId: '456',
|
||||
role: 'user',
|
||||
content: 'message 3',
|
||||
createdAt: new Date('2023-03-01'),
|
||||
},
|
||||
]);
|
||||
|
||||
// 调用 queryAll 方法
|
||||
const result = await messageModel.queryAll();
|
||||
|
||||
// 断言结果
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result[0].id).toBe('1');
|
||||
expect(result[1].id).toBe('2');
|
||||
});
|
||||
});
|
||||
|
||||
describe('findById', () => {
|
||||
it('should find message by ID', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(messages).values([
|
||||
{ id: '1', userId, role: 'user', content: 'message 1' },
|
||||
{ id: '2', userId: '456', role: 'user', content: 'message 2' },
|
||||
]);
|
||||
|
||||
// 调用 findById 方法
|
||||
const result = await messageModel.findById('1');
|
||||
|
||||
// 断言结果
|
||||
expect(result?.id).toBe('1');
|
||||
expect(result?.content).toBe('message 1');
|
||||
});
|
||||
|
||||
it('should return undefined if message does not belong to user', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB
|
||||
.insert(messages)
|
||||
.values([{ id: '1', userId: '456', role: 'user', content: 'message 1' }]);
|
||||
|
||||
// 调用 findById 方法
|
||||
const result = await messageModel.findById('1');
|
||||
|
||||
// 断言结果
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('queryBySessionId', () => {
|
||||
it('should query messages by sessionId', async () => {
|
||||
// 创建测试数据
|
||||
const sessionId = 'session1';
|
||||
await serverDB.insert(sessions).values([
|
||||
{ id: 'session1', userId },
|
||||
{ id: 'session2', userId },
|
||||
]);
|
||||
await serverDB.insert(messages).values([
|
||||
{
|
||||
id: '1',
|
||||
userId,
|
||||
role: 'user',
|
||||
sessionId,
|
||||
content: 'message 1',
|
||||
createdAt: new Date('2022-01-01'),
|
||||
},
|
||||
{
|
||||
id: '2',
|
||||
userId,
|
||||
role: 'user',
|
||||
sessionId,
|
||||
content: 'message 2',
|
||||
createdAt: new Date('2023-02-01'),
|
||||
},
|
||||
{ id: '3', userId, role: 'user', sessionId: 'session2', content: 'message 3' },
|
||||
]);
|
||||
|
||||
// 调用 queryBySessionId 方法
|
||||
const result = await messageModel.queryBySessionId(sessionId);
|
||||
|
||||
// 断言结果
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result[0].id).toBe('1');
|
||||
expect(result[1].id).toBe('2');
|
||||
});
|
||||
});
|
||||
|
||||
describe('queryByKeyWord', () => {
|
||||
it('should query messages by keyword', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(messages).values([
|
||||
{ id: '1', userId, role: 'user', content: 'apple', createdAt: new Date('2022-02-01') },
|
||||
{ id: '2', userId, role: 'user', content: 'banana' },
|
||||
{ id: '3', userId, role: 'user', content: 'pear' },
|
||||
{ id: '4', userId, role: 'user', content: 'apple pie', createdAt: new Date('2024-02-01') },
|
||||
]);
|
||||
|
||||
// 测试查询包含特定关键字的消息
|
||||
const result = await messageModel.queryByKeyword('apple');
|
||||
|
||||
// 断言结果
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result[0].id).toBe('4');
|
||||
expect(result[1].id).toBe('1');
|
||||
});
|
||||
|
||||
it('should return empty array when keyword is empty', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(messages).values([
|
||||
{ id: '1', userId, role: 'user', content: 'apple' },
|
||||
{ id: '2', userId, role: 'user', content: 'banana' },
|
||||
{ id: '3', userId, role: 'user', content: 'pear' },
|
||||
{ id: '4', userId, role: 'user', content: 'apple pie' },
|
||||
]);
|
||||
|
||||
// 测试当关键字为空时返回空数组
|
||||
const result = await messageModel.queryByKeyword('');
|
||||
|
||||
// 断言结果
|
||||
expect(result).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('createMessage', () => {
|
||||
it('should create a new message', async () => {
|
||||
// 调用 createMessage 方法
|
||||
await messageModel.create({ role: 'user', content: 'new message', sessionId: '1' });
|
||||
|
||||
// 断言结果
|
||||
const result = await serverDB
|
||||
.select()
|
||||
.from(messages)
|
||||
.where(eq(messages.userId, userId))
|
||||
.execute();
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].content).toBe('new message');
|
||||
});
|
||||
|
||||
it('should create a message', async () => {
|
||||
const sessionId = 'session1';
|
||||
await serverDB.insert(sessions).values([{ id: sessionId, userId }]);
|
||||
|
||||
const result = await messageModel.create({
|
||||
content: 'message 1',
|
||||
role: 'user',
|
||||
sessionId: 'session1',
|
||||
});
|
||||
|
||||
expect(result.id).toBeDefined();
|
||||
expect(result.content).toBe('message 1');
|
||||
expect(result.role).toBe('user');
|
||||
expect(result.sessionId).toBe('session1');
|
||||
expect(result.userId).toBe(userId);
|
||||
});
|
||||
|
||||
it('should generate message ID automatically', async () => {
|
||||
// 调用 createMessage 方法
|
||||
await messageModel.create({
|
||||
role: 'user',
|
||||
content: 'new message',
|
||||
sessionId: '1',
|
||||
});
|
||||
|
||||
// 断言结果
|
||||
const result = await serverDB
|
||||
.select()
|
||||
.from(messages)
|
||||
.where(eq(messages.userId, userId))
|
||||
.execute();
|
||||
expect(result[0].id).toBeDefined();
|
||||
expect(result[0].id).toHaveLength(18);
|
||||
});
|
||||
|
||||
it('should create a tool message and insert into messagePlugins table', async () => {
|
||||
// 调用 create 方法
|
||||
const result = await messageModel.create({
|
||||
content: 'message 1',
|
||||
role: 'tool',
|
||||
sessionId: '1',
|
||||
tool_call_id: 'tool1',
|
||||
plugin: {
|
||||
apiName: 'api1',
|
||||
arguments: 'arg1',
|
||||
identifier: 'plugin1',
|
||||
type: 'default',
|
||||
},
|
||||
});
|
||||
|
||||
// 断言结果
|
||||
expect(result.id).toBeDefined();
|
||||
expect(result.content).toBe('message 1');
|
||||
expect(result.role).toBe('tool');
|
||||
expect(result.sessionId).toBe('1');
|
||||
|
||||
const pluginResult = await serverDB
|
||||
.select()
|
||||
.from(messagePlugins)
|
||||
.where(eq(messagePlugins.id, result.id))
|
||||
.execute();
|
||||
expect(pluginResult).toHaveLength(1);
|
||||
expect(pluginResult[0].identifier).toBe('plugin1');
|
||||
});
|
||||
});
|
||||
|
||||
describe('batchCreateMessages', () => {
|
||||
it('should batch create messages', async () => {
|
||||
// 准备测试数据
|
||||
const newMessages = [
|
||||
{ id: '1', role: 'user', content: 'message 1' },
|
||||
{ id: '2', role: 'assistant', content: 'message 2' },
|
||||
];
|
||||
|
||||
// 调用 batchCreateMessages 方法
|
||||
await messageModel.batchCreate(newMessages);
|
||||
|
||||
// 断言结果
|
||||
const result = await serverDB
|
||||
.select()
|
||||
.from(messages)
|
||||
.where(eq(messages.userId, userId))
|
||||
.execute();
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result[0].content).toBe('message 1');
|
||||
expect(result[1].content).toBe('message 2');
|
||||
});
|
||||
});
|
||||
|
||||
describe('updateMessage', () => {
|
||||
it('should update message content', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB
|
||||
.insert(messages)
|
||||
.values([{ id: '1', userId, role: 'user', content: 'message 1' }]);
|
||||
|
||||
// 调用 updateMessage 方法
|
||||
await messageModel.update('1', { content: 'updated message' });
|
||||
|
||||
// 断言结果
|
||||
const result = await serverDB.select().from(messages).where(eq(messages.id, '1')).execute();
|
||||
expect(result[0].content).toBe('updated message');
|
||||
});
|
||||
|
||||
it('should only update messages belonging to the user', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB
|
||||
.insert(messages)
|
||||
.values([{ id: '1', userId: '456', role: 'user', content: 'message 1' }]);
|
||||
|
||||
// 调用 updateMessage 方法
|
||||
await messageModel.update('1', { content: 'updated message' });
|
||||
|
||||
// 断言结果
|
||||
const result = await serverDB.select().from(messages).where(eq(messages.id, '1')).execute();
|
||||
expect(result[0].content).toBe('message 1');
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteMessage', () => {
|
||||
it('should delete a message', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB
|
||||
.insert(messages)
|
||||
.values([{ id: '1', userId, role: 'user', content: 'message 1' }]);
|
||||
|
||||
// 调用 deleteMessage 方法
|
||||
await messageModel.deleteMessage('1');
|
||||
|
||||
// 断言结果
|
||||
const result = await serverDB.select().from(messages).where(eq(messages.id, '1')).execute();
|
||||
expect(result).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should delete a message with tool calls', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.transaction(async (trx) => {
|
||||
await trx.insert(messages).values([
|
||||
{ id: '1', userId, role: 'user', content: 'message 1', tools: [{ id: 'tool1' }] },
|
||||
{ id: '2', userId, role: 'tool', content: 'message 1' },
|
||||
]);
|
||||
await trx
|
||||
.insert(messagePlugins)
|
||||
.values([{ id: '2', toolCallId: 'tool1', identifier: 'plugin-1' }]);
|
||||
});
|
||||
|
||||
// 调用 deleteMessage 方法
|
||||
await messageModel.deleteMessage('1');
|
||||
|
||||
// 断言结果
|
||||
const result = await serverDB.select().from(messages).where(eq(messages.id, '1')).execute();
|
||||
expect(result).toHaveLength(0);
|
||||
|
||||
const result2 = await serverDB
|
||||
.select()
|
||||
.from(messagePlugins)
|
||||
.where(eq(messagePlugins.id, '2'))
|
||||
.execute();
|
||||
|
||||
expect(result2).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should only delete messages belonging to the user', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB
|
||||
.insert(messages)
|
||||
.values([{ id: '1', userId: '456', role: 'user', content: 'message 1' }]);
|
||||
|
||||
// 调用 deleteMessage 方法
|
||||
await messageModel.deleteMessage('1');
|
||||
|
||||
// 断言结果
|
||||
const result = await serverDB.select().from(messages).where(eq(messages.id, '1')).execute();
|
||||
expect(result).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteAllMessages', () => {
|
||||
it('should delete all messages belonging to the user', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(messages).values([
|
||||
{ id: '1', userId, role: 'user', content: 'message 1' },
|
||||
{ id: '2', userId, role: 'user', content: 'message 2' },
|
||||
{ id: '3', userId: '456', role: 'user', content: 'message 3' },
|
||||
]);
|
||||
|
||||
// 调用 deleteAllMessages 方法
|
||||
await messageModel.deleteAllMessages();
|
||||
|
||||
// 断言结果
|
||||
const result = await serverDB
|
||||
.select()
|
||||
.from(messages)
|
||||
.where(eq(messages.userId, userId))
|
||||
.execute();
|
||||
expect(result).toHaveLength(0);
|
||||
|
||||
const otherResult = await serverDB
|
||||
.select()
|
||||
.from(messages)
|
||||
.where(eq(messages.userId, '456'))
|
||||
.execute();
|
||||
expect(otherResult).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('updatePluginState', () => {
|
||||
it('should update the state field in messagePlugins table', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(messages).values({ id: '1', content: 'abc', role: 'user', userId });
|
||||
await serverDB
|
||||
.insert(messagePlugins)
|
||||
.values([
|
||||
{ id: '1', toolCallId: 'tool1', identifier: 'plugin1', state: { key1: 'value1' } },
|
||||
]);
|
||||
|
||||
// 调用 updatePluginState 方法
|
||||
await messageModel.updatePluginState('1', { key2: 'value2' });
|
||||
|
||||
// 断言结果
|
||||
const result = await serverDB
|
||||
.select()
|
||||
.from(messagePlugins)
|
||||
.where(eq(messagePlugins.id, '1'))
|
||||
.execute();
|
||||
expect(result[0].state).toEqual({ key1: 'value1', key2: 'value2' });
|
||||
});
|
||||
|
||||
it('should throw an error if plugin does not exist', async () => {
|
||||
// 调用 updatePluginState 方法
|
||||
await expect(messageModel.updatePluginState('1', { key: 'value' })).rejects.toThrowError(
|
||||
'Plugin not found',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('updateTranslate', () => {
|
||||
it('should insert a new record if message does not exist in messageTranslates table', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB
|
||||
.insert(messages)
|
||||
.values([{ id: '1', userId, role: 'user', content: 'message 1' }]);
|
||||
|
||||
// 调用 updateTranslate 方法
|
||||
await messageModel.updateTranslate('1', {
|
||||
content: 'translated message 1',
|
||||
from: 'en',
|
||||
to: 'zh',
|
||||
});
|
||||
|
||||
// 断言结果
|
||||
const result = await serverDB
|
||||
.select()
|
||||
.from(messageTranslates)
|
||||
.where(eq(messageTranslates.id, '1'))
|
||||
.execute();
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].content).toBe('translated message 1');
|
||||
});
|
||||
|
||||
it('should update the corresponding fields if message exists in messageTranslates table', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.transaction(async (trx) => {
|
||||
await trx
|
||||
.insert(messages)
|
||||
.values([{ id: '1', userId, role: 'user', content: 'message 1' }]);
|
||||
await trx
|
||||
.insert(messageTranslates)
|
||||
.values([{ id: '1', content: 'translated message 1', from: 'en', to: 'zh' }]);
|
||||
});
|
||||
|
||||
// 调用 updateTranslate 方法
|
||||
await messageModel.updateTranslate('1', { content: 'updated translated message 1' });
|
||||
|
||||
// 断言结果
|
||||
const result = await serverDB
|
||||
.select()
|
||||
.from(messageTranslates)
|
||||
.where(eq(messageTranslates.id, '1'))
|
||||
.execute();
|
||||
expect(result[0].content).toBe('updated translated message 1');
|
||||
});
|
||||
});
|
||||
|
||||
describe('updateTTS', () => {
|
||||
it('should insert a new record if message does not exist in messageTTS table', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB
|
||||
.insert(messages)
|
||||
.values([{ id: '1', userId, role: 'user', content: 'message 1' }]);
|
||||
|
||||
// 调用 updateTTS 方法
|
||||
await messageModel.updateTTS('1', { contentMd5: 'md5', file: 'f1', voice: 'voice1' });
|
||||
|
||||
// 断言结果
|
||||
const result = await serverDB
|
||||
.select()
|
||||
.from(messageTTS)
|
||||
.where(eq(messageTTS.id, '1'))
|
||||
.execute();
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].voice).toBe('voice1');
|
||||
});
|
||||
|
||||
it('should update the corresponding fields if message exists in messageTTS table', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.transaction(async (trx) => {
|
||||
await trx
|
||||
.insert(messages)
|
||||
.values([{ id: '1', userId, role: 'user', content: 'message 1' }]);
|
||||
await trx
|
||||
.insert(messageTTS)
|
||||
.values([{ id: '1', contentMd5: 'md5', fileId: 'f1', voice: 'voice1' }]);
|
||||
});
|
||||
|
||||
// 调用 updateTTS 方法
|
||||
await messageModel.updateTTS('1', { voice: 'updated voice1' });
|
||||
|
||||
// 断言结果
|
||||
const result = await serverDB
|
||||
.select()
|
||||
.from(messageTTS)
|
||||
.where(eq(messageTTS.id, '1'))
|
||||
.execute();
|
||||
expect(result[0].voice).toBe('updated voice1');
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteMessageTranslate', () => {
|
||||
it('should delete the message translate record', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(messages).values([{ id: '1', role: 'abc', userId }]);
|
||||
await serverDB.insert(messageTranslates).values([{ id: '1' }]);
|
||||
|
||||
// 调用 deleteMessageTranslate 方法
|
||||
await messageModel.deleteMessageTranslate('1');
|
||||
|
||||
// 断言结果
|
||||
const result = await serverDB
|
||||
.select()
|
||||
.from(messageTranslates)
|
||||
.where(eq(messageTranslates.id, '1'))
|
||||
.execute();
|
||||
expect(result).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteMessageTTS', () => {
|
||||
it('should delete the message TTS record', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(messages).values([{ id: '1', role: 'abc', userId }]);
|
||||
await serverDB.insert(messageTTS).values([{ id: '1' }]);
|
||||
|
||||
// 调用 deleteMessageTTS 方法
|
||||
await messageModel.deleteMessageTTS('1');
|
||||
|
||||
// 断言结果
|
||||
const result = await serverDB
|
||||
.select()
|
||||
.from(messageTTS)
|
||||
.where(eq(messageTTS.id, '1'))
|
||||
.execute();
|
||||
expect(result).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('count', () => {
|
||||
it('should return the count of messages belonging to the user', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(messages).values([
|
||||
{ id: '1', userId, role: 'user', content: 'message 1' },
|
||||
{ id: '2', userId, role: 'user', content: 'message 2' },
|
||||
{ id: '3', userId: '456', role: 'user', content: 'message 3' },
|
||||
]);
|
||||
|
||||
// 调用 count 方法
|
||||
const result = await messageModel.count();
|
||||
|
||||
// 断言结果
|
||||
expect(result).toBe(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('countToday', () => {
|
||||
it('should return the count of messages created today', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(messages).values([
|
||||
{
|
||||
id: '1',
|
||||
userId,
|
||||
role: 'user',
|
||||
content: 'message 1',
|
||||
createdAt: new Date(),
|
||||
},
|
||||
{
|
||||
id: '2',
|
||||
userId,
|
||||
role: 'user',
|
||||
content: 'message 2',
|
||||
createdAt: new Date(),
|
||||
},
|
||||
{
|
||||
id: '3',
|
||||
userId,
|
||||
role: 'user',
|
||||
content: 'message 3',
|
||||
createdAt: new Date('2023-01-01'),
|
||||
},
|
||||
]);
|
||||
|
||||
// 调用 countToday 方法
|
||||
const result = await messageModel.countToday();
|
||||
|
||||
// 断言结果
|
||||
expect(result).toBe(2);
|
||||
});
|
||||
});
|
||||
});
|
||||
172
src/database/server/models/__tests__/plugin.test.ts
Normal file
172
src/database/server/models/__tests__/plugin.test.ts
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
// @vitest-environment node
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { getTestDBInstance } from '@/database/server/core/dbForTest';
|
||||
|
||||
import { NewInstalledPlugin, installedPlugins, users } from '../../schemas/lobechat';
|
||||
import { PluginModel } from '../plugin';
|
||||
|
||||
let serverDB = await getTestDBInstance();
|
||||
|
||||
vi.mock('@/database/server/core/db', async () => ({
|
||||
get serverDB() {
|
||||
return serverDB;
|
||||
},
|
||||
}));
|
||||
|
||||
const userId = 'plugin-db';
|
||||
const pluginModel = new PluginModel(userId);
|
||||
|
||||
beforeEach(async () => {
|
||||
await serverDB.transaction(async (trx) => {
|
||||
await trx.delete(users);
|
||||
await trx.insert(users).values([{ id: userId }, { id: '456' }]);
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await serverDB.delete(users);
|
||||
});
|
||||
|
||||
describe('PluginModel', () => {
|
||||
describe('create', () => {
|
||||
it('should create a new installed plugin', async () => {
|
||||
const params = {
|
||||
type: 'plugin',
|
||||
identifier: 'test-plugin',
|
||||
manifest: { identifier: 'Test Plugin' },
|
||||
customParams: { manifestUrl: 'abc123' },
|
||||
} as NewInstalledPlugin;
|
||||
|
||||
const result = await pluginModel.create(params);
|
||||
|
||||
expect(result.userId).toBe(userId);
|
||||
expect(result.type).toBe(params.type);
|
||||
expect(result.identifier).toBe(params.identifier);
|
||||
expect(result.manifest).toEqual(params.manifest);
|
||||
expect(result.customParams).toEqual(params.customParams);
|
||||
});
|
||||
});
|
||||
|
||||
describe('delete', () => {
|
||||
it('should delete an installed plugin by identifier', async () => {
|
||||
await serverDB.insert(installedPlugins).values({
|
||||
userId,
|
||||
type: 'plugin',
|
||||
identifier: 'test-plugin',
|
||||
manifest: { name: 'Test Plugin' },
|
||||
} as unknown as NewInstalledPlugin);
|
||||
|
||||
await pluginModel.delete('test-plugin');
|
||||
|
||||
const result = await serverDB.select().from(installedPlugins);
|
||||
expect(result).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteAll', () => {
|
||||
it('should delete all installed plugins for the user', async () => {
|
||||
await serverDB.insert(installedPlugins).values([
|
||||
{
|
||||
userId,
|
||||
type: 'plugin',
|
||||
identifier: 'test-plugin-1',
|
||||
manifest: { name: 'Test Plugin 1' },
|
||||
},
|
||||
{
|
||||
userId,
|
||||
type: 'plugin',
|
||||
identifier: 'test-plugin-2',
|
||||
manifest: { name: 'Test Plugin 2' },
|
||||
},
|
||||
{
|
||||
userId: '456',
|
||||
type: 'plugin',
|
||||
identifier: 'test-plugin-3',
|
||||
manifest: { name: 'Test Plugin 3' },
|
||||
},
|
||||
] as unknown as NewInstalledPlugin[]);
|
||||
|
||||
await pluginModel.deleteAll();
|
||||
|
||||
const result = await serverDB.select().from(installedPlugins);
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].userId).toBe('456');
|
||||
});
|
||||
});
|
||||
|
||||
describe('query', () => {
|
||||
it('should query installed plugins for the user', async () => {
|
||||
await serverDB.insert(installedPlugins).values([
|
||||
{
|
||||
userId,
|
||||
type: 'plugin',
|
||||
identifier: 'test-plugin-1',
|
||||
manifest: { name: 'Test Plugin 1' },
|
||||
createdAt: new Date('2023-01-01'),
|
||||
},
|
||||
{
|
||||
userId,
|
||||
type: 'plugin',
|
||||
identifier: 'test-plugin-2',
|
||||
manifest: { name: 'Test Plugin 2' },
|
||||
createdAt: new Date('2023-02-01'),
|
||||
},
|
||||
{
|
||||
userId: '456',
|
||||
type: 'plugin',
|
||||
identifier: 'test-plugin-3',
|
||||
manifest: { name: 'Test Plugin 3' },
|
||||
createdAt: new Date('2023-03-01'),
|
||||
},
|
||||
] as unknown as NewInstalledPlugin[]);
|
||||
|
||||
const result = await pluginModel.query();
|
||||
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result[0].identifier).toBe('test-plugin-2');
|
||||
expect(result[1].identifier).toBe('test-plugin-1');
|
||||
});
|
||||
});
|
||||
|
||||
describe('findById', () => {
|
||||
it('should find an installed plugin by identifier', async () => {
|
||||
await serverDB.insert(installedPlugins).values([
|
||||
{
|
||||
userId,
|
||||
type: 'plugin',
|
||||
identifier: 'test-plugin-1',
|
||||
manifest: { name: 'Test Plugin 1' },
|
||||
},
|
||||
{
|
||||
userId: '456',
|
||||
type: 'plugin',
|
||||
identifier: 'test-plugin-2',
|
||||
manifest: { name: 'Test Plugin 2' },
|
||||
},
|
||||
] as unknown as NewInstalledPlugin[]);
|
||||
|
||||
const result = await pluginModel.findById('test-plugin-1');
|
||||
|
||||
expect(result?.userId).toBe(userId);
|
||||
expect(result?.identifier).toBe('test-plugin-1');
|
||||
});
|
||||
});
|
||||
|
||||
describe('update', () => {
|
||||
it('should update an installed plugin', async () => {
|
||||
await serverDB.insert(installedPlugins).values({
|
||||
userId,
|
||||
type: 'plugin',
|
||||
identifier: 'test-plugin',
|
||||
manifest: {},
|
||||
settings: { enabled: true },
|
||||
} as unknown as NewInstalledPlugin);
|
||||
|
||||
await pluginModel.update('test-plugin', { settings: { enabled: false } });
|
||||
|
||||
const result = await pluginModel.findById('test-plugin');
|
||||
expect(result?.settings).toEqual({ enabled: false });
|
||||
});
|
||||
});
|
||||
});
|
||||
595
src/database/server/models/__tests__/session.test.ts
Normal file
595
src/database/server/models/__tests__/session.test.ts
Normal file
|
|
@ -0,0 +1,595 @@
|
|||
import { eq, inArray } from 'drizzle-orm';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { getTestDBInstance } from '@/database/server/core/dbForTest';
|
||||
|
||||
import {
|
||||
NewSession,
|
||||
SessionItem,
|
||||
agents,
|
||||
agentsToSessions,
|
||||
messages,
|
||||
plugins,
|
||||
sessionGroups,
|
||||
sessions,
|
||||
topics,
|
||||
users,
|
||||
} from '../../schemas/lobechat';
|
||||
import { idGenerator } from '../../utils/idGenerator';
|
||||
import { SessionModel } from '../session';
|
||||
|
||||
let serverDB = await getTestDBInstance();
|
||||
|
||||
vi.mock('@/database/server/core/db', async () => ({
|
||||
get serverDB() {
|
||||
return serverDB;
|
||||
},
|
||||
}));
|
||||
|
||||
const userId = 'session-user';
|
||||
const sessionModel = new SessionModel(userId);
|
||||
|
||||
beforeEach(async () => {
|
||||
await serverDB.delete(plugins);
|
||||
await serverDB.delete(users);
|
||||
// 并创建初始用户
|
||||
await serverDB.insert(users).values({ id: userId });
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
// 在每个测试用例之后, 清空用户表 (应该会自动级联删除所有数据)
|
||||
await serverDB.delete(users);
|
||||
});
|
||||
|
||||
describe('SessionModel', () => {
|
||||
describe('query', () => {
|
||||
it('should query sessions by user ID', async () => {
|
||||
// 创建一些测试数据
|
||||
await serverDB.insert(users).values([{ id: '456' }]);
|
||||
|
||||
await serverDB.insert(sessions).values([
|
||||
{ id: '1', userId, updatedAt: new Date('2023-01-01') },
|
||||
{ id: '2', userId, updatedAt: new Date('2023-02-01') },
|
||||
{ id: '3', userId: '456', updatedAt: new Date('2023-03-01') },
|
||||
]);
|
||||
|
||||
// 调用 query 方法
|
||||
const result = await sessionModel.query();
|
||||
|
||||
// 断言结果
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result[0].id).toBe('2');
|
||||
expect(result[1].id).toBe('1');
|
||||
});
|
||||
|
||||
it('should query sessions with pagination', async () => {
|
||||
// create test data
|
||||
await serverDB.insert(sessions).values([
|
||||
{ id: '1', userId, updatedAt: new Date('2023-01-01') },
|
||||
{ id: '2', userId, updatedAt: new Date('2023-02-01') },
|
||||
{ id: '3', userId, updatedAt: new Date('2023-03-01') },
|
||||
]);
|
||||
|
||||
// should return 2 sessions
|
||||
const result1 = await sessionModel.query({ current: 0, pageSize: 2 });
|
||||
expect(result1).toHaveLength(2);
|
||||
|
||||
// should return only 1 session and it's the 2nd one
|
||||
const result2 = await sessionModel.query({ current: 1, pageSize: 1 });
|
||||
expect(result2).toHaveLength(1);
|
||||
expect(result2[0].id).toBe('2');
|
||||
});
|
||||
});
|
||||
|
||||
describe('queryWithGroups', () => {
|
||||
it('should return sessions grouped by group', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.transaction(async (trx) => {
|
||||
await trx.insert(users).values([{ id: '456' }]);
|
||||
await trx.insert(sessionGroups).values([
|
||||
{ userId, name: 'Group 1', id: 'group1' },
|
||||
{ userId, name: 'Group 2', id: 'group2' },
|
||||
]);
|
||||
await trx.insert(sessions).values([
|
||||
{ id: '1', userId, groupId: 'group1' },
|
||||
{ id: '2', userId, groupId: 'group1' },
|
||||
{ id: '23', userId, groupId: 'group1', pinned: true },
|
||||
{ id: '3', userId, groupId: 'group2' },
|
||||
{ id: '4', userId },
|
||||
{ id: '5', userId, pinned: true },
|
||||
{ id: '7', userId: '456' },
|
||||
]);
|
||||
});
|
||||
|
||||
// 调用 queryWithGroups 方法
|
||||
const result = await sessionModel.queryWithGroups();
|
||||
|
||||
// 断言结果
|
||||
expect(result.sessions).toHaveLength(6);
|
||||
expect(result.sessionGroups).toHaveLength(2);
|
||||
expect(result.sessionGroups[0].id).toBe('group1');
|
||||
expect(result.sessionGroups[0].name).toBe('Group 1');
|
||||
|
||||
expect(result.sessionGroups[1].id).toBe('group2');
|
||||
});
|
||||
|
||||
it('should return empty groups if no sessions', async () => {
|
||||
// 调用 queryWithGroups 方法
|
||||
const result = await sessionModel.queryWithGroups();
|
||||
|
||||
// 断言结果
|
||||
expect(result.sessions).toHaveLength(0);
|
||||
expect(result.sessionGroups).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('findById', () => {
|
||||
it('should find session by ID', async () => {
|
||||
await serverDB.insert(sessions).values([
|
||||
{ id: '1', userId },
|
||||
{ id: '2', userId },
|
||||
]);
|
||||
|
||||
const result = await sessionModel.findByIdOrSlug('1');
|
||||
expect(result?.id).toBe('1');
|
||||
});
|
||||
|
||||
it('should return undefined if session not found', async () => {
|
||||
await serverDB.insert(sessions).values([{ id: '1', userId }]);
|
||||
|
||||
const result = await sessionModel.findByIdOrSlug('2');
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should find with agents', async () => {
|
||||
await serverDB.transaction(async (trx) => {
|
||||
await trx.insert(sessions).values([
|
||||
{ id: '1', userId },
|
||||
{ id: '2', userId },
|
||||
]);
|
||||
await trx.insert(agents).values([
|
||||
{ id: 'a1', title: 'Agent1', userId },
|
||||
{ id: 'a2', title: 'Agent2', userId },
|
||||
]);
|
||||
|
||||
// @ts-ignore
|
||||
await trx.insert(agentsToSessions).values([
|
||||
{ sessionId: '1', agentId: 'a1', userId },
|
||||
{ sessionId: '2', agentId: 'a2', userId },
|
||||
]);
|
||||
});
|
||||
|
||||
const result = await sessionModel.findByIdOrSlug('2');
|
||||
|
||||
expect(result?.agent).toBeDefined();
|
||||
expect(result?.agent.id).toEqual('a2');
|
||||
});
|
||||
});
|
||||
|
||||
// describe('getAgentConfigById', () => {
|
||||
// it('should return agent config by id', async () => {
|
||||
// await serverDB.transaction(async (trx) => {
|
||||
// await trx.insert(agents).values([
|
||||
// { id: '1', userId, model: 'gpt-3.5-turbo' },
|
||||
// { id: '2', userId, model: 'gpt-3.5' },
|
||||
// ]);
|
||||
//
|
||||
// // @ts-ignore
|
||||
// await trx.insert(plugins).values([
|
||||
// { id: 1, userId, identifier: 'abc', title: 'A1', locale: 'en-US', manifest: {} },
|
||||
// { id: 2, userId, identifier: 'b2', title: 'A2', locale: 'en-US', manifest: {} },
|
||||
// ]);
|
||||
//
|
||||
// await trx.insert(agentsPlugins).values([
|
||||
// { agentId: '1', pluginId: 1 },
|
||||
// { agentId: '2', pluginId: 2 },
|
||||
// { agentId: '1', pluginId: 2 },
|
||||
// ]);
|
||||
// });
|
||||
//
|
||||
// const result = await sessionModel.getAgentConfigById('1');
|
||||
//
|
||||
// expect(result?.id).toBe('1');
|
||||
// expect(result?.plugins).toBe(['abc', 'b2']);
|
||||
// expect(result?.model).toEqual('gpt-3.5-turbo');
|
||||
// expect(result?.chatConfig).toBeDefined();
|
||||
// });
|
||||
// });
|
||||
describe('count', () => {
|
||||
it('should return the count of sessions for the user', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(users).values([{ id: '456' }]);
|
||||
await serverDB.insert(sessions).values([
|
||||
{ id: '1', userId },
|
||||
{ id: '2', userId },
|
||||
{ id: '3', userId: '456' },
|
||||
]);
|
||||
|
||||
// 调用 count 方法
|
||||
const result = await sessionModel.count();
|
||||
|
||||
// 断言结果
|
||||
expect(result).toBe(2);
|
||||
});
|
||||
|
||||
it('should return 0 if no sessions exist for the user', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(users).values([{ id: '456' }]);
|
||||
await serverDB.insert(sessions).values([{ id: '3', userId: '456' }]);
|
||||
|
||||
// 调用 count 方法
|
||||
const result = await sessionModel.count();
|
||||
|
||||
// 断言结果
|
||||
expect(result).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('queryByKeyword', () => {
|
||||
it('should return an empty array if keyword is empty', async () => {
|
||||
const result = await sessionModel.queryByKeyword('');
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
it('should return sessions with matching title', async () => {
|
||||
await serverDB.insert(sessions).values([
|
||||
{ id: '1', userId, title: 'Hello World', description: 'Some description' },
|
||||
{ id: '2', userId, title: 'Another Session', description: 'Another description' },
|
||||
]);
|
||||
|
||||
const result = await sessionModel.queryByKeyword('hello');
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].id).toBe('1');
|
||||
});
|
||||
|
||||
it('should return sessions with matching description', async () => {
|
||||
await serverDB.insert(sessions).values([
|
||||
{ id: '1', userId, title: 'Session 1', description: 'Description with keyword' },
|
||||
{ id: '2', userId, title: 'Session 2', description: 'Another description' },
|
||||
]);
|
||||
|
||||
const result = await sessionModel.queryByKeyword('keyword');
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].id).toBe('1');
|
||||
});
|
||||
|
||||
it('should return sessions with matching title or description', async () => {
|
||||
await serverDB.insert(sessions).values([
|
||||
{ id: '1', userId, title: 'Title with keyword', description: 'Some description' },
|
||||
{ id: '2', userId, title: 'Another Session', description: 'Description with keyword' },
|
||||
{ id: '3', userId, title: 'Third Session', description: 'Third description' },
|
||||
]);
|
||||
|
||||
const result = await sessionModel.queryByKeyword('keyword');
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result.map((s) => s.id)).toEqual(['1', '2']);
|
||||
});
|
||||
});
|
||||
|
||||
describe('create', () => {
|
||||
it('should create a new session', async () => {
|
||||
// 调用 create 方法
|
||||
const result = await sessionModel.create({
|
||||
type: 'agent',
|
||||
session: {
|
||||
title: 'New Session',
|
||||
},
|
||||
config: { model: 'gpt-3.5-turbo' },
|
||||
});
|
||||
|
||||
// 断言结果
|
||||
const sessionId = result.id;
|
||||
expect(sessionId).toBeDefined();
|
||||
expect(sessionId.startsWith('ssn_')).toBeTruthy();
|
||||
expect(result.userId).toBe(userId);
|
||||
expect(result.type).toBe('agent');
|
||||
|
||||
const session = await sessionModel.findByIdOrSlug(sessionId);
|
||||
expect(session).toBeDefined();
|
||||
expect(session?.title).toEqual('New Session');
|
||||
expect(session?.pinned).toBe(false);
|
||||
expect(session?.agent?.model).toEqual('gpt-3.5-turbo');
|
||||
});
|
||||
|
||||
it('should create a new session with custom ID', async () => {
|
||||
// 调用 create 方法,传入自定义 ID
|
||||
const customId = 'custom-id';
|
||||
const result = await sessionModel.create({
|
||||
type: 'agent',
|
||||
config: { model: 'gpt-3.5-turbo' },
|
||||
session: { title: 'New Session' },
|
||||
id: customId,
|
||||
});
|
||||
|
||||
// 断言结果
|
||||
expect(result.id).toBe(customId);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skip('batchCreate', () => {
|
||||
it('should batch create sessions', async () => {
|
||||
// 调用 batchCreate 方法
|
||||
const sessions: NewSession[] = [
|
||||
{
|
||||
id: '1',
|
||||
userId,
|
||||
type: 'agent',
|
||||
// config: { model: 'gpt-3.5-turbo' },
|
||||
title: 'Session 1',
|
||||
},
|
||||
{
|
||||
id: '2',
|
||||
userId,
|
||||
type: 'agent',
|
||||
// config: { model: 'gpt-4' },
|
||||
title: 'Session 2',
|
||||
},
|
||||
];
|
||||
const result = await sessionModel.batchCreate(sessions);
|
||||
|
||||
// 断言结果
|
||||
expect(result.rowCount).toEqual(2);
|
||||
});
|
||||
|
||||
it.skip('should set group to default if group does not exist', async () => {
|
||||
// 调用 batchCreate 方法,传入不存在的 group
|
||||
const sessions: NewSession[] = [
|
||||
{
|
||||
id: '1',
|
||||
userId,
|
||||
type: 'agent',
|
||||
// config: { model: 'gpt-3.5-turbo' },
|
||||
title: 'Session 1',
|
||||
groupId: 'non-existent-group',
|
||||
},
|
||||
];
|
||||
const result = await sessionModel.batchCreate(sessions);
|
||||
|
||||
// 断言结果
|
||||
// expect(result[0].group).toBe('default');
|
||||
});
|
||||
});
|
||||
|
||||
describe('duplicate', () => {
|
||||
it.skip('should duplicate a session', async () => {
|
||||
// 创建一个用户和一个 session
|
||||
await serverDB.transaction(async (trx) => {
|
||||
await trx
|
||||
.insert(sessions)
|
||||
.values({ id: '1', userId, type: 'agent', title: 'Original Session', pinned: true });
|
||||
await trx.insert(agents).values({ id: 'agent-1', userId, model: 'gpt-3.5-turbo' });
|
||||
await trx.insert(agentsToSessions).values({ agentId: 'agent-1', sessionId: '1' });
|
||||
});
|
||||
|
||||
// 调用 duplicate 方法
|
||||
const result = (await sessionModel.duplicate('1', 'Duplicated Session')) as SessionItem;
|
||||
|
||||
// 断言结果
|
||||
expect(result.id).not.toBe('1');
|
||||
expect(result.userId).toBe(userId);
|
||||
expect(result.type).toBe('agent');
|
||||
|
||||
const session = await sessionModel.findByIdOrSlug(result.id);
|
||||
|
||||
expect(session).toBeDefined();
|
||||
expect(session?.title).toEqual('Duplicated Session');
|
||||
expect(session?.pinned).toBe(true);
|
||||
expect(session?.agent?.model).toEqual('gpt-3.5-turbo');
|
||||
});
|
||||
|
||||
it('should return undefined if session does not exist', async () => {
|
||||
// 调用 duplicate 方法,传入不存在的 session ID
|
||||
const result = await sessionModel.duplicate('non-existent-id');
|
||||
|
||||
// 断言结果
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('update', () => {
|
||||
it('should update a session', async () => {
|
||||
// 创建一个测试 session
|
||||
const sessionId = '123';
|
||||
await serverDB.insert(sessions).values({ userId, id: sessionId, title: 'Test Session' });
|
||||
|
||||
// 调用 update 方法更新 session
|
||||
const updatedSessions = await sessionModel.update(sessionId, {
|
||||
title: 'Updated Test Session',
|
||||
description: 'This is an updated test session',
|
||||
});
|
||||
|
||||
// 断言更新后的结果
|
||||
expect(updatedSessions).toHaveLength(1);
|
||||
expect(updatedSessions[0].title).toBe('Updated Test Session');
|
||||
expect(updatedSessions[0].description).toBe('This is an updated test session');
|
||||
});
|
||||
|
||||
it('should not update a session if user ID does not match', async () => {
|
||||
// 创建一个测试 session,但使用不同的 user ID
|
||||
await serverDB.insert(users).values([{ id: '777' }]);
|
||||
|
||||
const sessionId = '123';
|
||||
|
||||
await serverDB
|
||||
.insert(sessions)
|
||||
.values({ userId: '777', id: sessionId, title: 'Test Session' });
|
||||
|
||||
// 尝试更新这个 session,应该不会有任何更新
|
||||
const updatedSessions = await sessionModel.update(sessionId, {
|
||||
title: 'Updated Test Session',
|
||||
});
|
||||
|
||||
expect(updatedSessions).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('delete', () => {
|
||||
it('should handle deleting a session with no associated messages or topics', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(sessions).values({ id: '1', userId });
|
||||
|
||||
// 调用 delete 方法
|
||||
await sessionModel.delete('1');
|
||||
|
||||
// 断言删除结果
|
||||
const result = await serverDB.select({ id: sessions.id }).from(sessions);
|
||||
|
||||
expect(result).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should handle concurrent deletions gracefully', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(sessions).values({ id: '1', userId });
|
||||
|
||||
// 并发调用 delete 方法
|
||||
await Promise.all([sessionModel.delete('1'), sessionModel.delete('1')]);
|
||||
|
||||
// 断言删除结果
|
||||
const result = await serverDB.select({ id: sessions.id }).from(sessions);
|
||||
|
||||
expect(result).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should delete a session and its associated topics and messages', async () => {
|
||||
// Create a session
|
||||
const sessionId = '1';
|
||||
await serverDB.insert(sessions).values({ id: sessionId, userId });
|
||||
|
||||
// Create some topics and messages associated with the session
|
||||
await serverDB.insert(topics).values([
|
||||
{ id: '1', sessionId, userId },
|
||||
{ id: '2', sessionId, userId },
|
||||
]);
|
||||
await serverDB.insert(messages).values([
|
||||
{ id: '1', sessionId, userId, role: 'user' },
|
||||
{ id: '2', sessionId, userId, role: 'assistant' },
|
||||
]);
|
||||
|
||||
// Delete the session
|
||||
await sessionModel.delete(sessionId);
|
||||
|
||||
// Check that the session, topics, and messages are deleted
|
||||
expect(await serverDB.select().from(sessions).where(eq(sessions.id, sessionId))).toHaveLength(
|
||||
0,
|
||||
);
|
||||
expect(
|
||||
await serverDB.select().from(topics).where(eq(topics.sessionId, sessionId)),
|
||||
).toHaveLength(0);
|
||||
expect(
|
||||
await serverDB.select().from(messages).where(eq(messages.sessionId, sessionId)),
|
||||
).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should not delete sessions belonging to other users', async () => {
|
||||
// Create two users
|
||||
const anotherUserId = idGenerator('user');
|
||||
await serverDB.insert(users).values({ id: anotherUserId });
|
||||
|
||||
// Create a session for each user
|
||||
await serverDB.insert(sessions).values([
|
||||
{ id: '1', userId },
|
||||
{ id: '2', userId: anotherUserId },
|
||||
]);
|
||||
|
||||
// Delete the session belonging to the current user
|
||||
await sessionModel.delete('1');
|
||||
|
||||
// Check that only the session belonging to the current user is deleted
|
||||
expect(await serverDB.select().from(sessions).where(eq(sessions.id, '1'))).toHaveLength(0);
|
||||
expect(await serverDB.select().from(sessions).where(eq(sessions.id, '2'))).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('batchDelete', () => {
|
||||
it('should handle deleting sessions with no associated messages or topics', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(sessions).values([
|
||||
{ id: '1', userId },
|
||||
{ id: '2', userId },
|
||||
]);
|
||||
|
||||
// 调用 batchDelete 方法
|
||||
await sessionModel.batchDelete(['1', '2']);
|
||||
|
||||
// 断言删除结果
|
||||
const result = await serverDB.select({ id: sessions.id }).from(sessions);
|
||||
|
||||
expect(result).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should handle concurrent batch deletions gracefully', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(sessions).values([
|
||||
{ id: '1', userId },
|
||||
{ id: '2', userId },
|
||||
]);
|
||||
|
||||
// 并发调用 batchDelete 方法
|
||||
await Promise.all([
|
||||
sessionModel.batchDelete(['1', '2']),
|
||||
sessionModel.batchDelete(['1', '2']),
|
||||
]);
|
||||
|
||||
// 断言删除结果
|
||||
const result = await serverDB.select({ id: sessions.id }).from(sessions);
|
||||
|
||||
expect(result).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should delete multiple sessions and their associated topics and messages', async () => {
|
||||
// Create some sessions
|
||||
const sessionIds = ['1', '2', '3'];
|
||||
await serverDB.insert(sessions).values(sessionIds.map((id) => ({ id, userId })));
|
||||
|
||||
// Create some topics and messages associated with the sessions
|
||||
await serverDB.insert(topics).values([
|
||||
{ id: '1', sessionId: '1', userId },
|
||||
{ id: '2', sessionId: '2', userId },
|
||||
{ id: '3', sessionId: '3', userId },
|
||||
]);
|
||||
await serverDB.insert(messages).values([
|
||||
{ id: '1', sessionId: '1', userId, role: 'user' },
|
||||
{ id: '2', sessionId: '2', userId, role: 'assistant' },
|
||||
{ id: '3', sessionId: '3', userId, role: 'user' },
|
||||
]);
|
||||
|
||||
// Delete the sessions
|
||||
await sessionModel.batchDelete(sessionIds);
|
||||
|
||||
// Check that the sessions, topics, and messages are deleted
|
||||
expect(
|
||||
await serverDB.select().from(sessions).where(inArray(sessions.id, sessionIds)),
|
||||
).toHaveLength(0);
|
||||
expect(
|
||||
await serverDB.select().from(topics).where(inArray(topics.sessionId, sessionIds)),
|
||||
).toHaveLength(0);
|
||||
expect(
|
||||
await serverDB.select().from(messages).where(inArray(messages.sessionId, sessionIds)),
|
||||
).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should not delete sessions belonging to other users', async () => {
|
||||
// Create two users
|
||||
await serverDB.insert(users).values([{ id: '456' }]);
|
||||
|
||||
// Create some sessions for each user
|
||||
await serverDB.insert(sessions).values([
|
||||
{ id: '1', userId },
|
||||
{ id: '2', userId },
|
||||
{ id: '3', userId: '456' },
|
||||
]);
|
||||
|
||||
// Delete the sessions belonging to the current user
|
||||
await sessionModel.batchDelete(['1', '2']);
|
||||
|
||||
// Check that only the sessions belonging to the current user are deleted
|
||||
expect(
|
||||
await serverDB
|
||||
.select()
|
||||
.from(sessions)
|
||||
.where(inArray(sessions.id, ['1', '2'])),
|
||||
).toHaveLength(0);
|
||||
expect(await serverDB.select().from(sessions).where(eq(sessions.id, '3'))).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
623
src/database/server/models/__tests__/topic.test.ts
Normal file
623
src/database/server/models/__tests__/topic.test.ts
Normal file
|
|
@ -0,0 +1,623 @@
|
|||
import { eq, inArray } from 'drizzle-orm';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { getTestDBInstance } from '@/database/server/core/dbForTest';
|
||||
|
||||
import { messages, sessions, topics, users } from '../../schemas/lobechat';
|
||||
import { CreateTopicParams, TopicModel } from '../topic';
|
||||
|
||||
let serverDB = await getTestDBInstance();
|
||||
|
||||
vi.mock('@/database/server/core/db', async () => ({
|
||||
get serverDB() {
|
||||
return serverDB;
|
||||
},
|
||||
}));
|
||||
|
||||
const userId = 'topic-user-test';
|
||||
const sessionId = 'topic-session';
|
||||
const topicModel = new TopicModel(userId);
|
||||
|
||||
describe('TopicModel', () => {
|
||||
beforeEach(async () => {
|
||||
await serverDB.delete(users);
|
||||
|
||||
// 创建测试数据
|
||||
await serverDB.transaction(async (tx) => {
|
||||
await tx.insert(users).values({ id: userId });
|
||||
await tx.insert(sessions).values({ id: sessionId, userId });
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
// 在每个测试用例之后,清空表
|
||||
await serverDB.delete(users);
|
||||
});
|
||||
|
||||
describe('query', () => {
|
||||
it('should query topics by user ID', async () => {
|
||||
// 创建一些测试数据
|
||||
await serverDB.transaction(async (tx) => {
|
||||
await tx.insert(users).values([{ id: '456' }]);
|
||||
|
||||
await tx.insert(topics).values([
|
||||
{ id: '1', userId, sessionId, updatedAt: new Date('2023-01-01') },
|
||||
{ id: '4', userId, sessionId, updatedAt: new Date('2023-03-01') },
|
||||
{ id: '2', userId, sessionId, updatedAt: new Date('2023-02-01'), favorite: true },
|
||||
{ id: '5', userId, sessionId, updatedAt: new Date('2023-05-01'), favorite: true },
|
||||
{ id: '3', userId: '456', sessionId, updatedAt: new Date('2023-03-01') },
|
||||
]);
|
||||
});
|
||||
|
||||
// 调用 query 方法
|
||||
const result = await topicModel.query({ sessionId });
|
||||
|
||||
// 断言结果
|
||||
expect(result).toHaveLength(4);
|
||||
expect(result[0].id).toBe('5'); // favorite 的 topic 应该在前面,按照 updatedAt 降序排序
|
||||
expect(result[1].id).toBe('2');
|
||||
expect(result[2].id).toBe('4'); // 按照 updatedAt 降序排序
|
||||
});
|
||||
|
||||
it('should query topics with pagination', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(topics).values([
|
||||
{ id: '1', userId, updatedAt: new Date('2023-01-01') },
|
||||
{ id: '2', userId, updatedAt: new Date('2023-02-01') },
|
||||
{ id: '3', userId, updatedAt: new Date('2023-03-01') },
|
||||
]);
|
||||
|
||||
// 应该返回 2 个 topics
|
||||
const result1 = await topicModel.query({ current: 0, pageSize: 2 });
|
||||
expect(result1).toHaveLength(2);
|
||||
|
||||
// 应该只返回 1 个 topic,并且是第 2 个
|
||||
const result2 = await topicModel.query({ current: 1, pageSize: 1 });
|
||||
expect(result2).toHaveLength(1);
|
||||
expect(result2[0].id).toBe('2');
|
||||
});
|
||||
|
||||
it('should query topics by session ID', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.transaction(async (tx) => {
|
||||
await tx.insert(sessions).values([
|
||||
{ id: 'session1', userId },
|
||||
{ id: 'session2', userId },
|
||||
]);
|
||||
|
||||
await tx.insert(topics).values([
|
||||
{ id: '1', userId, sessionId: 'session1' },
|
||||
{ id: '2', userId, sessionId: 'session2' },
|
||||
{ id: '3', userId }, // 没有 sessionId
|
||||
]);
|
||||
});
|
||||
|
||||
// 应该只返回属于 session1 的 topic
|
||||
const result = await topicModel.query({ sessionId: 'session1' });
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].id).toBe('1');
|
||||
});
|
||||
|
||||
it('should return topics based on pagination parameters', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(topics).values([
|
||||
{ id: 'topic1', sessionId, userId, updatedAt: new Date('2023-01-01') },
|
||||
{ id: 'topic2', sessionId, userId, updatedAt: new Date('2023-01-02') },
|
||||
{ id: 'topic3', sessionId, userId, updatedAt: new Date('2023-01-03') },
|
||||
]);
|
||||
|
||||
// 调用 query 方法
|
||||
const result1 = await topicModel.query({ current: 0, pageSize: 2, sessionId });
|
||||
const result2 = await topicModel.query({ current: 1, pageSize: 2, sessionId });
|
||||
|
||||
// 断言返回结果符合分页要求
|
||||
expect(result1).toHaveLength(2);
|
||||
expect(result1[0].id).toBe('topic3');
|
||||
expect(result1[1].id).toBe('topic2');
|
||||
|
||||
expect(result2).toHaveLength(1);
|
||||
expect(result2[0].id).toBe('topic1');
|
||||
});
|
||||
});
|
||||
|
||||
describe('findById', () => {
|
||||
it('should return a topic by id', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(topics).values({ id: 'topic1', sessionId, userId });
|
||||
|
||||
// 调用 findById 方法
|
||||
const result = await topicModel.findById('topic1');
|
||||
|
||||
// 断言返回结果符合预期
|
||||
expect(result?.id).toBe('topic1');
|
||||
});
|
||||
|
||||
it('should return undefined for non-existent topic', async () => {
|
||||
// 调用 findById 方法
|
||||
const result = await topicModel.findById('non-existent');
|
||||
|
||||
// 断言返回 undefined
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('queryAll', () => {
|
||||
it('should return all topics', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(topics).values([
|
||||
{ id: 'topic1', sessionId, userId },
|
||||
{ id: 'topic2', sessionId, userId },
|
||||
]);
|
||||
|
||||
// 调用 queryAll 方法
|
||||
const result = await topicModel.queryAll();
|
||||
|
||||
// 断言返回所有的 topics
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result[0].id).toBe('topic1');
|
||||
expect(result[1].id).toBe('topic2');
|
||||
});
|
||||
});
|
||||
|
||||
describe('queryByKeyword', () => {
|
||||
it('should return topics matching topic title keyword', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.transaction(async (tx) => {
|
||||
await tx.insert(topics).values([
|
||||
{ id: 'topic1', title: 'Hello world', sessionId, userId },
|
||||
{ id: 'topic2', title: 'Goodbye', sessionId, userId },
|
||||
]);
|
||||
await tx
|
||||
.insert(messages)
|
||||
.values([
|
||||
{ id: 'message1', role: 'assistant', content: 'abc there', topicId: 'topic1', userId },
|
||||
]);
|
||||
});
|
||||
// 调用 queryByKeyword 方法
|
||||
const result = await topicModel.queryByKeyword('hello', sessionId);
|
||||
|
||||
// 断言返回匹配关键字的 topic
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].id).toBe('topic1');
|
||||
});
|
||||
|
||||
it('should return topics matching message content keyword', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.transaction(async (tx) => {
|
||||
await tx.insert(topics).values([
|
||||
{ id: 'topic1', title: 'abc world', sessionId, userId },
|
||||
{ id: 'topic2', title: 'Goodbye', sessionId, userId },
|
||||
]);
|
||||
await tx.insert(messages).values([
|
||||
{
|
||||
id: 'message1',
|
||||
role: 'assistant',
|
||||
content: 'Hello there',
|
||||
topicId: 'topic1',
|
||||
userId,
|
||||
},
|
||||
]);
|
||||
});
|
||||
// 调用 queryByKeyword 方法
|
||||
const result = await topicModel.queryByKeyword('hello', sessionId);
|
||||
|
||||
// 断言返回匹配关键字的 topic
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].id).toBe('topic1');
|
||||
});
|
||||
|
||||
it('should return nothing if not match', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(topics).values([
|
||||
{ id: 'topic1', title: 'Hello world', userId },
|
||||
{ id: 'topic2', title: 'Goodbye', sessionId, userId },
|
||||
]);
|
||||
await serverDB
|
||||
.insert(messages)
|
||||
.values([
|
||||
{ id: 'message1', role: 'assistant', content: 'abc there', topicId: 'topic1', userId },
|
||||
]);
|
||||
|
||||
// 调用 queryByKeyword 方法
|
||||
const result = await topicModel.queryByKeyword('hello', sessionId);
|
||||
|
||||
// 断言返回匹配关键字的 topic
|
||||
expect(result).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('count', () => {
|
||||
it('should return total number of topics', async () => {
|
||||
// 创建测试数据
|
||||
await serverDB.insert(topics).values([
|
||||
{ id: 'abc_topic1', sessionId, userId },
|
||||
{ id: 'abc_topic2', sessionId, userId },
|
||||
]);
|
||||
|
||||
// 调用 count 方法
|
||||
const result = await topicModel.count();
|
||||
|
||||
// 断言返回 topics 总数
|
||||
expect(result).toBe(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('delete', () => {
|
||||
it('should delete a topic and its associated messages', async () => {
|
||||
const topicId = 'topic1';
|
||||
await serverDB.transaction(async (tx) => {
|
||||
await tx.insert(users).values({ id: '345' });
|
||||
await tx.insert(sessions).values([
|
||||
{ id: 'session1', userId },
|
||||
{ id: 'session2', userId: '345' },
|
||||
]);
|
||||
await tx.insert(topics).values([
|
||||
{ id: topicId, sessionId: 'session1', userId },
|
||||
{ id: 'topic2', sessionId: 'session2', userId: '345' },
|
||||
]);
|
||||
await tx.insert(messages).values([
|
||||
{ id: 'message1', role: 'user', topicId: topicId, userId },
|
||||
{ id: 'message2', role: 'assistant', topicId: topicId, userId },
|
||||
{ id: 'message3', role: 'user', topicId: 'topic2', userId: '345' },
|
||||
]);
|
||||
});
|
||||
|
||||
// 调用 delete 方法
|
||||
await topicModel.delete(topicId);
|
||||
|
||||
// 断言 topic 和关联的 messages 都被删除了
|
||||
expect(
|
||||
await serverDB.select().from(messages).where(eq(messages.topicId, topicId)),
|
||||
).toHaveLength(0);
|
||||
expect(await serverDB.select().from(topics)).toHaveLength(1);
|
||||
|
||||
expect(await serverDB.select().from(messages)).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('batchDeleteBySessionId', () => {
|
||||
it('should delete all topics associated with a session', async () => {
|
||||
await serverDB.insert(sessions).values([
|
||||
{ id: 'session1', userId },
|
||||
{ id: 'session2', userId },
|
||||
]);
|
||||
await serverDB.insert(topics).values([
|
||||
{ id: 'topic1', sessionId: 'session1', userId },
|
||||
{ id: 'topic2', sessionId: 'session1', userId },
|
||||
{ id: 'topic3', sessionId: 'session2', userId },
|
||||
{ id: 'topic4', userId },
|
||||
]);
|
||||
|
||||
// 调用 batchDeleteBySessionId 方法
|
||||
await topicModel.batchDeleteBySessionId('session1');
|
||||
|
||||
// 断言属于 session1 的 topics 都被删除了
|
||||
expect(
|
||||
await serverDB.select().from(topics).where(eq(topics.sessionId, 'session1')),
|
||||
).toHaveLength(0);
|
||||
expect(await serverDB.select().from(topics)).toHaveLength(2);
|
||||
});
|
||||
it('should delete all topics associated without sessionId', async () => {
|
||||
await serverDB.insert(sessions).values([{ id: 'session1', userId }]);
|
||||
|
||||
await serverDB.insert(topics).values([
|
||||
{ id: 'topic1', sessionId: 'session1', userId },
|
||||
{ id: 'topic2', sessionId: 'session1', userId },
|
||||
{ id: 'topic4', userId },
|
||||
]);
|
||||
|
||||
// 调用 batchDeleteBySessionId 方法
|
||||
await topicModel.batchDeleteBySessionId();
|
||||
|
||||
// 断言属于 session1 的 topics 都被删除了
|
||||
expect(
|
||||
await serverDB.select().from(topics).where(eq(topics.sessionId, 'session1')),
|
||||
).toHaveLength(2);
|
||||
expect(await serverDB.select().from(topics)).toHaveLength(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('batchDelete', () => {
|
||||
it('should delete multiple topics and their associated messages', async () => {
|
||||
await serverDB.transaction(async (tx) => {
|
||||
await tx.insert(sessions).values({ id: 'session1', userId });
|
||||
await tx.insert(topics).values([
|
||||
{ id: 'topic1', sessionId: 'session1', userId },
|
||||
{ id: 'topic2', sessionId: 'session1', userId },
|
||||
{ id: 'topic3', sessionId: 'session1', userId },
|
||||
]);
|
||||
await tx.insert(messages).values([
|
||||
{ id: 'message1', role: 'user', topicId: 'topic1', userId },
|
||||
{ id: 'message2', role: 'assistant', topicId: 'topic2', userId },
|
||||
{ id: 'message3', role: 'user', topicId: 'topic3', userId },
|
||||
]);
|
||||
});
|
||||
|
||||
// 调用 batchDelete 方法
|
||||
await topicModel.batchDelete(['topic1', 'topic2']);
|
||||
|
||||
// 断言指定的 topics 和关联的 messages 都被删除了
|
||||
expect(await serverDB.select().from(topics)).toHaveLength(1);
|
||||
expect(await serverDB.select().from(messages)).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteAll', () => {
|
||||
it('should delete all topics of the user', async () => {
|
||||
await serverDB.insert(users).values({ id: '345' });
|
||||
await serverDB.insert(sessions).values([
|
||||
{ id: 'session1', userId },
|
||||
{ id: 'session2', userId: '345' },
|
||||
]);
|
||||
await serverDB.insert(topics).values([
|
||||
{ id: 'topic1', sessionId: 'session1', userId },
|
||||
{ id: 'topic2', sessionId: 'session1', userId },
|
||||
{ id: 'topic3', sessionId: 'session2', userId: '345' },
|
||||
]);
|
||||
|
||||
// 调用 deleteAll 方法
|
||||
await topicModel.deleteAll();
|
||||
|
||||
// 断言当前用户的所有 topics 都被删除了
|
||||
expect(await serverDB.select().from(topics).where(eq(topics.userId, userId))).toHaveLength(0);
|
||||
expect(await serverDB.select().from(topics)).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('update', () => {
|
||||
it('should update a topic', async () => {
|
||||
// 创建一个测试 session
|
||||
const topicId = '123';
|
||||
await serverDB.insert(topics).values({ userId, id: topicId, title: 'Test', favorite: true });
|
||||
|
||||
// 调用 update 方法更新 session
|
||||
const item = await topicModel.update(topicId, {
|
||||
title: 'Updated Test',
|
||||
favorite: false,
|
||||
});
|
||||
|
||||
// 断言更新后的结果
|
||||
expect(item).toHaveLength(1);
|
||||
expect(item[0].title).toBe('Updated Test');
|
||||
expect(item[0].favorite).toBeFalsy();
|
||||
});
|
||||
|
||||
it('should not update a topic if user ID does not match', async () => {
|
||||
// 创建一个测试 topic, 但使用不同的 user ID
|
||||
await serverDB.insert(users).values([{ id: '456' }]);
|
||||
const topicId = '123';
|
||||
await serverDB
|
||||
.insert(topics)
|
||||
.values({ userId: '456', id: topicId, title: 'Test', favorite: true });
|
||||
|
||||
// 尝试更新这个 topic , 应该不会有任何更新
|
||||
const item = await topicModel.update(topicId, {
|
||||
title: 'Updated Test Session',
|
||||
});
|
||||
|
||||
expect(item).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('create', () => {
|
||||
it('should create a new topic and associate messages', async () => {
|
||||
const topicData = {
|
||||
title: 'New Topic',
|
||||
favorite: true,
|
||||
sessionId,
|
||||
messages: ['message1', 'message2'],
|
||||
} satisfies CreateTopicParams;
|
||||
|
||||
const topicId = 'new-topic';
|
||||
|
||||
// 预先创建一些 messages
|
||||
await serverDB.insert(messages).values([
|
||||
{ id: 'message1', role: 'user', userId, sessionId },
|
||||
{ id: 'message2', role: 'assistant', userId, sessionId },
|
||||
{ id: 'message3', role: 'user', userId, sessionId },
|
||||
]);
|
||||
|
||||
// 调用 create 方法
|
||||
const createdTopic = await topicModel.create(topicData, topicId);
|
||||
|
||||
// 断言返回的 topic 数据正确
|
||||
expect(createdTopic).toEqual({
|
||||
id: topicId,
|
||||
title: 'New Topic',
|
||||
favorite: true,
|
||||
sessionId,
|
||||
userId,
|
||||
clientId: null,
|
||||
createdAt: expect.any(Date),
|
||||
updatedAt: expect.any(Date),
|
||||
});
|
||||
|
||||
// 断言 topic 已在数据库中创建
|
||||
const dbTopic = await serverDB.select().from(topics).where(eq(topics.id, topicId));
|
||||
expect(dbTopic).toHaveLength(1);
|
||||
expect(dbTopic[0]).toEqual(createdTopic);
|
||||
|
||||
// 断言关联的 messages 的 topicId 已更新
|
||||
const associatedMessages = await serverDB
|
||||
.select()
|
||||
.from(messages)
|
||||
.where(inArray(messages.id, topicData.messages!));
|
||||
expect(associatedMessages).toHaveLength(2);
|
||||
expect(associatedMessages.every((msg) => msg.topicId === topicId)).toBe(true);
|
||||
|
||||
// 断言未关联的 message 的 topicId 没有更新
|
||||
const unassociatedMessage = await serverDB
|
||||
.select()
|
||||
.from(messages)
|
||||
.where(eq(messages.id, 'message3'));
|
||||
|
||||
expect(unassociatedMessage[0].topicId).toBeNull();
|
||||
});
|
||||
|
||||
it('should create a new topic without associating messages', async () => {
|
||||
const topicData = {
|
||||
title: 'New Topic',
|
||||
favorite: false,
|
||||
sessionId,
|
||||
};
|
||||
|
||||
const topicId = 'new-topic';
|
||||
|
||||
// 调用 create 方法
|
||||
const createdTopic = await topicModel.create(topicData, topicId);
|
||||
|
||||
// 断言返回的 topic 数据正确
|
||||
expect(createdTopic).toEqual({
|
||||
id: topicId,
|
||||
title: 'New Topic',
|
||||
favorite: false,
|
||||
clientId: null,
|
||||
sessionId,
|
||||
userId,
|
||||
createdAt: expect.any(Date),
|
||||
updatedAt: expect.any(Date),
|
||||
});
|
||||
|
||||
// 断言 topic 已在数据库中创建
|
||||
const dbTopic = await serverDB.select().from(topics).where(eq(topics.id, topicId));
|
||||
expect(dbTopic).toHaveLength(1);
|
||||
expect(dbTopic[0]).toEqual(createdTopic);
|
||||
});
|
||||
});
|
||||
|
||||
describe('batchCreate', () => {
|
||||
it('should batch create topics and update associated messages', async () => {
|
||||
// 准备测试数据
|
||||
const topicParams = [
|
||||
{
|
||||
title: 'Topic 1',
|
||||
favorite: true,
|
||||
sessionId,
|
||||
messages: ['message1', 'message2'],
|
||||
},
|
||||
{
|
||||
title: 'Topic 2',
|
||||
favorite: false,
|
||||
sessionId,
|
||||
messages: ['message3'],
|
||||
},
|
||||
];
|
||||
await serverDB.insert(messages).values([
|
||||
{ id: 'message1', role: 'user', userId },
|
||||
{ id: 'message2', role: 'assistant', userId },
|
||||
{ id: 'message3', role: 'user', userId },
|
||||
]);
|
||||
|
||||
// 调用 batchCreate 方法
|
||||
const createdTopics = await topicModel.batchCreate(topicParams);
|
||||
|
||||
// 断言返回的 topics 数据正确
|
||||
expect(createdTopics).toHaveLength(2);
|
||||
expect(createdTopics[0]).toMatchObject({
|
||||
title: 'Topic 1',
|
||||
favorite: true,
|
||||
sessionId,
|
||||
userId,
|
||||
});
|
||||
expect(createdTopics[1]).toMatchObject({
|
||||
title: 'Topic 2',
|
||||
favorite: false,
|
||||
sessionId,
|
||||
userId,
|
||||
});
|
||||
|
||||
// 断言 topics 表中的数据正确
|
||||
const items = await serverDB.select().from(topics);
|
||||
expect(items).toHaveLength(2);
|
||||
expect(items[0]).toMatchObject({
|
||||
title: 'Topic 1',
|
||||
favorite: true,
|
||||
sessionId,
|
||||
userId,
|
||||
});
|
||||
expect(items[1]).toMatchObject({
|
||||
title: 'Topic 2',
|
||||
favorite: false,
|
||||
sessionId,
|
||||
userId,
|
||||
});
|
||||
|
||||
// 断言关联的 messages 的 topicId 被正确更新
|
||||
const updatedMessages = await serverDB.select().from(messages);
|
||||
expect(updatedMessages).toHaveLength(3);
|
||||
expect(updatedMessages[0].topicId).toBe(createdTopics[0].id);
|
||||
expect(updatedMessages[1].topicId).toBe(createdTopics[0].id);
|
||||
expect(updatedMessages[2].topicId).toBe(createdTopics[1].id);
|
||||
});
|
||||
|
||||
it('should generate topic IDs if not provided', async () => {
|
||||
// 准备测试数据
|
||||
const topicParams = [
|
||||
{
|
||||
title: 'Topic 1',
|
||||
favorite: true,
|
||||
sessionId,
|
||||
},
|
||||
{
|
||||
title: 'Topic 2',
|
||||
favorite: false,
|
||||
sessionId,
|
||||
},
|
||||
];
|
||||
|
||||
// 调用 batchCreate 方法
|
||||
const createdTopics = await topicModel.batchCreate(topicParams);
|
||||
|
||||
// 断言生成了正确的 topic ID
|
||||
expect(createdTopics[0].id).toBeDefined();
|
||||
expect(createdTopics[1].id).toBeDefined();
|
||||
expect(createdTopics[0].id).not.toBe(createdTopics[1].id);
|
||||
});
|
||||
});
|
||||
|
||||
describe('duplicate', () => {
|
||||
it('should duplicate a topic and its associated messages', async () => {
|
||||
const topicId = 'topic-duplicate';
|
||||
const newTitle = 'Duplicated Topic';
|
||||
|
||||
// 创建原始的 topic 和 messages
|
||||
await serverDB.transaction(async (tx) => {
|
||||
await tx.insert(topics).values({ id: topicId, sessionId, userId, title: 'Original Topic' });
|
||||
await tx.insert(messages).values([
|
||||
{ id: 'message1', role: 'user', topicId, userId, content: 'User message' },
|
||||
{ id: 'message2', role: 'assistant', topicId, userId, content: 'Assistant message' },
|
||||
]);
|
||||
});
|
||||
|
||||
// 调用 duplicate 方法
|
||||
const { topic: duplicatedTopic, messages: duplicatedMessages } = await topicModel.duplicate(
|
||||
topicId,
|
||||
newTitle,
|
||||
);
|
||||
|
||||
// 断言复制的 topic 的属性正确
|
||||
expect(duplicatedTopic.id).not.toBe(topicId);
|
||||
expect(duplicatedTopic.title).toBe(newTitle);
|
||||
expect(duplicatedTopic.sessionId).toBe(sessionId);
|
||||
expect(duplicatedTopic.userId).toBe(userId);
|
||||
|
||||
// 断言复制的 messages 的属性正确
|
||||
expect(duplicatedMessages).toHaveLength(2);
|
||||
expect(duplicatedMessages[0].id).not.toBe('message1');
|
||||
expect(duplicatedMessages[0].topicId).toBe(duplicatedTopic.id);
|
||||
expect(duplicatedMessages[0].content).toBe('User message');
|
||||
expect(duplicatedMessages[1].id).not.toBe('message2');
|
||||
expect(duplicatedMessages[1].topicId).toBe(duplicatedTopic.id);
|
||||
expect(duplicatedMessages[1].content).toBe('Assistant message');
|
||||
});
|
||||
|
||||
it('should throw an error if the topic to duplicate does not exist', async () => {
|
||||
const topicId = 'nonexistent-topic';
|
||||
|
||||
// 调用 duplicate 方法,期望抛出错误
|
||||
await expect(topicModel.duplicate(topicId)).rejects.toThrow(
|
||||
`Topic with id ${topicId} not found`,
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
173
src/database/server/models/__tests__/user.test.ts
Normal file
173
src/database/server/models/__tests__/user.test.ts
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
import { eq } from 'drizzle-orm';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { INBOX_SESSION_ID } from '@/const/session';
|
||||
import { getTestDBInstance } from '@/database/server/core/dbForTest';
|
||||
import { KeyVaultsGateKeeper } from '@/server/keyVaultsEncrypt';
|
||||
import { UserPreference } from '@/types/user';
|
||||
import { UserSettings } from '@/types/user/settings';
|
||||
|
||||
import { userSettings, users } from '../../schemas/lobechat';
|
||||
import { SessionModel } from '../session';
|
||||
import { UserModel } from '../user';
|
||||
|
||||
let serverDB = await getTestDBInstance();
|
||||
|
||||
vi.mock('@/database/server/core/db', async () => ({
|
||||
get serverDB() {
|
||||
return serverDB;
|
||||
},
|
||||
}));
|
||||
|
||||
const userId = 'user-db';
|
||||
const userModel = new UserModel();
|
||||
|
||||
beforeEach(async () => {
|
||||
await serverDB.delete(users);
|
||||
await serverDB.delete(userSettings);
|
||||
process.env.KEY_VAULTS_SECRET = 'ofQiJCXLF8mYemwfMWLOHoHimlPu91YmLfU7YZ4lreQ=';
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await serverDB.delete(users);
|
||||
await serverDB.delete(userSettings);
|
||||
process.env.KEY_VAULTS_SECRET = undefined;
|
||||
});
|
||||
|
||||
describe('UserModel', () => {
|
||||
describe('createUser', () => {
|
||||
it('should create a new user and inbox session', async () => {
|
||||
const params = {
|
||||
id: userId,
|
||||
username: 'testuser',
|
||||
email: 'test@example.com',
|
||||
};
|
||||
|
||||
await userModel.createUser(params);
|
||||
|
||||
const user = await serverDB.query.users.findFirst({ where: eq(users.id, userId) });
|
||||
expect(user).not.toBeNull();
|
||||
expect(user?.username).toBe('testuser');
|
||||
expect(user?.email).toBe('test@example.com');
|
||||
|
||||
const sessionModel = new SessionModel(userId);
|
||||
const inbox = await sessionModel.findByIdOrSlug(INBOX_SESSION_ID);
|
||||
expect(inbox).not.toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteUser', () => {
|
||||
it('should delete a user', async () => {
|
||||
await serverDB.insert(users).values({ id: userId });
|
||||
|
||||
await userModel.deleteUser(userId);
|
||||
|
||||
const user = await serverDB.query.users.findFirst({ where: eq(users.id, userId) });
|
||||
expect(user).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('findById', () => {
|
||||
it('should find a user by ID', async () => {
|
||||
await serverDB.insert(users).values({ id: userId, username: 'testuser' });
|
||||
|
||||
const user = await userModel.findById(userId);
|
||||
|
||||
expect(user).not.toBeNull();
|
||||
expect(user?.id).toBe(userId);
|
||||
expect(user?.username).toBe('testuser');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getUserState', () => {
|
||||
it('should get user state with decrypted keyVaults', async () => {
|
||||
const preference = { useCmdEnterToSend: true } as UserPreference;
|
||||
const keyVaults = { apiKey: 'secret' };
|
||||
|
||||
await serverDB.insert(users).values({ id: userId, preference });
|
||||
|
||||
const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey();
|
||||
const encryptedKeyVaults = await gateKeeper.encrypt(JSON.stringify(keyVaults));
|
||||
|
||||
await serverDB.insert(userSettings).values({
|
||||
id: userId,
|
||||
keyVaults: encryptedKeyVaults,
|
||||
});
|
||||
|
||||
const state = await userModel.getUserState(userId);
|
||||
|
||||
expect(state.userId).toBe(userId);
|
||||
expect(state.preference).toEqual(preference);
|
||||
expect(state.settings.keyVaults).toEqual(keyVaults);
|
||||
});
|
||||
|
||||
it('should throw an error if user not found', async () => {
|
||||
await expect(userModel.getUserState('invalid-user-id')).rejects.toThrow('user not found');
|
||||
});
|
||||
});
|
||||
|
||||
describe('updateUser', () => {
|
||||
it('should update user fields', async () => {
|
||||
await serverDB.insert(users).values({ id: userId, username: 'oldname' });
|
||||
|
||||
await userModel.updateUser(userId, { username: 'newname' });
|
||||
|
||||
const updatedUser = await serverDB.query.users.findFirst({
|
||||
where: eq(users.id, userId),
|
||||
});
|
||||
expect(updatedUser?.username).toBe('newname');
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteSetting', () => {
|
||||
it('should delete user settings', async () => {
|
||||
await serverDB.insert(users).values({ id: userId });
|
||||
await serverDB.insert(userSettings).values({ id: userId });
|
||||
|
||||
await userModel.deleteSetting(userId);
|
||||
|
||||
const settings = await serverDB.query.userSettings.findFirst({
|
||||
where: eq(users.id, userId),
|
||||
});
|
||||
|
||||
expect(settings).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('updateSetting', () => {
|
||||
it('should update user settings with encrypted keyVaults', async () => {
|
||||
const settings = {
|
||||
general: { language: 'en-US' },
|
||||
keyVaults: { openai: { apiKey: 'secret' } },
|
||||
} as UserSettings;
|
||||
await serverDB.insert(users).values({ id: userId });
|
||||
|
||||
await userModel.updateSetting(userId, settings);
|
||||
|
||||
const updatedSettings = await serverDB.query.userSettings.findFirst({
|
||||
where: eq(users.id, userId),
|
||||
});
|
||||
expect(updatedSettings?.general).toEqual(settings.general);
|
||||
expect(updatedSettings?.keyVaults).not.toBe(JSON.stringify(settings.keyVaults));
|
||||
|
||||
const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey();
|
||||
const { plaintext } = await gateKeeper.decrypt(updatedSettings!.keyVaults!);
|
||||
expect(JSON.parse(plaintext)).toEqual(settings.keyVaults);
|
||||
});
|
||||
});
|
||||
|
||||
describe('updatePreference', () => {
|
||||
it('should update user preference', async () => {
|
||||
const preference = { guide: { topic: false } } as UserPreference;
|
||||
await serverDB.insert(users).values({ id: userId, preference });
|
||||
|
||||
const newPreference: Partial<UserPreference> = {
|
||||
guide: { topic: true, moveSettingsToAvatar: true },
|
||||
};
|
||||
await userModel.updatePreference(userId, newPreference);
|
||||
|
||||
const updatedUser = await serverDB.query.users.findFirst({ where: eq(users.id, userId) });
|
||||
expect(updatedUser?.preference).toEqual({ ...preference, ...newPreference });
|
||||
});
|
||||
});
|
||||
});
|
||||
44
src/database/server/models/_template.ts
Normal file
44
src/database/server/models/_template.ts
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
import { eq } from 'drizzle-orm';
|
||||
import { and, desc } from 'drizzle-orm/expressions';
|
||||
|
||||
import { serverDB } from '@/database/server';
|
||||
|
||||
import { NewSessionGroup, UserItem, sessionGroups } from '../schemas/lobechat';
|
||||
|
||||
export class TemplateModel {
|
||||
private userId: string;
|
||||
|
||||
constructor(userId: string) {
|
||||
this.userId = userId;
|
||||
}
|
||||
|
||||
create = async (params: NewSessionGroup) => {
|
||||
return serverDB.insert(sessionGroups).values({ ...params, userId: this.userId });
|
||||
};
|
||||
|
||||
delete = async (id: string) => {
|
||||
return serverDB
|
||||
.delete(sessionGroups)
|
||||
.where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)));
|
||||
};
|
||||
|
||||
query = async () => {
|
||||
return serverDB.query.sessionGroups.findMany({
|
||||
orderBy: [desc(sessionGroups.updatedAt)],
|
||||
where: eq(sessionGroups.userId, this.userId),
|
||||
});
|
||||
};
|
||||
|
||||
findById = async (id: string) => {
|
||||
return serverDB.query.sessionGroups.findFirst({
|
||||
where: and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)),
|
||||
});
|
||||
};
|
||||
|
||||
async update(id: string, value: Partial<UserItem>) {
|
||||
return serverDB
|
||||
.update(sessionGroups)
|
||||
.set({ ...value, updatedAt: new Date() })
|
||||
.where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)));
|
||||
}
|
||||
}
|
||||
51
src/database/server/models/file.ts
Normal file
51
src/database/server/models/file.ts
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
import { eq } from 'drizzle-orm';
|
||||
import { and, desc } from 'drizzle-orm/expressions';
|
||||
|
||||
import { serverDB } from '@/database/server/core/db';
|
||||
|
||||
import { FileItem, NewFile, files } from '../schemas/lobechat';
|
||||
|
||||
export class FileModel {
|
||||
private readonly userId: string;
|
||||
|
||||
constructor(userId: string) {
|
||||
this.userId = userId;
|
||||
}
|
||||
|
||||
create = async (params: Omit<NewFile, 'id' | 'userId'>) => {
|
||||
const result = await serverDB
|
||||
.insert(files)
|
||||
.values({ ...params, userId: this.userId })
|
||||
.returning();
|
||||
|
||||
return { id: result[0].id };
|
||||
};
|
||||
|
||||
delete = async (id: string) => {
|
||||
return serverDB.delete(files).where(and(eq(files.id, id), eq(files.userId, this.userId)));
|
||||
};
|
||||
|
||||
clear = async () => {
|
||||
return serverDB.delete(files).where(eq(files.userId, this.userId));
|
||||
};
|
||||
|
||||
query = async () => {
|
||||
return serverDB.query.files.findMany({
|
||||
orderBy: [desc(files.updatedAt)],
|
||||
where: eq(files.userId, this.userId),
|
||||
});
|
||||
};
|
||||
|
||||
findById = async (id: string) => {
|
||||
return serverDB.query.files.findFirst({
|
||||
where: and(eq(files.id, id), eq(files.userId, this.userId)),
|
||||
});
|
||||
};
|
||||
|
||||
async update(id: string, value: Partial<FileItem>) {
|
||||
return serverDB
|
||||
.update(files)
|
||||
.set({ ...value, updatedAt: new Date() })
|
||||
.where(and(eq(files.id, id), eq(files.userId, this.userId)));
|
||||
}
|
||||
}
|
||||
378
src/database/server/models/message.ts
Normal file
378
src/database/server/models/message.ts
Normal file
|
|
@ -0,0 +1,378 @@
|
|||
import { count, sql } from 'drizzle-orm';
|
||||
import { and, asc, desc, eq, isNull, like } from 'drizzle-orm/expressions';
|
||||
import { inArray } from 'drizzle-orm/sql/expressions/conditions';
|
||||
|
||||
import { CreateMessageParams } from '@/database/client/models/message';
|
||||
import { serverDB } from '@/database/server/core/db';
|
||||
import { idGenerator } from '@/database/server/utils/idGenerator';
|
||||
import { ChatTTS, ChatToolPayload } from '@/types/message';
|
||||
import { merge } from '@/utils/merge';
|
||||
|
||||
import {
|
||||
MessageItem,
|
||||
filesToMessages,
|
||||
messagePlugins,
|
||||
messageTTS,
|
||||
messageTranslates,
|
||||
messages,
|
||||
} from '../schemas/lobechat';
|
||||
|
||||
export interface QueryMessageParams {
|
||||
current?: number;
|
||||
pageSize?: number;
|
||||
sessionId?: string | null;
|
||||
topicId?: string | null;
|
||||
}
|
||||
|
||||
export class MessageModel {
|
||||
private userId: string;
|
||||
|
||||
constructor(userId: string) {
|
||||
this.userId = userId;
|
||||
}
|
||||
|
||||
// **************** Query *************** //
|
||||
async query({
|
||||
current = 0,
|
||||
pageSize = 1000,
|
||||
sessionId,
|
||||
topicId,
|
||||
}: QueryMessageParams = {}): Promise<MessageItem[]> {
|
||||
const offset = current * pageSize;
|
||||
|
||||
const result = await serverDB
|
||||
.select({
|
||||
/* eslint-disable sort-keys-fix/sort-keys-fix*/
|
||||
id: messages.id,
|
||||
role: messages.role,
|
||||
content: messages.content,
|
||||
error: messages.error,
|
||||
|
||||
model: messages.model,
|
||||
provider: messages.provider,
|
||||
|
||||
createdAt: messages.createdAt,
|
||||
updatedAt: messages.updatedAt,
|
||||
|
||||
parentId: messages.parentId,
|
||||
|
||||
tools: messages.tools,
|
||||
tool_call_id: messagePlugins.toolCallId,
|
||||
|
||||
plugin: {
|
||||
apiName: messagePlugins.apiName,
|
||||
arguments: messagePlugins.arguments,
|
||||
identifier: messagePlugins.identifier,
|
||||
type: messagePlugins.type,
|
||||
},
|
||||
pluginError: messagePlugins.error,
|
||||
pluginState: messagePlugins.state,
|
||||
|
||||
translate: {
|
||||
content: messageTranslates.content,
|
||||
from: messageTranslates.from,
|
||||
to: messageTranslates.to,
|
||||
},
|
||||
|
||||
ttsId: messageTTS.id,
|
||||
|
||||
// TODO: 确认下如何处理 TTS 的读取
|
||||
// ttsContentMd5: messageTTS.contentMd5,
|
||||
// ttsFile: messageTTS.fileId,
|
||||
// ttsVoice: messageTTS.voice,
|
||||
/* eslint-enable */
|
||||
})
|
||||
.from(messages)
|
||||
.where(
|
||||
and(
|
||||
eq(messages.userId, this.userId),
|
||||
this.matchSession(sessionId),
|
||||
this.matchTopic(topicId),
|
||||
),
|
||||
)
|
||||
.leftJoin(messagePlugins, eq(messagePlugins.id, messages.id))
|
||||
.leftJoin(messageTranslates, eq(messageTranslates.id, messages.id))
|
||||
.leftJoin(messageTTS, eq(messageTTS.id, messages.id))
|
||||
.orderBy(asc(messages.createdAt))
|
||||
.limit(pageSize)
|
||||
.offset(offset);
|
||||
|
||||
const messageIds = result.map((message) => message.id as string);
|
||||
|
||||
if (messageIds.length === 0) return result;
|
||||
|
||||
const fileIds = await serverDB
|
||||
.select({
|
||||
fileId: filesToMessages.fileId,
|
||||
messageId: filesToMessages.messageId,
|
||||
})
|
||||
.from(filesToMessages)
|
||||
.where(inArray(filesToMessages.messageId, messageIds));
|
||||
|
||||
return result.map(
|
||||
({
|
||||
model,
|
||||
provider,
|
||||
translate,
|
||||
ttsId,
|
||||
// ttsFile, ttsId, ttsContentMd5, ttsVoice,
|
||||
...item
|
||||
}) => ({
|
||||
...item,
|
||||
extra: {
|
||||
fromModel: model,
|
||||
fromProvider: provider,
|
||||
translate,
|
||||
tts: ttsId
|
||||
? {
|
||||
// contentMd5: ttsContentMd5,
|
||||
// file: ttsFile,
|
||||
// voice: ttsVoice,
|
||||
}
|
||||
: undefined,
|
||||
},
|
||||
files: fileIds.filter((relation) => relation.messageId === item.id).map((r) => r.fileId),
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
async findById(id: string) {
|
||||
return serverDB.query.messages.findFirst({
|
||||
where: and(eq(messages.id, id), eq(messages.userId, this.userId)),
|
||||
});
|
||||
}
|
||||
|
||||
async queryAll(): Promise<MessageItem[]> {
|
||||
return serverDB
|
||||
.select()
|
||||
.from(messages)
|
||||
.orderBy(messages.createdAt)
|
||||
.where(eq(messages.userId, this.userId))
|
||||
|
||||
.execute();
|
||||
}
|
||||
|
||||
async queryBySessionId(sessionId?: string | null): Promise<MessageItem[]> {
|
||||
return serverDB.query.messages.findMany({
|
||||
orderBy: [asc(messages.createdAt)],
|
||||
where: and(eq(messages.userId, this.userId), this.matchSession(sessionId)),
|
||||
});
|
||||
}
|
||||
|
||||
async queryByKeyword(keyword: string): Promise<MessageItem[]> {
|
||||
if (!keyword) return [];
|
||||
|
||||
return serverDB.query.messages.findMany({
|
||||
orderBy: [desc(messages.createdAt)],
|
||||
where: and(eq(messages.userId, this.userId), like(messages.content, `%${keyword}%`)),
|
||||
});
|
||||
}
|
||||
|
||||
async count() {
|
||||
const result = await serverDB
|
||||
.select({
|
||||
count: count(),
|
||||
})
|
||||
.from(messages)
|
||||
.where(eq(messages.userId, this.userId))
|
||||
.execute();
|
||||
|
||||
return result[0].count;
|
||||
}
|
||||
|
||||
async countToday() {
|
||||
const today = new Date();
|
||||
today.setHours(0, 0, 0, 0);
|
||||
const tomorrow = new Date(today);
|
||||
tomorrow.setDate(tomorrow.getDate() + 1);
|
||||
|
||||
const result = await serverDB
|
||||
.select({
|
||||
count: count(),
|
||||
})
|
||||
.from(messages)
|
||||
.where(
|
||||
and(
|
||||
eq(messages.userId, this.userId),
|
||||
sql`${messages.createdAt} >= ${today} AND ${messages.createdAt} < ${tomorrow}`,
|
||||
),
|
||||
)
|
||||
.execute();
|
||||
|
||||
return result[0].count;
|
||||
}
|
||||
|
||||
// **************** Create *************** //
|
||||
|
||||
async create(
|
||||
{ fromModel, fromProvider, files, ...message }: CreateMessageParams,
|
||||
id: string = this.genId(),
|
||||
): Promise<MessageItem> {
|
||||
return serverDB.transaction(async (trx) => {
|
||||
const [item] = (await trx
|
||||
.insert(messages)
|
||||
.values({
|
||||
...message,
|
||||
id,
|
||||
model: fromModel,
|
||||
provider: fromProvider,
|
||||
userId: this.userId,
|
||||
})
|
||||
.returning()) as MessageItem[];
|
||||
|
||||
// Insert the plugin data if the message is a tool
|
||||
if (message.role === 'tool') {
|
||||
await trx.insert(messagePlugins).values({
|
||||
apiName: message.plugin?.apiName,
|
||||
arguments: message.plugin?.arguments,
|
||||
id,
|
||||
identifier: message.plugin?.identifier,
|
||||
toolCallId: message.tool_call_id,
|
||||
type: message.plugin?.type,
|
||||
});
|
||||
}
|
||||
|
||||
if (files && files.length > 0) {
|
||||
await trx
|
||||
.insert(filesToMessages)
|
||||
.values(files.map((file) => ({ fileId: file, messageId: id })));
|
||||
}
|
||||
|
||||
return item;
|
||||
});
|
||||
}
|
||||
|
||||
async batchCreate(newMessages: MessageItem[]) {
|
||||
const messagesToInsert = newMessages.map((m) => {
|
||||
return { ...m, userId: this.userId };
|
||||
});
|
||||
|
||||
return serverDB.insert(messages).values(messagesToInsert);
|
||||
}
|
||||
|
||||
// **************** Update *************** //
|
||||
|
||||
async update(id: string, message: Partial<MessageItem>) {
|
||||
return serverDB
|
||||
.update(messages)
|
||||
.set(message)
|
||||
.where(and(eq(messages.id, id), eq(messages.userId, this.userId)));
|
||||
}
|
||||
|
||||
async updatePluginState(id: string, state: Record<string, any>) {
|
||||
const item = await serverDB.query.messagePlugins.findFirst({
|
||||
where: eq(messagePlugins.id, id),
|
||||
});
|
||||
if (!item) throw new Error('Plugin not found');
|
||||
|
||||
return serverDB
|
||||
.update(messagePlugins)
|
||||
.set({ state: merge(item.state || {}, state) })
|
||||
.where(eq(messagePlugins.id, id));
|
||||
}
|
||||
|
||||
async updateTranslate(id: string, translate: Partial<MessageItem>) {
|
||||
const result = await serverDB.query.messageTranslates.findFirst({
|
||||
where: and(eq(messageTranslates.id, id)),
|
||||
});
|
||||
|
||||
// If the message does not exist in the translate table, insert it
|
||||
if (!result) {
|
||||
return serverDB.insert(messageTranslates).values({ ...translate, id });
|
||||
}
|
||||
|
||||
// or just update the existing one
|
||||
return serverDB.update(messageTranslates).set(translate).where(eq(messageTranslates.id, id));
|
||||
}
|
||||
|
||||
async updateTTS(id: string, tts: Partial<ChatTTS>) {
|
||||
const result = await serverDB.query.messageTTS.findFirst({
|
||||
where: and(eq(messageTTS.id, id)),
|
||||
});
|
||||
|
||||
// If the message does not exist in the translate table, insert it
|
||||
if (!result) {
|
||||
return serverDB
|
||||
.insert(messageTTS)
|
||||
.values({ contentMd5: tts.contentMd5, fileId: tts.file, id, voice: tts.voice });
|
||||
}
|
||||
|
||||
// or just update the existing one
|
||||
return serverDB
|
||||
.update(messageTTS)
|
||||
.set({ contentMd5: tts.contentMd5, fileId: tts.file, voice: tts.voice })
|
||||
.where(eq(messageTTS.id, id));
|
||||
}
|
||||
|
||||
// **************** Delete *************** //
|
||||
|
||||
async deleteMessage(id: string) {
|
||||
return serverDB.transaction(async (tx) => {
|
||||
// 1. 查询要删除的 message 的完整信息
|
||||
const message = await tx
|
||||
.select()
|
||||
.from(messages)
|
||||
.where(and(eq(messages.id, id), eq(messages.userId, this.userId)))
|
||||
.limit(1);
|
||||
|
||||
// 如果找不到要删除的 message,直接返回
|
||||
if (message.length === 0) return;
|
||||
|
||||
// 2. 检查 message 是否包含 tools
|
||||
const toolCallIds = message[0].tools?.map((tool: ChatToolPayload) => tool.id).filter(Boolean);
|
||||
|
||||
let relatedMessageIds: string[] = [];
|
||||
|
||||
if (toolCallIds?.length > 0) {
|
||||
// 3. 如果 message 包含 tools,查询出所有相关联的 message id
|
||||
const res = await tx
|
||||
.select({ id: messagePlugins.id })
|
||||
.from(messagePlugins)
|
||||
.where(inArray(messagePlugins.toolCallId, toolCallIds))
|
||||
.execute();
|
||||
|
||||
relatedMessageIds = res.map((row) => row.id);
|
||||
}
|
||||
|
||||
// 4. 合并要删除的 message id 列表
|
||||
const messageIdsToDelete = [id, ...relatedMessageIds];
|
||||
|
||||
// 5. 删除所有相关的 message
|
||||
await tx.delete(messages).where(inArray(messages.id, messageIdsToDelete));
|
||||
});
|
||||
}
|
||||
|
||||
async deleteMessageTranslate(id: string) {
|
||||
return serverDB.delete(messageTranslates).where(and(eq(messageTranslates.id, id)));
|
||||
}
|
||||
|
||||
async deleteMessageTTS(id: string) {
|
||||
return serverDB.delete(messageTTS).where(and(eq(messageTTS.id, id)));
|
||||
}
|
||||
|
||||
async deleteMessages(sessionId?: string | null, topicId?: string | null) {
|
||||
return serverDB
|
||||
.delete(messages)
|
||||
.where(
|
||||
and(
|
||||
eq(messages.userId, this.userId),
|
||||
this.matchSession(sessionId),
|
||||
this.matchTopic(topicId),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
async deleteAllMessages() {
|
||||
return serverDB.delete(messages).where(eq(messages.userId, this.userId));
|
||||
}
|
||||
|
||||
// **************** Helper *************** //
|
||||
|
||||
private genId = () => idGenerator('messages', 14);
|
||||
|
||||
private matchSession = (sessionId?: string | null) =>
|
||||
sessionId ? eq(messages.sessionId, sessionId) : isNull(messages.sessionId);
|
||||
|
||||
private matchTopic = (topicId?: string | null) =>
|
||||
topicId ? eq(messages.topicId, topicId) : isNull(messages.topicId);
|
||||
}
|
||||
63
src/database/server/models/plugin.ts
Normal file
63
src/database/server/models/plugin.ts
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
import { and, desc, eq } from 'drizzle-orm/expressions';
|
||||
|
||||
import { serverDB } from '@/database/server';
|
||||
|
||||
import { InstalledPluginItem, NewInstalledPlugin, installedPlugins } from '../schemas/lobechat';
|
||||
|
||||
export class PluginModel {
|
||||
private userId: string;
|
||||
|
||||
constructor(userId: string) {
|
||||
this.userId = userId;
|
||||
}
|
||||
|
||||
create = async (
|
||||
params: Pick<NewInstalledPlugin, 'type' | 'identifier' | 'manifest' | 'customParams'>,
|
||||
) => {
|
||||
const [result] = await serverDB
|
||||
.insert(installedPlugins)
|
||||
.values({ ...params, createdAt: new Date(), updatedAt: new Date(), userId: this.userId })
|
||||
.returning();
|
||||
|
||||
return result;
|
||||
};
|
||||
|
||||
delete = async (id: string) => {
|
||||
return serverDB
|
||||
.delete(installedPlugins)
|
||||
.where(and(eq(installedPlugins.identifier, id), eq(installedPlugins.userId, this.userId)));
|
||||
};
|
||||
|
||||
deleteAll = async () => {
|
||||
return serverDB.delete(installedPlugins).where(eq(installedPlugins.userId, this.userId));
|
||||
};
|
||||
|
||||
query = async () => {
|
||||
return serverDB
|
||||
.select({
|
||||
createdAt: installedPlugins.createdAt,
|
||||
customParams: installedPlugins.customParams,
|
||||
identifier: installedPlugins.identifier,
|
||||
manifest: installedPlugins.manifest,
|
||||
settings: installedPlugins.settings,
|
||||
type: installedPlugins.type,
|
||||
updatedAt: installedPlugins.updatedAt,
|
||||
})
|
||||
.from(installedPlugins)
|
||||
.where(eq(installedPlugins.userId, this.userId))
|
||||
.orderBy(desc(installedPlugins.createdAt));
|
||||
};
|
||||
|
||||
findById = async (id: string) => {
|
||||
return serverDB.query.installedPlugins.findFirst({
|
||||
where: and(eq(installedPlugins.identifier, id), eq(installedPlugins.userId, this.userId)),
|
||||
});
|
||||
};
|
||||
|
||||
async update(id: string, value: Partial<InstalledPluginItem>) {
|
||||
return serverDB
|
||||
.update(installedPlugins)
|
||||
.set({ ...value, updatedAt: new Date() })
|
||||
.where(and(eq(installedPlugins.identifier, id), eq(installedPlugins.userId, this.userId)));
|
||||
}
|
||||
}
|
||||
290
src/database/server/models/session.ts
Normal file
290
src/database/server/models/session.ts
Normal file
|
|
@ -0,0 +1,290 @@
|
|||
import { Column, asc, count, inArray, like, sql } from 'drizzle-orm';
|
||||
import { and, desc, eq, isNull, not, or } from 'drizzle-orm/expressions';
|
||||
|
||||
import { appEnv } from '@/config/app';
|
||||
import { INBOX_SESSION_ID } from '@/const/session';
|
||||
import { DEFAULT_AGENT_CONFIG } from '@/const/settings';
|
||||
import { serverDB } from '@/database/server/core/db';
|
||||
import { parseAgentConfig } from '@/server/globalConfig/parseDefaultAgent';
|
||||
import { ChatSessionList, LobeAgentSession } from '@/types/session';
|
||||
import { merge } from '@/utils/merge';
|
||||
|
||||
import {
|
||||
AgentItem,
|
||||
NewAgent,
|
||||
NewSession,
|
||||
SessionItem,
|
||||
agents,
|
||||
agentsToSessions,
|
||||
sessionGroups,
|
||||
sessions,
|
||||
} from '../schemas/lobechat';
|
||||
import { idGenerator } from '../utils/idGenerator';
|
||||
|
||||
export class SessionModel {
|
||||
private userId: string;
|
||||
|
||||
constructor(userId: string) {
|
||||
this.userId = userId;
|
||||
}
|
||||
// **************** Query *************** //
|
||||
|
||||
async query({ current = 0, pageSize = 9999 } = {}) {
|
||||
const offset = current * pageSize;
|
||||
|
||||
return serverDB.query.sessions.findMany({
|
||||
limit: pageSize,
|
||||
offset,
|
||||
orderBy: [desc(sessions.updatedAt)],
|
||||
where: and(eq(sessions.userId, this.userId), not(eq(sessions.slug, INBOX_SESSION_ID))),
|
||||
with: { agentsToSessions: { columns: {}, with: { agent: true } }, group: true },
|
||||
});
|
||||
}
|
||||
|
||||
async queryWithGroups(): Promise<ChatSessionList> {
|
||||
// 查询所有会话
|
||||
const result = await this.query();
|
||||
|
||||
const groups = await serverDB.query.sessionGroups.findMany({
|
||||
orderBy: [asc(sessionGroups.sort), desc(sessionGroups.createdAt)],
|
||||
where: eq(sessions.userId, this.userId),
|
||||
});
|
||||
|
||||
return {
|
||||
sessionGroups: groups as unknown as ChatSessionList['sessionGroups'],
|
||||
sessions: result.map((item) => this.mapSessionItem(item as any)),
|
||||
};
|
||||
}
|
||||
|
||||
async queryByKeyword(keyword: string) {
|
||||
if (!keyword) return [];
|
||||
|
||||
const keywordLowerCase = keyword.toLowerCase();
|
||||
|
||||
const data = await this.findSessions({ keyword: keywordLowerCase });
|
||||
|
||||
return data.map((item) => this.mapSessionItem(item as any));
|
||||
}
|
||||
|
||||
async findByIdOrSlug(
|
||||
idOrSlug: string,
|
||||
): Promise<(SessionItem & { agent: AgentItem }) | undefined> {
|
||||
const result = await serverDB.query.sessions.findFirst({
|
||||
where: and(
|
||||
or(eq(sessions.id, idOrSlug), eq(sessions.slug, idOrSlug)),
|
||||
eq(sessions.userId, this.userId),
|
||||
),
|
||||
with: { agentsToSessions: { columns: {}, with: { agent: true } }, group: true },
|
||||
});
|
||||
|
||||
if (!result) return;
|
||||
|
||||
return { ...result, agent: (result?.agentsToSessions?.[0] as any)?.agent } as any;
|
||||
}
|
||||
|
||||
async count() {
|
||||
const result = await serverDB
|
||||
.select({
|
||||
count: count(),
|
||||
})
|
||||
.from(sessions)
|
||||
.where(eq(sessions.userId, this.userId))
|
||||
.execute();
|
||||
|
||||
return result[0].count;
|
||||
}
|
||||
|
||||
// **************** Create *************** //
|
||||
|
||||
async create({
|
||||
id = idGenerator('sessions'),
|
||||
type = 'agent',
|
||||
session = {},
|
||||
config = {},
|
||||
slug,
|
||||
}: {
|
||||
config?: Partial<NewAgent>;
|
||||
id?: string;
|
||||
session?: Partial<NewSession>;
|
||||
slug?: string;
|
||||
type: 'agent' | 'group';
|
||||
}): Promise<SessionItem> {
|
||||
return serverDB.transaction(async (trx) => {
|
||||
const newAgents = await trx
|
||||
.insert(agents)
|
||||
.values({
|
||||
...config,
|
||||
createdAt: new Date(),
|
||||
id: idGenerator('agents'),
|
||||
updatedAt: new Date(),
|
||||
userId: this.userId,
|
||||
})
|
||||
.returning();
|
||||
|
||||
const result = await trx
|
||||
.insert(sessions)
|
||||
.values({
|
||||
...session,
|
||||
createdAt: new Date(),
|
||||
id,
|
||||
slug,
|
||||
type,
|
||||
updatedAt: new Date(),
|
||||
userId: this.userId,
|
||||
})
|
||||
.returning();
|
||||
|
||||
await trx.insert(agentsToSessions).values({
|
||||
agentId: newAgents[0].id,
|
||||
sessionId: id,
|
||||
});
|
||||
|
||||
return result[0];
|
||||
});
|
||||
}
|
||||
|
||||
async createInbox() {
|
||||
const serverAgentConfig = parseAgentConfig(appEnv.DEFAULT_AGENT_CONFIG) || {};
|
||||
|
||||
return await this.create({
|
||||
config: merge(DEFAULT_AGENT_CONFIG, serverAgentConfig),
|
||||
slug: INBOX_SESSION_ID,
|
||||
type: 'agent',
|
||||
});
|
||||
}
|
||||
|
||||
async batchCreate(newSessions: NewSession[]) {
|
||||
const sessionsToInsert = newSessions.map((s) => {
|
||||
return {
|
||||
...s,
|
||||
id: this.genId(),
|
||||
userId: this.userId,
|
||||
};
|
||||
});
|
||||
|
||||
return serverDB.insert(sessions).values(sessionsToInsert);
|
||||
}
|
||||
|
||||
async duplicate(id: string, newTitle?: string) {
|
||||
const result = await this.findByIdOrSlug(id);
|
||||
|
||||
if (!result) return;
|
||||
|
||||
const { agent, ...session } = result;
|
||||
const sessionId = this.genId();
|
||||
|
||||
return this.create({
|
||||
config: agent,
|
||||
id: sessionId,
|
||||
session: {
|
||||
...session,
|
||||
title: newTitle || session.title,
|
||||
},
|
||||
type: 'agent',
|
||||
});
|
||||
}
|
||||
|
||||
// **************** Delete *************** //
|
||||
|
||||
/**
|
||||
* Delete a session, also delete all messages and topics associated with it.
|
||||
*/
|
||||
async delete(id: string) {
|
||||
return serverDB
|
||||
.delete(sessions)
|
||||
.where(and(eq(sessions.id, id), eq(sessions.userId, this.userId)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Batch delete sessions, also delete all messages and topics associated with them.
|
||||
*/
|
||||
async batchDelete(ids: string[]) {
|
||||
return serverDB
|
||||
.delete(sessions)
|
||||
.where(and(inArray(sessions.id, ids), eq(sessions.userId, this.userId)));
|
||||
}
|
||||
|
||||
async deleteAll() {
|
||||
return serverDB.delete(sessions).where(eq(sessions.userId, this.userId));
|
||||
}
|
||||
// **************** Update *************** //
|
||||
|
||||
async update(id: string, data: Partial<SessionItem>) {
|
||||
return serverDB
|
||||
.update(sessions)
|
||||
.set(data)
|
||||
.where(and(eq(sessions.id, id), eq(sessions.userId, this.userId)))
|
||||
.returning();
|
||||
}
|
||||
|
||||
async updateConfig(id: string, data: Partial<AgentItem>) {
|
||||
return serverDB
|
||||
.update(agents)
|
||||
.set(data)
|
||||
.where(and(eq(agents.id, id), eq(agents.userId, this.userId)));
|
||||
}
|
||||
|
||||
// **************** Helper *************** //
|
||||
|
||||
private genId = () => idGenerator('sessions');
|
||||
|
||||
private mapSessionItem = ({
|
||||
agentsToSessions,
|
||||
title,
|
||||
backgroundColor,
|
||||
description,
|
||||
avatar,
|
||||
groupId,
|
||||
...res
|
||||
}: SessionItem & { agentsToSessions?: { agent: AgentItem }[] }): LobeAgentSession => {
|
||||
// TODO: 未来这里需要更好的实现方案,目前只取第一个
|
||||
const agent = agentsToSessions?.[0]?.agent;
|
||||
return {
|
||||
...res,
|
||||
group: groupId,
|
||||
meta: {
|
||||
avatar: agent?.avatar ?? avatar ?? undefined,
|
||||
backgroundColor: agent?.backgroundColor ?? backgroundColor ?? undefined,
|
||||
description: agent?.description ?? description ?? undefined,
|
||||
title: agent?.title ?? title ?? undefined,
|
||||
},
|
||||
model: agent?.model,
|
||||
} as any;
|
||||
};
|
||||
|
||||
async findSessions(params: {
|
||||
current?: number;
|
||||
group?: string;
|
||||
keyword?: string;
|
||||
pageSize?: number;
|
||||
pinned?: boolean;
|
||||
}) {
|
||||
const { pinned, keyword, group, pageSize = 9999, current = 0 } = params;
|
||||
|
||||
const offset = current * pageSize;
|
||||
return serverDB.query.sessions.findMany({
|
||||
limit: pageSize,
|
||||
offset,
|
||||
orderBy: [desc(sessions.updatedAt)],
|
||||
where: and(
|
||||
eq(sessions.userId, this.userId),
|
||||
pinned !== undefined ? eq(sessions.pinned, pinned) : eq(sessions.userId, this.userId),
|
||||
keyword
|
||||
? or(
|
||||
like(
|
||||
sql`lower(${sessions.title})` as unknown as Column,
|
||||
`%${keyword.toLowerCase()}%`,
|
||||
),
|
||||
like(
|
||||
sql`lower(${sessions.description})` as unknown as Column,
|
||||
`%${keyword.toLowerCase()}%`,
|
||||
),
|
||||
)
|
||||
: eq(sessions.userId, this.userId),
|
||||
group ? eq(sessions.groupId, group) : isNull(sessions.groupId),
|
||||
),
|
||||
|
||||
with: { agentsToSessions: { columns: {}, with: { agent: true } }, group: true },
|
||||
});
|
||||
}
|
||||
}
|
||||
69
src/database/server/models/sessionGroup.ts
Normal file
69
src/database/server/models/sessionGroup.ts
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
import { eq } from 'drizzle-orm';
|
||||
import { and, asc, desc } from 'drizzle-orm/expressions';
|
||||
|
||||
import { serverDB } from '@/database/server';
|
||||
import { idGenerator } from '@/database/server/utils/idGenerator';
|
||||
|
||||
import { SessionGroupItem, sessionGroups } from '../schemas/lobechat';
|
||||
|
||||
export class SessionGroupModel {
|
||||
private userId: string;
|
||||
|
||||
constructor(userId: string) {
|
||||
this.userId = userId;
|
||||
}
|
||||
|
||||
create = async (params: { name: string; sort?: number }) => {
|
||||
const [result] = await serverDB
|
||||
.insert(sessionGroups)
|
||||
.values({ ...params, id: this.genId(), userId: this.userId })
|
||||
.returning();
|
||||
|
||||
return result;
|
||||
};
|
||||
|
||||
delete = async (id: string) => {
|
||||
return serverDB
|
||||
.delete(sessionGroups)
|
||||
.where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)));
|
||||
};
|
||||
|
||||
deleteAll = async () => {
|
||||
return serverDB.delete(sessionGroups);
|
||||
};
|
||||
|
||||
query = async () => {
|
||||
return serverDB.query.sessionGroups.findMany({
|
||||
orderBy: [asc(sessionGroups.sort), desc(sessionGroups.createdAt)],
|
||||
where: eq(sessionGroups.userId, this.userId),
|
||||
});
|
||||
};
|
||||
|
||||
findById = async (id: string) => {
|
||||
return serverDB.query.sessionGroups.findFirst({
|
||||
where: and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)),
|
||||
});
|
||||
};
|
||||
|
||||
async update(id: string, value: Partial<SessionGroupItem>) {
|
||||
return serverDB
|
||||
.update(sessionGroups)
|
||||
.set({ ...value, updatedAt: new Date() })
|
||||
.where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)));
|
||||
}
|
||||
|
||||
async updateOrder(sortMap: { id: string; sort: number }[]) {
|
||||
await serverDB.transaction(async (tx) => {
|
||||
const updates = sortMap.map(({ id, sort }) => {
|
||||
return tx
|
||||
.update(sessionGroups)
|
||||
.set({ sort, updatedAt: new Date() })
|
||||
.where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)));
|
||||
});
|
||||
|
||||
await Promise.all(updates);
|
||||
});
|
||||
}
|
||||
|
||||
private genId = () => idGenerator('sessionGroups');
|
||||
}
|
||||
265
src/database/server/models/topic.ts
Normal file
265
src/database/server/models/topic.ts
Normal file
|
|
@ -0,0 +1,265 @@
|
|||
import { Column, count, inArray, sql } from 'drizzle-orm';
|
||||
import { and, desc, eq, exists, isNull, like, or } from 'drizzle-orm/expressions';
|
||||
|
||||
import { serverDB } from '@/database/server/core/db';
|
||||
|
||||
import { NewMessage, TopicItem, messages, topics } from '../schemas/lobechat';
|
||||
import { idGenerator } from '../utils/idGenerator';
|
||||
|
||||
export interface CreateTopicParams {
|
||||
favorite?: boolean;
|
||||
messages?: string[];
|
||||
sessionId?: string | null;
|
||||
title: string;
|
||||
}
|
||||
|
||||
interface QueryTopicParams {
|
||||
current?: number;
|
||||
pageSize?: number;
|
||||
sessionId?: string | null;
|
||||
}
|
||||
|
||||
export class TopicModel {
|
||||
private userId: string;
|
||||
|
||||
constructor(userId: string) {
|
||||
this.userId = userId;
|
||||
}
|
||||
// **************** Query *************** //
|
||||
|
||||
async query({ current = 0, pageSize = 9999, sessionId }: QueryTopicParams = {}) {
|
||||
const offset = current * pageSize;
|
||||
|
||||
return (
|
||||
serverDB
|
||||
.select({
|
||||
createdAt: topics.createdAt,
|
||||
favorite: topics.favorite,
|
||||
id: topics.id,
|
||||
title: topics.title,
|
||||
updatedAt: topics.updatedAt,
|
||||
})
|
||||
.from(topics)
|
||||
.where(and(eq(topics.userId, this.userId), this.matchSession(sessionId)))
|
||||
// In boolean sorting, false is considered "smaller" than true.
|
||||
// So here we use desc to ensure that topics with favorite as true are in front.
|
||||
.orderBy(desc(topics.favorite), desc(topics.updatedAt))
|
||||
.limit(pageSize)
|
||||
.offset(offset)
|
||||
);
|
||||
}
|
||||
|
||||
async findById(id: string) {
|
||||
return serverDB.query.topics.findFirst({
|
||||
where: and(eq(topics.id, id), eq(topics.userId, this.userId)),
|
||||
});
|
||||
}
|
||||
|
||||
async queryAll(): Promise<TopicItem[]> {
|
||||
return serverDB
|
||||
.select()
|
||||
.from(topics)
|
||||
.orderBy(topics.updatedAt)
|
||||
.where(eq(topics.userId, this.userId))
|
||||
.execute();
|
||||
}
|
||||
|
||||
async queryByKeyword(keyword: string, sessionId?: string | null): Promise<TopicItem[]> {
|
||||
if (!keyword) return [];
|
||||
|
||||
const keywordLowerCase = keyword.toLowerCase();
|
||||
|
||||
const matchKeyword = (field: any) =>
|
||||
like(sql`lower(${field})` as unknown as Column, `%${keywordLowerCase}%`);
|
||||
|
||||
return serverDB.query.topics.findMany({
|
||||
orderBy: [desc(topics.updatedAt)],
|
||||
where: and(
|
||||
eq(topics.userId, this.userId),
|
||||
this.matchSession(sessionId),
|
||||
or(
|
||||
matchKeyword(topics.title),
|
||||
exists(
|
||||
serverDB
|
||||
.select()
|
||||
.from(messages)
|
||||
.where(and(eq(messages.topicId, topics.id), or(matchKeyword(messages.content)))),
|
||||
),
|
||||
),
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
async count() {
|
||||
const result = await serverDB
|
||||
.select({
|
||||
count: count(),
|
||||
})
|
||||
.from(topics)
|
||||
.where(eq(topics.userId, this.userId))
|
||||
.execute();
|
||||
|
||||
return result[0].count;
|
||||
}
|
||||
|
||||
// **************** Create *************** //
|
||||
|
||||
async create(
|
||||
{ messages: messageIds, ...params }: CreateTopicParams,
|
||||
id: string = this.genId(),
|
||||
): Promise<TopicItem> {
|
||||
return serverDB.transaction(async (tx) => {
|
||||
// 在 topics 表中插入新的 topic
|
||||
const [topic] = await tx
|
||||
.insert(topics)
|
||||
.values({
|
||||
...params,
|
||||
id: id,
|
||||
userId: this.userId,
|
||||
})
|
||||
.returning();
|
||||
|
||||
// 如果有关联的 messages, 更新它们的 topicId
|
||||
if (messageIds && messageIds.length > 0) {
|
||||
await tx
|
||||
.update(messages)
|
||||
.set({ topicId: topic.id })
|
||||
.where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIds)));
|
||||
}
|
||||
|
||||
return topic;
|
||||
});
|
||||
}
|
||||
|
||||
async batchCreate(topicParams: (CreateTopicParams & { id?: string })[]) {
|
||||
// 开始一个事务
|
||||
return serverDB.transaction(async (tx) => {
|
||||
// 在 topics 表中批量插入新的 topics
|
||||
const createdTopics = await tx
|
||||
.insert(topics)
|
||||
.values(
|
||||
topicParams.map((params) => ({
|
||||
favorite: params.favorite,
|
||||
id: params.id || this.genId(),
|
||||
sessionId: params.sessionId,
|
||||
title: params.title,
|
||||
userId: this.userId,
|
||||
})),
|
||||
)
|
||||
.returning();
|
||||
|
||||
// 对每个新创建的 topic,更新关联的 messages 的 topicId
|
||||
await Promise.all(
|
||||
createdTopics.map(async (topic, index) => {
|
||||
const messageIds = topicParams[index].messages;
|
||||
if (messageIds && messageIds.length > 0) {
|
||||
await tx
|
||||
.update(messages)
|
||||
.set({ topicId: topic.id })
|
||||
.where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIds)));
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
return createdTopics;
|
||||
});
|
||||
}
|
||||
|
||||
async duplicate(topicId: string, newTitle?: string) {
|
||||
return serverDB.transaction(async (tx) => {
|
||||
// find original topic
|
||||
const originalTopic = await tx.query.topics.findFirst({
|
||||
where: and(eq(topics.id, topicId), eq(topics.userId, this.userId)),
|
||||
});
|
||||
|
||||
if (!originalTopic) {
|
||||
throw new Error(`Topic with id ${topicId} not found`);
|
||||
}
|
||||
|
||||
// copy topic
|
||||
const [duplicatedTopic] = await tx
|
||||
.insert(topics)
|
||||
.values({
|
||||
...originalTopic,
|
||||
id: this.genId(),
|
||||
title: newTitle || originalTopic?.title,
|
||||
})
|
||||
.returning();
|
||||
|
||||
// 查找与原始 topic 关联的 messages
|
||||
const originalMessages = await tx
|
||||
.select()
|
||||
.from(messages)
|
||||
.where(and(eq(messages.topicId, topicId), eq(messages.userId, this.userId)));
|
||||
|
||||
// copy messages
|
||||
const duplicatedMessages = await Promise.all(
|
||||
originalMessages.map(async (message) => {
|
||||
const result = (await tx
|
||||
.insert(messages)
|
||||
.values({
|
||||
...message,
|
||||
id: idGenerator('messages'),
|
||||
topicId: duplicatedTopic.id,
|
||||
})
|
||||
.returning()) as NewMessage[];
|
||||
|
||||
return result[0];
|
||||
}),
|
||||
);
|
||||
|
||||
return {
|
||||
messages: duplicatedMessages,
|
||||
topic: duplicatedTopic,
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
// **************** Delete *************** //
|
||||
|
||||
/**
|
||||
* Delete a session, also delete all messages and topics associated with it.
|
||||
*/
|
||||
async delete(id: string) {
|
||||
return serverDB.delete(topics).where(and(eq(topics.id, id), eq(topics.userId, this.userId)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes multiple topics based on the sessionId.
|
||||
*/
|
||||
async batchDeleteBySessionId(sessionId?: string | null) {
|
||||
return serverDB
|
||||
.delete(topics)
|
||||
.where(and(this.matchSession(sessionId), eq(topics.userId, this.userId)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes multiple topics and all messages associated with them in a transaction.
|
||||
*/
|
||||
async batchDelete(ids: string[]) {
|
||||
return serverDB
|
||||
.delete(topics)
|
||||
.where(and(inArray(topics.id, ids), eq(topics.userId, this.userId)));
|
||||
}
|
||||
|
||||
async deleteAll() {
|
||||
return serverDB.delete(topics).where(eq(topics.userId, this.userId));
|
||||
}
|
||||
|
||||
// **************** Update *************** //
|
||||
|
||||
async update(id: string, data: Partial<TopicItem>) {
|
||||
return serverDB
|
||||
.update(topics)
|
||||
.set({ ...data, updatedAt: new Date() })
|
||||
.where(and(eq(topics.id, id), eq(topics.userId, this.userId)))
|
||||
.returning();
|
||||
}
|
||||
|
||||
// **************** Helper *************** //
|
||||
|
||||
private genId = () => idGenerator('topics');
|
||||
|
||||
private matchSession = (sessionId?: string | null) =>
|
||||
sessionId ? eq(topics.sessionId, sessionId) : isNull(topics.sessionId);
|
||||
}
|
||||
138
src/database/server/models/user.ts
Normal file
138
src/database/server/models/user.ts
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
import { TRPCError } from '@trpc/server';
|
||||
import { eq } from 'drizzle-orm';
|
||||
import { DeepPartial } from 'utility-types';
|
||||
|
||||
import { serverDB } from '@/database/server';
|
||||
import { KeyVaultsGateKeeper } from '@/server/keyVaultsEncrypt';
|
||||
import { UserPreference } from '@/types/user';
|
||||
import { UserSettings } from '@/types/user/settings';
|
||||
import { merge } from '@/utils/merge';
|
||||
|
||||
import { NewUser, UserItem, userSettings, users } from '../schemas/lobechat';
|
||||
import { SessionModel } from './session';
|
||||
|
||||
export class UserModel {
|
||||
createUser = async (params: NewUser) => {
|
||||
const [user] = await serverDB
|
||||
.insert(users)
|
||||
.values({ ...params })
|
||||
.returning();
|
||||
|
||||
// Create an inbox session for the user
|
||||
const model = new SessionModel(user.id);
|
||||
|
||||
await model.createInbox();
|
||||
};
|
||||
|
||||
deleteUser = async (id: string) => {
|
||||
return serverDB.delete(users).where(eq(users.id, id));
|
||||
};
|
||||
|
||||
findById = async (id: string) => {
|
||||
return serverDB.query.users.findFirst({ where: eq(users.id, id) });
|
||||
};
|
||||
|
||||
getUserState = async (id: string) => {
|
||||
const result = await serverDB
|
||||
.select({
|
||||
isOnboarded: users.isOnboarded,
|
||||
preference: users.preference,
|
||||
|
||||
settingsDefaultAgent: userSettings.defaultAgent,
|
||||
settingsGeneral: userSettings.general,
|
||||
settingsKeyVaults: userSettings.keyVaults,
|
||||
settingsLanguageModel: userSettings.languageModel,
|
||||
settingsSystemAgent: userSettings.systemAgent,
|
||||
settingsTTS: userSettings.tts,
|
||||
settingsTool: userSettings.tool,
|
||||
})
|
||||
.from(users)
|
||||
.where(eq(users.id, id))
|
||||
.leftJoin(userSettings, eq(users.id, userSettings.id));
|
||||
|
||||
if (!result || !result[0]) {
|
||||
throw new TRPCError({ code: 'UNAUTHORIZED', message: 'user not found' });
|
||||
}
|
||||
|
||||
const state = result[0];
|
||||
|
||||
// Decrypt keyVaults
|
||||
let decryptKeyVaults = {};
|
||||
if (state.settingsKeyVaults) {
|
||||
const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey();
|
||||
const { wasAuthentic, plaintext } = await gateKeeper.decrypt(state.settingsKeyVaults);
|
||||
|
||||
if (wasAuthentic) {
|
||||
try {
|
||||
decryptKeyVaults = JSON.parse(plaintext);
|
||||
} catch (e) {
|
||||
console.error(`Failed to parse keyVaults ,userId: ${id}. Error:`, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const settings: DeepPartial<UserSettings> = {
|
||||
defaultAgent: state.settingsDefaultAgent || {},
|
||||
general: state.settingsGeneral || {},
|
||||
keyVaults: decryptKeyVaults,
|
||||
languageModel: state.settingsLanguageModel || {},
|
||||
systemAgent: state.settingsSystemAgent || {},
|
||||
tool: state.settingsTool || {},
|
||||
tts: state.settingsTTS || {},
|
||||
};
|
||||
|
||||
return {
|
||||
isOnboarded: state.isOnboarded,
|
||||
preference: state.preference as UserPreference,
|
||||
settings,
|
||||
userId: id,
|
||||
};
|
||||
};
|
||||
|
||||
async updateUser(id: string, value: Partial<UserItem>) {
|
||||
return serverDB
|
||||
.update(users)
|
||||
.set({ ...value, updatedAt: new Date() })
|
||||
.where(eq(users.id, id));
|
||||
}
|
||||
|
||||
async deleteSetting(id: string) {
|
||||
return serverDB.delete(userSettings).where(eq(userSettings.id, id));
|
||||
}
|
||||
|
||||
async updateSetting(id: string, value: Partial<UserSettings>) {
|
||||
const { keyVaults, ...res } = value;
|
||||
|
||||
// Encrypt keyVaults
|
||||
let encryptedKeyVaults: string | null = null;
|
||||
|
||||
if (keyVaults) {
|
||||
// TODO: better to add a validation
|
||||
const data = JSON.stringify(keyVaults);
|
||||
const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey();
|
||||
|
||||
encryptedKeyVaults = await gateKeeper.encrypt(data);
|
||||
}
|
||||
|
||||
const newValue = { ...res, keyVaults: encryptedKeyVaults };
|
||||
|
||||
// update or create user settings
|
||||
const settings = await serverDB.query.userSettings.findFirst({ where: eq(users.id, id) });
|
||||
if (!settings) {
|
||||
await serverDB.insert(userSettings).values({ id, ...newValue });
|
||||
return;
|
||||
}
|
||||
|
||||
return serverDB.update(userSettings).set(newValue).where(eq(userSettings.id, id));
|
||||
}
|
||||
|
||||
async updatePreference(id: string, value: Partial<UserPreference>) {
|
||||
const user = await serverDB.query.users.findFirst({ where: eq(users.id, id) });
|
||||
if (!user) return;
|
||||
|
||||
return serverDB
|
||||
.update(users)
|
||||
.set({ preference: merge(user.preference, value) })
|
||||
.where(eq(users.id, id));
|
||||
}
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
954
src/database/server/modules/DataImporter/__tests__/index.test.ts
Normal file
954
src/database/server/modules/DataImporter/__tests__/index.test.ts
Normal file
|
|
@ -0,0 +1,954 @@
|
|||
// @vitest-environment node
|
||||
import { eq, inArray } from 'drizzle-orm';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { getTestDBInstance } from '@/database/server/core/dbForTest';
|
||||
import {
|
||||
agents,
|
||||
agentsToSessions,
|
||||
messages,
|
||||
sessionGroups,
|
||||
sessions,
|
||||
topics,
|
||||
users,
|
||||
} from '@/database/server/schemas/lobechat';
|
||||
import { CURRENT_CONFIG_VERSION } from '@/migrations';
|
||||
import { ImportResult } from '@/services/config';
|
||||
import { ImporterEntryData } from '@/types/importer';
|
||||
|
||||
import { DataImporter } from '../index';
|
||||
import mockImportData from './fixtures/messages.json';
|
||||
|
||||
let serverDB = await getTestDBInstance();
|
||||
|
||||
vi.mock('@/database/server/core/db', async () => ({
|
||||
get serverDB() {
|
||||
return serverDB;
|
||||
},
|
||||
}));
|
||||
|
||||
const userId = 'test-user-id';
|
||||
let importer: DataImporter;
|
||||
|
||||
beforeEach(async () => {
|
||||
await serverDB.delete(users);
|
||||
|
||||
// 创建测试数据
|
||||
await serverDB.transaction(async (tx) => {
|
||||
await tx.insert(users).values({ id: userId });
|
||||
});
|
||||
|
||||
importer = new DataImporter(userId);
|
||||
});
|
||||
|
||||
describe('DataImporter', () => {
|
||||
describe('import sessionGroups', () => {
|
||||
it('should import session groups and return correct result', async () => {
|
||||
const data: ImporterEntryData = {
|
||||
version: CURRENT_CONFIG_VERSION,
|
||||
sessionGroups: [
|
||||
{ id: 'group1', name: 'Group 1', createdAt: 1715186011586, updatedAt: 1715186015053 },
|
||||
{ id: 'group2', name: 'Group 2', createdAt: 1715186011586, updatedAt: 1715186015053 },
|
||||
],
|
||||
};
|
||||
|
||||
const result = await importer.importData(data);
|
||||
|
||||
expect(result.sessionGroups.added).toBe(2);
|
||||
expect(result.sessionGroups.skips).toBe(0);
|
||||
expect(result.sessionGroups.errors).toBe(0);
|
||||
|
||||
const groups = await serverDB.query.sessionGroups.findMany({
|
||||
where: eq(sessionGroups.userId, userId),
|
||||
});
|
||||
expect(groups).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('should skip existing session groups and return correct result', async () => {
|
||||
await serverDB
|
||||
.insert(sessionGroups)
|
||||
.values({ clientId: 'group1', name: 'Existing Group', userId })
|
||||
.execute();
|
||||
|
||||
const data: ImporterEntryData = {
|
||||
version: CURRENT_CONFIG_VERSION,
|
||||
sessionGroups: [
|
||||
{ id: 'group1', name: 'Group 1', createdAt: 1715186011586, updatedAt: 1715186015053 },
|
||||
{ id: 'group2', name: 'Group 2', createdAt: 1715186011586, updatedAt: 1715186015053 },
|
||||
],
|
||||
};
|
||||
|
||||
const result = await importer.importData(data);
|
||||
|
||||
expect(result.sessionGroups.added).toBe(1);
|
||||
expect(result.sessionGroups.skips).toBe(1);
|
||||
expect(result.sessionGroups.errors).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('import sessions', () => {
|
||||
it('should import sessions and return correct result', async () => {
|
||||
const data: ImporterEntryData = {
|
||||
version: CURRENT_CONFIG_VERSION,
|
||||
sessions: [
|
||||
{
|
||||
id: 'session1',
|
||||
createdAt: '2022-05-14T18:18:10.494Z',
|
||||
updatedAt: '2023-01-01',
|
||||
type: 'agent',
|
||||
config: {
|
||||
model: 'abc',
|
||||
chatConfig: {} as any,
|
||||
params: {},
|
||||
systemRole: 'abc',
|
||||
tts: {} as any,
|
||||
},
|
||||
meta: {
|
||||
title: 'Session 1',
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'session2',
|
||||
createdAt: '2022-05-14T18:18:10.494Z',
|
||||
updatedAt: '2023-01-01',
|
||||
type: 'agent',
|
||||
config: {
|
||||
model: 'abc',
|
||||
chatConfig: {} as any,
|
||||
params: {},
|
||||
systemRole: 'abc',
|
||||
tts: {} as any,
|
||||
},
|
||||
meta: {
|
||||
title: 'Session 2',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await importer.importData(data);
|
||||
|
||||
expect(result.sessions.added).toBe(2);
|
||||
expect(result.sessions.skips).toBe(0);
|
||||
expect(result.sessions.errors).toBe(0);
|
||||
|
||||
const importedSessions = await serverDB.query.sessions.findMany({
|
||||
where: eq(sessions.userId, userId),
|
||||
});
|
||||
expect(importedSessions).toHaveLength(2);
|
||||
|
||||
const agentCount = await serverDB.query.agents.findMany({
|
||||
where: eq(agents.userId, userId),
|
||||
});
|
||||
|
||||
expect(agentCount.length).toBe(2);
|
||||
|
||||
const agentSessionCount = await serverDB.query.agentsToSessions.findMany();
|
||||
expect(agentSessionCount.length).toBe(2);
|
||||
});
|
||||
|
||||
it('should skip existing sessions and return correct result', async () => {
|
||||
await serverDB.insert(sessions).values({ clientId: 'session1', userId }).execute();
|
||||
|
||||
const data: ImporterEntryData = {
|
||||
version: CURRENT_CONFIG_VERSION,
|
||||
sessions: [
|
||||
{
|
||||
id: 'session1',
|
||||
createdAt: '2022-05-14T18:18:10.494Z',
|
||||
updatedAt: '2023-01-01',
|
||||
type: 'agent',
|
||||
config: {
|
||||
model: 'abc',
|
||||
chatConfig: {} as any,
|
||||
params: {},
|
||||
systemRole: 'abc',
|
||||
tts: {} as any,
|
||||
},
|
||||
meta: {
|
||||
title: 'Session 1',
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'session2',
|
||||
createdAt: '2022-05-14T18:18:10.494Z',
|
||||
updatedAt: '2023-01-01',
|
||||
type: 'agent',
|
||||
config: {
|
||||
model: 'abc',
|
||||
chatConfig: {} as any,
|
||||
params: {},
|
||||
systemRole: 'abc',
|
||||
tts: {} as any,
|
||||
},
|
||||
meta: {
|
||||
title: 'Session 2',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await importer.importData(data);
|
||||
|
||||
expect(result.sessions.added).toBe(1);
|
||||
expect(result.sessions.skips).toBe(1);
|
||||
expect(result.sessions.errors).toBe(0);
|
||||
});
|
||||
|
||||
it('should associate imported sessions with session groups', async () => {
|
||||
const data: ImporterEntryData = {
|
||||
version: CURRENT_CONFIG_VERSION,
|
||||
sessionGroups: [
|
||||
{ id: 'group1', name: 'Group 1', createdAt: 1715186011586, updatedAt: 1715186015053 },
|
||||
{ id: 'group2', name: 'Group 2', createdAt: 1715186011586, updatedAt: 1715186015053 },
|
||||
],
|
||||
sessions: [
|
||||
{
|
||||
id: 'session1',
|
||||
createdAt: '2022-05-14T18:18:10.494Z',
|
||||
updatedAt: '2023-01-01',
|
||||
type: 'agent',
|
||||
group: 'group1',
|
||||
config: {
|
||||
model: 'abc',
|
||||
chatConfig: {} as any,
|
||||
params: {},
|
||||
systemRole: 'abc',
|
||||
tts: {} as any,
|
||||
},
|
||||
meta: {
|
||||
title: 'Session 1',
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'session2',
|
||||
group: 'group2',
|
||||
createdAt: '2022-05-14T18:18:10.494Z',
|
||||
updatedAt: '2023-01-01',
|
||||
type: 'agent',
|
||||
config: {
|
||||
model: 'abc',
|
||||
chatConfig: {} as any,
|
||||
params: {},
|
||||
systemRole: 'abc',
|
||||
tts: {} as any,
|
||||
},
|
||||
meta: {
|
||||
title: 'Session 2',
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'session3',
|
||||
group: 'group4',
|
||||
createdAt: '2022-05-14T18:18:10.494Z',
|
||||
updatedAt: '2023-01-01',
|
||||
type: 'agent',
|
||||
config: {
|
||||
model: 'abc',
|
||||
chatConfig: {} as any,
|
||||
params: {},
|
||||
systemRole: 'abc',
|
||||
tts: {} as any,
|
||||
},
|
||||
meta: {
|
||||
title: 'Session 3',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await importer.importData(data);
|
||||
|
||||
expect(result.sessionGroups.added).toBe(2);
|
||||
expect(result.sessionGroups.skips).toBe(0);
|
||||
|
||||
expect(result.sessions.added).toBe(3);
|
||||
expect(result.sessions.skips).toBe(0);
|
||||
|
||||
// session 1 should be associated with group 1
|
||||
const session1 = await serverDB.query.sessions.findFirst({
|
||||
where: eq(sessions.clientId, 'session1'),
|
||||
with: { group: true },
|
||||
});
|
||||
expect(session1?.group).toBeDefined();
|
||||
|
||||
// session 3 should not have group
|
||||
const session3 = await serverDB.query.sessions.findFirst({
|
||||
where: eq(sessions.clientId, 'session3'),
|
||||
with: { group: true },
|
||||
});
|
||||
expect(session3?.group).toBeNull();
|
||||
});
|
||||
|
||||
it('should create agents and associate them with imported sessions', async () => {
|
||||
const data: ImporterEntryData = {
|
||||
version: CURRENT_CONFIG_VERSION,
|
||||
sessions: [
|
||||
{
|
||||
id: 'session1',
|
||||
createdAt: '2022-05-14T18:18:10.494Z',
|
||||
updatedAt: '2023-01-01',
|
||||
type: 'agent',
|
||||
config: {
|
||||
model: 'abc',
|
||||
chatConfig: {} as any,
|
||||
params: {},
|
||||
systemRole: 'Test Agent 1',
|
||||
tts: {} as any,
|
||||
},
|
||||
meta: {
|
||||
title: 'Session 1',
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'session2',
|
||||
createdAt: '2022-05-14T18:18:10.494Z',
|
||||
updatedAt: '2023-01-01',
|
||||
type: 'agent',
|
||||
config: {
|
||||
model: 'def',
|
||||
chatConfig: {} as any,
|
||||
params: {},
|
||||
systemRole: 'Test Agent 2',
|
||||
tts: {} as any,
|
||||
},
|
||||
meta: {
|
||||
title: 'Session 2',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
await importer.importData(data);
|
||||
|
||||
// 验证是否为每个 session 创建了对应的 agent
|
||||
const agentCount = await serverDB.query.agents.findMany({
|
||||
where: eq(agents.userId, userId),
|
||||
});
|
||||
expect(agentCount).toHaveLength(2);
|
||||
|
||||
// 验证 agent 的属性是否正确设置
|
||||
const agent1 = await serverDB.query.agents.findFirst({
|
||||
where: eq(agents.systemRole, 'Test Agent 1'),
|
||||
});
|
||||
expect(agent1?.model).toBe('abc');
|
||||
|
||||
const agent2 = await serverDB.query.agents.findFirst({
|
||||
where: eq(agents.systemRole, 'Test Agent 2'),
|
||||
});
|
||||
expect(agent2?.model).toBe('def');
|
||||
|
||||
// 验证 agentsToSessions 关联是否正确建立
|
||||
const session1 = await serverDB.query.sessions.findFirst({
|
||||
where: eq(sessions.clientId, 'session1'),
|
||||
});
|
||||
const session1Agent = await serverDB.query.agentsToSessions.findFirst({
|
||||
where: eq(agentsToSessions.sessionId, session1?.id!),
|
||||
with: { agent: true },
|
||||
});
|
||||
|
||||
expect((session1Agent?.agent as any).systemRole).toBe('Test Agent 1');
|
||||
|
||||
const session2 = await serverDB.query.sessions.findFirst({
|
||||
where: eq(sessions.clientId, 'session2'),
|
||||
});
|
||||
const session2Agent = await serverDB.query.agentsToSessions.findFirst({
|
||||
where: eq(agentsToSessions.sessionId, session2?.id!),
|
||||
with: { agent: true },
|
||||
});
|
||||
|
||||
expect((session2Agent?.agent as any).systemRole).toBe('Test Agent 2');
|
||||
});
|
||||
|
||||
it('should not create duplicate agents for existing sessions', async () => {
|
||||
// 先导入一些 sessions
|
||||
await importer.importData({
|
||||
sessions: [
|
||||
{
|
||||
id: 'session1',
|
||||
createdAt: '2022-05-14T18:18:10.494Z',
|
||||
updatedAt: '2023-01-01',
|
||||
type: 'agent',
|
||||
config: {
|
||||
model: 'abc',
|
||||
chatConfig: {} as any,
|
||||
params: {},
|
||||
systemRole: 'Test Agent 1',
|
||||
tts: {} as any,
|
||||
},
|
||||
meta: {
|
||||
title: 'Session 1',
|
||||
},
|
||||
},
|
||||
],
|
||||
version: CURRENT_CONFIG_VERSION,
|
||||
});
|
||||
|
||||
// 再次导入相同的 sessions
|
||||
await importer.importData({
|
||||
sessions: [
|
||||
{
|
||||
id: 'session1',
|
||||
createdAt: '2022-05-14T18:18:10.494Z',
|
||||
updatedAt: '2023-01-01',
|
||||
type: 'agent',
|
||||
config: {
|
||||
model: 'abc',
|
||||
chatConfig: {} as any,
|
||||
params: {},
|
||||
systemRole: 'Test Agent 1',
|
||||
tts: {} as any,
|
||||
},
|
||||
meta: {
|
||||
title: 'Session 1',
|
||||
},
|
||||
},
|
||||
],
|
||||
version: CURRENT_CONFIG_VERSION,
|
||||
});
|
||||
|
||||
// 验证只创建了一个 agent
|
||||
const agentCount = await serverDB.query.agents.findMany({
|
||||
where: eq(agents.userId, userId),
|
||||
});
|
||||
expect(agentCount).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('import topics', () => {
|
||||
it('should import topics and return correct result', async () => {
|
||||
const data: ImporterEntryData = {
|
||||
version: CURRENT_CONFIG_VERSION,
|
||||
topics: [
|
||||
{
|
||||
id: 'topic1',
|
||||
title: 'Topic 1',
|
||||
createdAt: 1715186011586,
|
||||
updatedAt: 1715186015053,
|
||||
sessionId: 'session1',
|
||||
},
|
||||
{
|
||||
id: 'topic2',
|
||||
title: 'Topic 2',
|
||||
createdAt: 1715186011586,
|
||||
updatedAt: 1715186015053,
|
||||
sessionId: 'session2',
|
||||
},
|
||||
],
|
||||
sessions: [
|
||||
{
|
||||
id: 'session1',
|
||||
createdAt: '2022-05-14T18:18:10.494Z',
|
||||
updatedAt: '2023-01-01',
|
||||
type: 'agent',
|
||||
config: {
|
||||
model: 'abc',
|
||||
chatConfig: {} as any,
|
||||
params: {},
|
||||
systemRole: 'abc',
|
||||
tts: {} as any,
|
||||
},
|
||||
meta: {
|
||||
title: 'Session 1',
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'session2',
|
||||
createdAt: '2022-05-14T18:18:10.494Z',
|
||||
updatedAt: '2023-01-01',
|
||||
type: 'agent',
|
||||
config: {
|
||||
model: 'abc',
|
||||
chatConfig: {} as any,
|
||||
params: {},
|
||||
systemRole: 'abc',
|
||||
tts: {} as any,
|
||||
},
|
||||
meta: {
|
||||
title: 'Session 2',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await importer.importData(data);
|
||||
|
||||
expect(result.topics.added).toBe(2);
|
||||
expect(result.topics.skips).toBe(0);
|
||||
expect(result.topics.errors).toBe(0);
|
||||
|
||||
const importedTopics = await serverDB.query.topics.findMany({
|
||||
where: eq(topics.userId, userId),
|
||||
});
|
||||
expect(importedTopics).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('should skip existing topics and return correct result', async () => {
|
||||
await serverDB
|
||||
.insert(topics)
|
||||
.values({ clientId: 'topic1', title: 'Existing Topic', userId })
|
||||
.execute();
|
||||
|
||||
const data: ImporterEntryData = {
|
||||
version: CURRENT_CONFIG_VERSION,
|
||||
topics: [
|
||||
{ id: 'topic1', title: 'Topic 1', createdAt: 1715186011586, updatedAt: 1715186015053 },
|
||||
{ id: 'topic2', title: 'Topic 2', createdAt: 1715186011586, updatedAt: 1715186015053 },
|
||||
],
|
||||
};
|
||||
|
||||
const result = await importer.importData(data);
|
||||
expect(result.topics.added).toBe(1);
|
||||
expect(result.topics.skips).toBe(1);
|
||||
expect(result.topics.errors).toBe(0);
|
||||
});
|
||||
|
||||
it('should associate imported topics with sessions', async () => {
|
||||
const data: ImporterEntryData = {
|
||||
version: CURRENT_CONFIG_VERSION,
|
||||
sessions: [
|
||||
{
|
||||
id: 'session1',
|
||||
createdAt: '2022-05-14T18:18:10.494Z',
|
||||
updatedAt: '2023-01-01',
|
||||
type: 'agent',
|
||||
config: {
|
||||
model: 'abc',
|
||||
chatConfig: {} as any,
|
||||
params: {},
|
||||
systemRole: 'abc',
|
||||
tts: {} as any,
|
||||
},
|
||||
meta: {
|
||||
title: 'Session 1',
|
||||
},
|
||||
},
|
||||
],
|
||||
topics: [
|
||||
{
|
||||
id: 'topic1',
|
||||
title: 'Topic 1',
|
||||
createdAt: 1715186011586,
|
||||
updatedAt: 1715186015053,
|
||||
sessionId: 'session1',
|
||||
},
|
||||
{ id: 'topic2', title: 'Topic 2', createdAt: 1715186011586, updatedAt: 1715186015053 },
|
||||
],
|
||||
};
|
||||
|
||||
await importer.importData(data);
|
||||
|
||||
// topic1 should be associated with session1
|
||||
const [topic1] = await serverDB
|
||||
.select({ sessionClientId: sessions.clientId })
|
||||
.from(topics)
|
||||
.where(eq(topics.clientId, 'topic1'))
|
||||
.leftJoin(sessions, eq(topics.sessionId, sessions.id));
|
||||
|
||||
expect(topic1?.sessionClientId).toBe('session1');
|
||||
|
||||
// topic2 should not have session
|
||||
const topic2 = await serverDB.query.topics.findFirst({
|
||||
where: eq(topics.clientId, 'topic2'),
|
||||
with: { session: true },
|
||||
});
|
||||
expect(topic2?.session).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('import messages', () => {
|
||||
it('should import messages and return correct result', async () => {
|
||||
const data: ImporterEntryData = {
|
||||
version: CURRENT_CONFIG_VERSION,
|
||||
messages: [
|
||||
{
|
||||
id: 'msg1',
|
||||
content: 'Message 1',
|
||||
role: 'user',
|
||||
createdAt: 1715186011586,
|
||||
updatedAt: 1715186015053,
|
||||
sessionId: 'session1',
|
||||
topicId: 'topic1',
|
||||
},
|
||||
{
|
||||
id: 'msg2',
|
||||
content: 'Message 2',
|
||||
role: 'assistant',
|
||||
createdAt: 1715186011586,
|
||||
updatedAt: 1715186015053,
|
||||
sessionId: 'session1',
|
||||
topicId: 'topic1',
|
||||
parentId: 'msg1',
|
||||
},
|
||||
],
|
||||
sessions: [
|
||||
{
|
||||
id: 'session1',
|
||||
createdAt: '2022-05-14T18:18:10.494Z',
|
||||
updatedAt: '2023-01-01',
|
||||
type: 'agent',
|
||||
config: {
|
||||
model: 'abc',
|
||||
chatConfig: {} as any,
|
||||
params: {},
|
||||
systemRole: 'abc',
|
||||
tts: {} as any,
|
||||
},
|
||||
meta: {
|
||||
title: 'Session 1',
|
||||
},
|
||||
},
|
||||
],
|
||||
topics: [
|
||||
{
|
||||
id: 'topic1',
|
||||
title: 'Topic 1',
|
||||
createdAt: 1715186011586,
|
||||
updatedAt: 1715186015053,
|
||||
sessionId: 'session1',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await importer.importData(data);
|
||||
|
||||
expect(result.messages.added).toBe(2);
|
||||
expect(result.messages.skips).toBe(0);
|
||||
expect(result.messages.errors).toBe(0);
|
||||
|
||||
const importedMessages = await serverDB.query.messages.findMany({
|
||||
where: eq(messages.userId, userId),
|
||||
});
|
||||
expect(importedMessages).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('should skip existing messages and return correct result', async () => {
|
||||
await serverDB
|
||||
.insert(messages)
|
||||
.values({
|
||||
clientId: 'msg1',
|
||||
content: 'Existing Message',
|
||||
role: 'user',
|
||||
userId,
|
||||
})
|
||||
.execute();
|
||||
|
||||
const data: ImporterEntryData = {
|
||||
version: CURRENT_CONFIG_VERSION,
|
||||
messages: [
|
||||
{
|
||||
id: 'msg1',
|
||||
content: 'Message 1',
|
||||
role: 'user',
|
||||
createdAt: 1715186011586,
|
||||
updatedAt: 1715186015053,
|
||||
},
|
||||
{
|
||||
id: 'msg2',
|
||||
content: 'Message 2',
|
||||
role: 'assistant',
|
||||
createdAt: 1715186011586,
|
||||
updatedAt: 1715186015053,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = await importer.importData(data);
|
||||
|
||||
expect(result.messages.added).toBe(1);
|
||||
expect(result.messages.skips).toBe(1);
|
||||
expect(result.messages.errors).toBe(0);
|
||||
});
|
||||
|
||||
it('should associate imported messages with sessions and topics', async () => {
|
||||
const data: ImporterEntryData = {
|
||||
version: CURRENT_CONFIG_VERSION,
|
||||
sessions: [
|
||||
{
|
||||
id: 'session1',
|
||||
createdAt: '2022-05-14T18:18:10.494Z',
|
||||
updatedAt: '2023-01-01',
|
||||
type: 'agent',
|
||||
config: {
|
||||
model: 'abc',
|
||||
chatConfig: {} as any,
|
||||
params: {},
|
||||
systemRole: 'abc',
|
||||
tts: {} as any,
|
||||
},
|
||||
meta: {
|
||||
title: 'Session 1',
|
||||
},
|
||||
},
|
||||
],
|
||||
topics: [
|
||||
{
|
||||
id: 'topic1',
|
||||
title: 'Topic 1',
|
||||
createdAt: 1715186011586,
|
||||
updatedAt: 1715186015053,
|
||||
sessionId: 'session1',
|
||||
},
|
||||
],
|
||||
messages: [
|
||||
{
|
||||
id: 'msg1',
|
||||
content: 'Message 1',
|
||||
role: 'user',
|
||||
createdAt: 1715186011586,
|
||||
updatedAt: 1715186015053,
|
||||
sessionId: 'session1',
|
||||
topicId: 'topic1',
|
||||
},
|
||||
{
|
||||
id: 'msg2',
|
||||
content: 'Message 2',
|
||||
role: 'assistant',
|
||||
createdAt: 1715186011586,
|
||||
updatedAt: 1715186015053,
|
||||
sessionId: 'session1',
|
||||
topicId: 'topic1',
|
||||
parentId: 'msg1',
|
||||
},
|
||||
{
|
||||
id: 'msg3',
|
||||
content: 'Message 3',
|
||||
role: 'user',
|
||||
createdAt: 1715186011586,
|
||||
updatedAt: 1715186015053,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
await importer.importData(data);
|
||||
|
||||
// msg1 and msg2 should be associated with session1 and topic1
|
||||
const [msg1, msg2] = await serverDB.query.messages.findMany({
|
||||
where: inArray(messages.clientId, ['msg1', 'msg2']),
|
||||
with: {
|
||||
session: true,
|
||||
topic: true,
|
||||
},
|
||||
});
|
||||
|
||||
expect(msg1.session?.clientId).toBe('session1');
|
||||
expect(msg1.topic?.clientId).toBe('topic1');
|
||||
expect(msg2.session?.clientId).toBe('session1');
|
||||
expect(msg2.topic?.clientId).toBe('topic1');
|
||||
|
||||
// msg3 should not have session and topic
|
||||
const msg3 = await serverDB.query.messages.findFirst({
|
||||
where: eq(messages.clientId, 'msg3'),
|
||||
with: {
|
||||
session: true,
|
||||
topic: true,
|
||||
},
|
||||
});
|
||||
expect(msg3?.session).toBeNull();
|
||||
expect(msg3?.topic).toBeNull();
|
||||
});
|
||||
|
||||
it('should set parentId for messages', async () => {
|
||||
const data: ImporterEntryData = {
|
||||
version: CURRENT_CONFIG_VERSION,
|
||||
messages: [
|
||||
{
|
||||
id: 'msg1',
|
||||
content: 'Message 1',
|
||||
role: 'user',
|
||||
createdAt: 1715186011586,
|
||||
updatedAt: 1715186015053,
|
||||
},
|
||||
{
|
||||
id: 'msg2',
|
||||
content: 'Message 2',
|
||||
role: 'assistant',
|
||||
createdAt: 1715186011586,
|
||||
updatedAt: 1715186015053,
|
||||
parentId: 'msg1',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
await importer.importData(data);
|
||||
|
||||
const msg2 = await serverDB.query.messages.findFirst({
|
||||
where: eq(messages.clientId, 'msg2'),
|
||||
with: { parent: true },
|
||||
});
|
||||
|
||||
expect(msg2?.parent?.clientId).toBe('msg1');
|
||||
});
|
||||
|
||||
it('should import parentId Success', () => {});
|
||||
});
|
||||
|
||||
describe(
|
||||
'real world examples',
|
||||
() => {
|
||||
it('should import successfully', async () => {
|
||||
const result = await importer.importData({
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'hello',
|
||||
files: [],
|
||||
sessionId: 'inbox',
|
||||
topicId: '2wcF8yaS',
|
||||
createdAt: 1714236590340,
|
||||
id: 'DCG1G1EH',
|
||||
updatedAt: 1714236590340,
|
||||
extra: {},
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: '...',
|
||||
parentId: 'DCG1G1EH',
|
||||
sessionId: 'inbox',
|
||||
topicId: '2wcF8yaS',
|
||||
createdAt: 1714236590441,
|
||||
id: 'gY41w5vQ',
|
||||
updatedAt: 1714236590518,
|
||||
error: {
|
||||
body: {
|
||||
error: {
|
||||
message: "model 'mixtral' not found, try pulling it first",
|
||||
name: 'ResponseError',
|
||||
status_code: 404,
|
||||
},
|
||||
provider: 'ollama',
|
||||
},
|
||||
message:
|
||||
'Error requesting Ollama service, please troubleshoot or retry based on the following information',
|
||||
type: 'OllamaBizError',
|
||||
},
|
||||
extra: { fromModel: 'mixtral', fromProvider: 'ollama' },
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: 'hello',
|
||||
files: [],
|
||||
sessionId: 'a5fefc88-f6c1-44fb-9e98-3d366b1ed589',
|
||||
topicId: 'v38snJ0A',
|
||||
createdAt: 1717080410895,
|
||||
id: 'qOIxEGEB',
|
||||
updatedAt: 1717080410895,
|
||||
extra: {},
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: '...',
|
||||
parentId: 'qOIxEGEB',
|
||||
sessionId: 'a5fefc88-f6c1-44fb-9e98-3d366b1ed589',
|
||||
topicId: 'v38snJ0A',
|
||||
createdAt: 1717080410970,
|
||||
id: 'w28FcqY5',
|
||||
updatedAt: 1717080411485,
|
||||
error: {
|
||||
body: { error: { errorType: 'NoOpenAIAPIKey' }, provider: 'openai' },
|
||||
message: 'OpenAI API Key is empty, please add a custom OpenAI API Key',
|
||||
type: 'NoOpenAIAPIKey',
|
||||
},
|
||||
extra: { fromModel: 'gpt-3.5-turbo', fromProvider: 'openai' },
|
||||
},
|
||||
],
|
||||
sessionGroups: [
|
||||
{
|
||||
name: 'Writter',
|
||||
sort: 0,
|
||||
createdAt: 1706114744425,
|
||||
id: 'XlUbvOvL',
|
||||
updatedAt: 1706114747468,
|
||||
},
|
||||
],
|
||||
sessions: [
|
||||
{
|
||||
config: {
|
||||
model: 'gpt-3.5-turbo',
|
||||
params: {
|
||||
frequency_penalty: 0,
|
||||
presence_penalty: 0,
|
||||
temperature: 0.6,
|
||||
top_p: 1,
|
||||
},
|
||||
plugins: [],
|
||||
systemRole:
|
||||
"You are a LobeChat technical operator 🍐🐊. You now need to write a developer's guide for LobeChat as a guide for them to develop LobeChat. This guide will include several sections, and you need to output the corresponding document content based on the user's input.\n\nHere is the technical introduction of LobeChat\n\n LobeChat is an AI conversation application built with the Next.js framework. It uses a series of technology stacks to implement various functions and features.\n\n\n ## Basic Technology Stack\n\n The core technology stack of LobeChat is as follows:\n\n - **Framework**: We chose [Next.js](https://nextjs.org/), a powerful React framework that provides key features such as server-side rendering, routing framework, and Router Handler for our project.\n - **Component Library**: We use [Ant Design (antd)](https://ant.design/) as the basic component library, and introduce [lobe-ui](https://github.com/lobehub/lobe-ui) as our business component library.\n - **State Management**: We use [zustand](https://github.com/pmndrs/zustand), a lightweight and easy-to-use state management library.\n - **Network Request**: We adopt [swr](https://swr.vercel.app/), a React Hooks library for data fetching.\n - **Routing**: We directly use the routing solution provided by [Next.js](https://nextjs.org/) itself.\n - **Internationalization**: We use [i18next](https://www.i18next.com/) to implement multi-language support for the application.\n - **Styling**: We use [antd-style](https://github.com/ant-design/antd-style), a CSS-in-JS library that is compatible with Ant Design.\n - **Unit Testing**: We use [vitest](https://github.com/vitejs/vitest) for unit testing.\n\n ## Folder Directory Structure\n\n The folder directory structure of LobeChat is as follows:\n\n \\`\\`\\`bash\n src\n ├── app # Main logic and state management related code of the application\n ├── components # Reusable UI components\n ├── config # Application configuration files, including client environment variables and server environment variables\n ├── const # Used to define constants, such as action types, route names, etc.\n ├── features # Function modules related to business functions, such as Agent settings, plugin development pop-ups, etc.\n ├── hooks # Custom utility Hooks reused throughout the application\n ├── layout # Layout components of the application, such as navigation bar, sidebar, etc.\n ├── locales # Language files for internationalization\n ├── services # Encapsulated backend service interfaces, such as HTTP requests\n ├── store # Zustand store for state management\n ├── types # TypeScript type definition files\n └── utils # Common utility functions\n \\`\\`\\`\n",
|
||||
tts: {
|
||||
showAllLocaleVoice: false,
|
||||
sttLocale: 'auto',
|
||||
ttsService: 'openai',
|
||||
voice: { openai: 'alloy' },
|
||||
},
|
||||
chatConfig: {
|
||||
autoCreateTopicThreshold: 2,
|
||||
displayMode: 'chat',
|
||||
enableAutoCreateTopic: true,
|
||||
historyCount: 1,
|
||||
},
|
||||
},
|
||||
group: 'XlUbvOvL',
|
||||
meta: {
|
||||
avatar: '📝',
|
||||
description:
|
||||
'LobeChat is an AI conversation application built with the Next.js framework. I will help you write the development documentation for LobeChat.',
|
||||
tags: [
|
||||
'Development Documentation',
|
||||
'Technical Introduction',
|
||||
'next-js',
|
||||
'react',
|
||||
'lobe-chat',
|
||||
],
|
||||
title: 'LobeChat Technical Documentation Expert',
|
||||
},
|
||||
type: 'agent',
|
||||
createdAt: '2024-01-24T16:43:12.164Z',
|
||||
id: 'a5fefc88-f6c1-44fb-9e98-3d366b1ed589',
|
||||
updatedAt: '2024-01-24T16:46:15.226Z',
|
||||
pinned: false,
|
||||
},
|
||||
],
|
||||
topics: [
|
||||
{
|
||||
title: 'Default Topic',
|
||||
sessionId: 'inbox',
|
||||
createdAt: 1714236590531,
|
||||
id: '2wcF8yaS',
|
||||
updatedAt: 1714236590531,
|
||||
},
|
||||
{
|
||||
title: 'Default Topic',
|
||||
sessionId: 'a5fefc88-f6c1-44fb-9e98-3d366b1ed589',
|
||||
createdAt: 1717080410825,
|
||||
id: 'v38snJ0A',
|
||||
updatedAt: 1717080410825,
|
||||
},
|
||||
],
|
||||
version: mockImportData.version,
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
sessionGroups: { added: 1, errors: 0, skips: 0 },
|
||||
sessions: { added: 1, errors: 0, skips: 0 },
|
||||
topics: { added: 2, errors: 0, skips: 0 },
|
||||
messages: { added: 4, errors: 0, skips: 0 },
|
||||
});
|
||||
});
|
||||
|
||||
it('should import real world data', async () => {
|
||||
const result = await importer.importData({
|
||||
...(mockImportData.state as any),
|
||||
version: mockImportData.version,
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
sessionGroups: { added: 2, errors: 0, skips: 0 },
|
||||
sessions: { added: 15, errors: 0, skips: 0 },
|
||||
topics: { added: 4, errors: 0, skips: 0 },
|
||||
messages: { added: 32, errors: 0, skips: 0 },
|
||||
});
|
||||
});
|
||||
},
|
||||
{ timeout: 15000 },
|
||||
);
|
||||
});
|
||||
333
src/database/server/modules/DataImporter/index.ts
Normal file
333
src/database/server/modules/DataImporter/index.ts
Normal file
|
|
@ -0,0 +1,333 @@
|
|||
import { eq, inArray, sql } from 'drizzle-orm';
|
||||
import { and } from 'drizzle-orm/expressions';
|
||||
|
||||
import { serverDB } from '@/database/server';
|
||||
import {
|
||||
agents,
|
||||
agentsToSessions,
|
||||
messagePlugins,
|
||||
messageTranslates,
|
||||
messages,
|
||||
sessionGroups,
|
||||
sessions,
|
||||
topics,
|
||||
} from '@/database/server/schemas/lobechat';
|
||||
import { ImportResult } from '@/services/config';
|
||||
import { ImporterEntryData } from '@/types/importer';
|
||||
|
||||
export class DataImporter {
|
||||
private userId: string;
|
||||
|
||||
/**
|
||||
* The version of the importer that this module supports
|
||||
*/
|
||||
supportVersion = 7;
|
||||
|
||||
constructor(userId: string) {
|
||||
this.userId = userId;
|
||||
}
|
||||
|
||||
importData = async (data: ImporterEntryData) => {
|
||||
if (data.version > this.supportVersion) throw new Error('Unsupported version');
|
||||
|
||||
let sessionGroupResult: ImportResult = { added: 0, errors: 0, skips: 0 };
|
||||
let sessionResult: ImportResult = { added: 0, errors: 0, skips: 0 };
|
||||
let topicResult: ImportResult = { added: 0, errors: 0, skips: 0 };
|
||||
let messageResult: ImportResult = { added: 0, errors: 0, skips: 0 };
|
||||
|
||||
let sessionGroupIdMap: Record<string, string> = {};
|
||||
let sessionIdMap: Record<string, string> = {};
|
||||
let topicIdMap: Record<string, string> = {};
|
||||
|
||||
// import sessionGroups
|
||||
await serverDB.transaction(async (trx) => {
|
||||
if (data.sessionGroups && data.sessionGroups.length > 0) {
|
||||
const query = await trx.query.sessionGroups.findMany({
|
||||
where: and(
|
||||
eq(sessionGroups.userId, this.userId),
|
||||
inArray(
|
||||
sessionGroups.clientId,
|
||||
data.sessionGroups.map(({ id }) => id),
|
||||
),
|
||||
),
|
||||
});
|
||||
|
||||
sessionGroupResult.skips = query.length;
|
||||
|
||||
const mapArray = await trx
|
||||
.insert(sessionGroups)
|
||||
.values(
|
||||
data.sessionGroups.map(({ id, createdAt, updatedAt, ...res }) => ({
|
||||
...res,
|
||||
clientId: id,
|
||||
createdAt: new Date(createdAt),
|
||||
updatedAt: new Date(updatedAt),
|
||||
userId: this.userId,
|
||||
})),
|
||||
)
|
||||
.onConflictDoUpdate({
|
||||
set: { updatedAt: new Date() },
|
||||
target: [sessionGroups.clientId, sessionGroups.userId],
|
||||
})
|
||||
.returning({ clientId: sessionGroups.clientId, id: sessionGroups.id })
|
||||
.execute();
|
||||
|
||||
sessionGroupResult.added = mapArray.length - query.length;
|
||||
|
||||
sessionGroupIdMap = Object.fromEntries(mapArray.map(({ clientId, id }) => [clientId, id]));
|
||||
}
|
||||
|
||||
// import sessions
|
||||
if (data.sessions && data.sessions.length > 0) {
|
||||
const query = await trx.query.sessions.findMany({
|
||||
where: and(
|
||||
eq(sessions.userId, this.userId),
|
||||
inArray(
|
||||
sessions.clientId,
|
||||
data.sessions.map(({ id }) => id),
|
||||
),
|
||||
),
|
||||
});
|
||||
|
||||
sessionResult.skips = query.length;
|
||||
|
||||
const mapArray = await trx
|
||||
.insert(sessions)
|
||||
.values(
|
||||
data.sessions.map(({ id, createdAt, updatedAt, group, ...res }) => ({
|
||||
...res,
|
||||
clientId: id,
|
||||
createdAt: new Date(createdAt),
|
||||
groupId: group ? sessionGroupIdMap[group] : null,
|
||||
updatedAt: new Date(updatedAt),
|
||||
userId: this.userId,
|
||||
})),
|
||||
)
|
||||
.onConflictDoUpdate({
|
||||
set: { updatedAt: new Date() },
|
||||
target: [sessions.clientId, sessions.userId],
|
||||
})
|
||||
.returning({ clientId: sessions.clientId, id: sessions.id })
|
||||
.execute();
|
||||
|
||||
// get the session client-server id map
|
||||
sessionIdMap = Object.fromEntries(mapArray.map(({ clientId, id }) => [clientId, id]));
|
||||
|
||||
// update added count
|
||||
sessionResult.added = mapArray.length - query.length;
|
||||
|
||||
const shouldInsertSessionAgents = data.sessions
|
||||
// filter out existing session, only insert new ones
|
||||
.filter((s) => query.every((q) => q.clientId !== s.id));
|
||||
|
||||
// 只有当需要有新的 session 时,才会插入 agent
|
||||
if (shouldInsertSessionAgents.length > 0) {
|
||||
const agentMapArray = await trx
|
||||
.insert(agents)
|
||||
.values(
|
||||
shouldInsertSessionAgents.map(({ config, meta }) => ({
|
||||
...config,
|
||||
...meta,
|
||||
userId: this.userId,
|
||||
})),
|
||||
)
|
||||
.returning({ id: agents.id })
|
||||
.execute();
|
||||
|
||||
await trx
|
||||
.insert(agentsToSessions)
|
||||
.values(
|
||||
shouldInsertSessionAgents.map(({ id }, index) => ({
|
||||
agentId: agentMapArray[index].id,
|
||||
sessionId: sessionIdMap[id],
|
||||
})),
|
||||
)
|
||||
.execute();
|
||||
}
|
||||
}
|
||||
|
||||
// import topics
|
||||
if (data.topics && data.topics.length > 0) {
|
||||
const skipQuery = await trx.query.topics.findMany({
|
||||
where: and(
|
||||
eq(topics.userId, this.userId),
|
||||
inArray(
|
||||
topics.clientId,
|
||||
data.topics.map(({ id }) => id),
|
||||
),
|
||||
),
|
||||
});
|
||||
topicResult.skips = skipQuery.length;
|
||||
|
||||
const mapArray = await trx
|
||||
.insert(topics)
|
||||
.values(
|
||||
data.topics.map(({ id, createdAt, updatedAt, sessionId, ...res }) => ({
|
||||
...res,
|
||||
clientId: id,
|
||||
createdAt: new Date(createdAt),
|
||||
sessionId: sessionId ? sessionIdMap[sessionId] : null,
|
||||
updatedAt: new Date(updatedAt),
|
||||
userId: this.userId,
|
||||
})),
|
||||
)
|
||||
.onConflictDoUpdate({
|
||||
set: { updatedAt: new Date() },
|
||||
target: [topics.clientId, topics.userId],
|
||||
})
|
||||
.returning({ clientId: topics.clientId, id: topics.id })
|
||||
.execute();
|
||||
|
||||
topicIdMap = Object.fromEntries(mapArray.map(({ clientId, id }) => [clientId, id]));
|
||||
|
||||
topicResult.added = mapArray.length - skipQuery.length;
|
||||
}
|
||||
|
||||
// import messages
|
||||
if (data.messages && data.messages.length > 0) {
|
||||
// 1. find skip ones
|
||||
console.time('find messages');
|
||||
const skipQuery = await trx.query.messages.findMany({
|
||||
where: and(
|
||||
eq(messages.userId, this.userId),
|
||||
inArray(
|
||||
messages.clientId,
|
||||
data.messages.map(({ id }) => id),
|
||||
),
|
||||
),
|
||||
});
|
||||
console.timeEnd('find messages');
|
||||
|
||||
messageResult.skips = skipQuery.length;
|
||||
|
||||
// filter out existing messages, only insert new ones
|
||||
const shouldInsertMessages = data.messages.filter((s) =>
|
||||
skipQuery.every((q) => q.clientId !== s.id),
|
||||
);
|
||||
|
||||
// 2. insert messages
|
||||
if (shouldInsertMessages.length > 0) {
|
||||
const inertValues = shouldInsertMessages.map(
|
||||
({ id, extra, createdAt, updatedAt, sessionId, topicId, ...res }) => ({
|
||||
...res,
|
||||
clientId: id,
|
||||
createdAt: new Date(createdAt),
|
||||
model: extra?.fromModel,
|
||||
parentId: null,
|
||||
provider: extra?.fromProvider,
|
||||
sessionId: sessionId ? sessionIdMap[sessionId] : null,
|
||||
topicId: topicId ? topicIdMap[topicId] : null, // 暂时设为 NULL
|
||||
updatedAt: new Date(updatedAt),
|
||||
userId: this.userId,
|
||||
}),
|
||||
);
|
||||
|
||||
console.time('insert messages');
|
||||
const BATCH_SIZE = 100; // 每批次插入的记录数
|
||||
|
||||
for (let i = 0; i < inertValues.length; i += BATCH_SIZE) {
|
||||
const batch = inertValues.slice(i, i + BATCH_SIZE);
|
||||
await trx.insert(messages).values(batch).execute();
|
||||
}
|
||||
|
||||
console.timeEnd('insert messages');
|
||||
|
||||
const messageIdArray = await trx
|
||||
.select({ clientId: messages.clientId, id: messages.id })
|
||||
.from(messages)
|
||||
.where(
|
||||
and(
|
||||
eq(messages.userId, this.userId),
|
||||
inArray(
|
||||
messages.clientId,
|
||||
data.messages.map(({ id }) => id),
|
||||
),
|
||||
),
|
||||
);
|
||||
|
||||
const messageIdMap = Object.fromEntries(
|
||||
messageIdArray.map(({ clientId, id }) => [clientId, id]),
|
||||
);
|
||||
|
||||
// 3. update parentId for messages
|
||||
console.time('execute updates parentId');
|
||||
const parentIdUpdates = shouldInsertMessages
|
||||
.filter((msg) => msg.parentId) // 只处理有 parentId 的消息
|
||||
.map((msg) => {
|
||||
if (messageIdMap[msg.parentId as string])
|
||||
return sql`WHEN ${messages.clientId} = ${msg.id} THEN ${messageIdMap[msg.parentId as string]} `;
|
||||
|
||||
return undefined;
|
||||
})
|
||||
.filter(Boolean);
|
||||
|
||||
if (parentIdUpdates.length > 0) {
|
||||
const updateQuery = trx
|
||||
.update(messages)
|
||||
.set({
|
||||
parentId: sql`CASE ${sql.join(parentIdUpdates)} END`,
|
||||
})
|
||||
.where(
|
||||
inArray(
|
||||
messages.clientId,
|
||||
data.messages.map((msg) => msg.id),
|
||||
),
|
||||
);
|
||||
|
||||
// if needed, you can print the sql and params
|
||||
// const SQL = updateQuery.toSQL();
|
||||
// console.log('sql:', SQL.sql);
|
||||
// console.log('params:', SQL.params);
|
||||
|
||||
await updateQuery.execute();
|
||||
}
|
||||
console.timeEnd('execute updates parentId');
|
||||
|
||||
// 4. insert message plugins
|
||||
const pluginInserts = shouldInsertMessages.filter((msg) => msg.plugin);
|
||||
if (pluginInserts.length > 0) {
|
||||
await trx
|
||||
.insert(messagePlugins)
|
||||
.values(
|
||||
pluginInserts.map((msg) => ({
|
||||
apiName: msg.plugin?.apiName,
|
||||
arguments: msg.plugin?.arguments,
|
||||
id: messageIdMap[msg.id],
|
||||
identifier: msg.plugin?.identifier,
|
||||
state: msg.pluginState,
|
||||
toolCallId: msg.tool_call_id,
|
||||
type: msg.plugin?.type,
|
||||
})),
|
||||
)
|
||||
.execute();
|
||||
}
|
||||
|
||||
// 5. insert message translate
|
||||
const translateInserts = shouldInsertMessages.filter((msg) => msg.extra?.translate);
|
||||
if (translateInserts.length > 0) {
|
||||
await trx
|
||||
.insert(messageTranslates)
|
||||
.values(
|
||||
translateInserts.map((msg) => ({
|
||||
id: messageIdMap[msg.id],
|
||||
...msg.extra?.translate,
|
||||
})),
|
||||
)
|
||||
.execute();
|
||||
}
|
||||
|
||||
// TODO: 未来需要处理 TTS 和图片的插入 (目前存在 file 的部分,不方便处理)
|
||||
}
|
||||
|
||||
messageResult.added = shouldInsertMessages.length;
|
||||
}
|
||||
});
|
||||
|
||||
return {
|
||||
messages: messageResult,
|
||||
sessionGroups: sessionGroupResult,
|
||||
sessions: sessionResult,
|
||||
topics: topicResult,
|
||||
};
|
||||
};
|
||||
}
|
||||
15
src/database/server/schemas/_id.ts
Normal file
15
src/database/server/schemas/_id.ts
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
// refs: https://unkey.dev/blog/uuid-ux
|
||||
|
||||
// If I have 100 million users, each generating up to 1 million messages.
|
||||
// Then the total number of IDs that need to be generated: 100 million × 1 million = 10^14 (100 trillion)
|
||||
// 11-digit Nano ID: 36^11 ≈ 1.3 × 10^17 (130 trillion trillion)
|
||||
|
||||
export const FILE_ID_LENGTH = 19; // 5 prefix + 14 random, e.g. file_ydGX5gmaxL32fh
|
||||
|
||||
export const MESSAGE_ID_LENGTH = 18; // 4 prefix + 14 random, e.g. msg_GX5ymaxL3d2ds2
|
||||
|
||||
export const SESSION_ID_LENGTH = 16; // 4 prefix + 12 random, e.g. ssn_GX5y3d2dmaxL
|
||||
|
||||
export const TOPIC_ID_LENGTH = 16; // 4 prefix + 12 random, e.g. tpc_GX5ymd7axL3y
|
||||
|
||||
export const USER_ID_LENGTH = 14; // 4 prefix + 10 random, e.g. user_GXyxLmd75a
|
||||
601
src/database/server/schemas/lobechat.ts
Normal file
601
src/database/server/schemas/lobechat.ts
Normal file
|
|
@ -0,0 +1,601 @@
|
|||
/* eslint-disable sort-keys-fix/sort-keys-fix */
|
||||
import { LobeChatPluginManifest } from '@lobehub/chat-plugin-sdk';
|
||||
import { relations } from 'drizzle-orm';
|
||||
import {
|
||||
boolean,
|
||||
index,
|
||||
integer,
|
||||
jsonb,
|
||||
pgTable,
|
||||
primaryKey,
|
||||
serial,
|
||||
text,
|
||||
timestamp,
|
||||
unique,
|
||||
uniqueIndex,
|
||||
varchar,
|
||||
} from 'drizzle-orm/pg-core';
|
||||
import { createInsertSchema } from 'drizzle-zod';
|
||||
|
||||
import { DEFAULT_PREFERENCE } from '@/const/user';
|
||||
import { LobeAgentChatConfig, LobeAgentTTSConfig } from '@/types/agent';
|
||||
import { CustomPluginParams } from '@/types/tool/plugin';
|
||||
|
||||
import { idGenerator, randomSlug } from '../utils/idGenerator';
|
||||
|
||||
const timestamptz = (name: string) => timestamp(name, { withTimezone: true });
|
||||
|
||||
const createdAt = () => timestamptz('created_at').notNull().defaultNow();
|
||||
const updatedAt = () => timestamptz('updated_at').notNull().defaultNow();
|
||||
|
||||
/**
|
||||
* This table stores users. Users are created in Clerk, then Clerk calls a
|
||||
* webhook at /api/webhook/clerk to inform this application a user was created.
|
||||
*/
|
||||
export const users = pgTable('users', {
|
||||
// The ID will be the user's ID from Clerk
|
||||
id: text('id').primaryKey().notNull(),
|
||||
username: text('username').unique(),
|
||||
email: text('email'),
|
||||
|
||||
avatar: text('avatar'),
|
||||
phone: text('phone'),
|
||||
firstName: text('first_name'),
|
||||
lastName: text('last_name'),
|
||||
|
||||
isOnboarded: boolean('is_onboarded').default(false),
|
||||
// Time user was created in Clerk
|
||||
clerkCreatedAt: timestamptz('clerk_created_at'),
|
||||
|
||||
preference: jsonb('preference').default(DEFAULT_PREFERENCE),
|
||||
|
||||
createdAt: createdAt(),
|
||||
updatedAt: updatedAt(),
|
||||
|
||||
key: text('key'),
|
||||
});
|
||||
|
||||
export const userSettings = pgTable('user_settings', {
|
||||
id: text('id')
|
||||
.references(() => users.id, { onDelete: 'cascade' })
|
||||
.primaryKey(),
|
||||
|
||||
tts: jsonb('tts'),
|
||||
keyVaults: text('key_vaults'),
|
||||
general: jsonb('general'),
|
||||
languageModel: jsonb('language_model'),
|
||||
systemAgent: jsonb('system_agent'),
|
||||
defaultAgent: jsonb('default_agent'),
|
||||
tool: jsonb('tool'),
|
||||
});
|
||||
|
||||
export const tags = pgTable('tags', {
|
||||
id: serial('id').primaryKey(),
|
||||
slug: text('slug').notNull().unique(),
|
||||
name: text('name'),
|
||||
|
||||
userId: text('user_id')
|
||||
.references(() => users.id, { onDelete: 'cascade' })
|
||||
.notNull(),
|
||||
|
||||
createdAt: createdAt(),
|
||||
updatedAt: updatedAt(),
|
||||
});
|
||||
|
||||
export type NewUser = typeof users.$inferInsert;
|
||||
export type UserItem = typeof users.$inferSelect;
|
||||
|
||||
export const files = pgTable('files', {
|
||||
id: text('id')
|
||||
.$defaultFn(() => idGenerator('files'))
|
||||
.primaryKey(),
|
||||
|
||||
userId: text('user_id')
|
||||
.references(() => users.id, { onDelete: 'cascade' })
|
||||
.notNull(),
|
||||
fileType: varchar('file_type', { length: 255 }).notNull(),
|
||||
name: text('name').notNull(),
|
||||
size: integer('size').notNull(),
|
||||
url: text('url').notNull(),
|
||||
|
||||
metadata: jsonb('metadata'),
|
||||
|
||||
createdAt: createdAt(),
|
||||
updatedAt: updatedAt(),
|
||||
});
|
||||
|
||||
export type NewFile = typeof files.$inferInsert;
|
||||
export type FileItem = typeof files.$inferSelect;
|
||||
|
||||
export const plugins = pgTable('plugins', {
|
||||
id: serial('id').primaryKey(),
|
||||
identifier: text('identifier').notNull().unique(),
|
||||
|
||||
title: text('title').notNull(),
|
||||
description: text('description'),
|
||||
avatar: text('avatar'),
|
||||
author: text('author'),
|
||||
|
||||
manifest: text('manifest').notNull(),
|
||||
locale: text('locale').notNull(),
|
||||
createdAt: createdAt(),
|
||||
updatedAt: updatedAt(),
|
||||
});
|
||||
|
||||
export const installedPlugins = pgTable(
|
||||
'user_installed_plugins',
|
||||
{
|
||||
userId: text('user_id')
|
||||
.references(() => users.id, { onDelete: 'cascade' })
|
||||
.notNull(),
|
||||
|
||||
identifier: text('identifier').notNull(),
|
||||
type: text('type', { enum: ['plugin', 'customPlugin'] }).notNull(),
|
||||
manifest: jsonb('manifest').$type<LobeChatPluginManifest>(),
|
||||
settings: jsonb('settings'),
|
||||
customParams: jsonb('custom_params').$type<CustomPluginParams>(),
|
||||
|
||||
createdAt: createdAt(),
|
||||
updatedAt: updatedAt(),
|
||||
},
|
||||
(self) => ({
|
||||
id: primaryKey({ columns: [self.userId, self.identifier] }),
|
||||
}),
|
||||
);
|
||||
|
||||
export type NewInstalledPlugin = typeof installedPlugins.$inferInsert;
|
||||
export type InstalledPluginItem = typeof installedPlugins.$inferSelect;
|
||||
|
||||
export const pluginsTags = pgTable(
|
||||
'plugins_tags',
|
||||
{
|
||||
pluginId: integer('plugin_id')
|
||||
.notNull()
|
||||
.references(() => plugins.id, { onDelete: 'cascade' }),
|
||||
tagId: integer('tag_id')
|
||||
.notNull()
|
||||
.references(() => tags.id, { onDelete: 'cascade' }),
|
||||
},
|
||||
(t) => ({
|
||||
pk: primaryKey({ columns: [t.pluginId, t.tagId] }),
|
||||
}),
|
||||
);
|
||||
|
||||
// ======= agents ======= //
|
||||
export const agents = pgTable('agents', {
|
||||
id: text('id')
|
||||
.primaryKey()
|
||||
.$defaultFn(() => idGenerator('agents'))
|
||||
.notNull(),
|
||||
slug: varchar('slug', { length: 100 })
|
||||
.$defaultFn(() => randomSlug())
|
||||
.unique(),
|
||||
title: text('title'),
|
||||
description: text('description'),
|
||||
tags: jsonb('tags').$type<string[]>().default([]),
|
||||
avatar: text('avatar'),
|
||||
backgroundColor: text('background_color'),
|
||||
|
||||
plugins: jsonb('plugins').$type<string[]>().default([]),
|
||||
userId: text('user_id')
|
||||
.references(() => users.id, { onDelete: 'cascade' })
|
||||
.notNull(),
|
||||
|
||||
chatConfig: jsonb('chat_config').$type<LobeAgentChatConfig>(),
|
||||
|
||||
fewShots: jsonb('few_shots'),
|
||||
model: text('model'),
|
||||
params: jsonb('params').default({}),
|
||||
provider: text('provider'),
|
||||
systemRole: text('system_role'),
|
||||
tts: jsonb('tts').$type<LobeAgentTTSConfig>(),
|
||||
|
||||
createdAt: createdAt(),
|
||||
updatedAt: updatedAt(),
|
||||
});
|
||||
|
||||
export const agentsTags = pgTable(
|
||||
'agents_tags',
|
||||
{
|
||||
agentId: text('agent_id')
|
||||
.notNull()
|
||||
.references(() => agents.id, { onDelete: 'cascade' }),
|
||||
tagId: integer('tag_id')
|
||||
.notNull()
|
||||
.references(() => tags.id, { onDelete: 'cascade' }),
|
||||
},
|
||||
(t) => ({
|
||||
pk: primaryKey({ columns: [t.agentId, t.tagId] }),
|
||||
}),
|
||||
);
|
||||
export const insertAgentSchema = createInsertSchema(agents);
|
||||
|
||||
export type NewAgent = typeof agents.$inferInsert;
|
||||
export type AgentItem = typeof agents.$inferSelect;
|
||||
|
||||
// ======= market ======= //
|
||||
|
||||
export const market = pgTable('market', {
|
||||
id: serial('id').primaryKey(),
|
||||
|
||||
agentId: text('agent_id').references(() => agents.id, { onDelete: 'cascade' }),
|
||||
pluginId: integer('plugin_id').references(() => plugins.id, { onDelete: 'cascade' }),
|
||||
|
||||
type: text('type', { enum: ['plugin', 'model', 'agent', 'group'] }).notNull(),
|
||||
|
||||
view: integer('view').default(0),
|
||||
like: integer('like').default(0),
|
||||
used: integer('used').default(0),
|
||||
|
||||
userId: text('user_id')
|
||||
.references(() => users.id, { onDelete: 'cascade' })
|
||||
.notNull(),
|
||||
|
||||
createdAt: createdAt(),
|
||||
updatedAt: updatedAt(),
|
||||
});
|
||||
|
||||
// ======= sessionGroups ======= //
|
||||
|
||||
export const sessionGroups = pgTable(
|
||||
'session_groups',
|
||||
{
|
||||
id: text('id')
|
||||
.$defaultFn(() => idGenerator('sessionGroups'))
|
||||
.primaryKey(),
|
||||
name: text('name').notNull(),
|
||||
sort: integer('sort'),
|
||||
|
||||
userId: text('user_id')
|
||||
.references(() => users.id, { onDelete: 'cascade' })
|
||||
.notNull(),
|
||||
|
||||
clientId: text('client_id'),
|
||||
createdAt: createdAt(),
|
||||
updatedAt: updatedAt(),
|
||||
},
|
||||
(table) => ({
|
||||
clientIdUnique: unique('session_group_client_id_user_unique').on(table.clientId, table.userId),
|
||||
}),
|
||||
);
|
||||
|
||||
export const insertSessionGroupSchema = createInsertSchema(sessionGroups);
|
||||
|
||||
export type NewSessionGroup = typeof sessionGroups.$inferInsert;
|
||||
export type SessionGroupItem = typeof sessionGroups.$inferSelect;
|
||||
|
||||
// ======= sessions ======= //
|
||||
|
||||
export const sessions = pgTable(
|
||||
'sessions',
|
||||
{
|
||||
id: text('id')
|
||||
.$defaultFn(() => idGenerator('sessions'))
|
||||
.primaryKey(),
|
||||
slug: varchar('slug', { length: 100 })
|
||||
.notNull()
|
||||
.$defaultFn(() => randomSlug()),
|
||||
title: text('title'),
|
||||
description: text('description'),
|
||||
avatar: text('avatar'),
|
||||
backgroundColor: text('background_color'),
|
||||
|
||||
type: text('type', { enum: ['agent', 'group'] }).default('agent'),
|
||||
|
||||
userId: text('user_id')
|
||||
.references(() => users.id, { onDelete: 'cascade' })
|
||||
.notNull(),
|
||||
groupId: text('group_id').references(() => sessionGroups.id, { onDelete: 'set null' }),
|
||||
clientId: text('client_id'),
|
||||
pinned: boolean('pinned').default(false),
|
||||
|
||||
createdAt: createdAt(),
|
||||
updatedAt: updatedAt(),
|
||||
},
|
||||
(t) => ({
|
||||
slugUserIdUnique: uniqueIndex('slug_user_id_unique').on(t.slug, t.userId),
|
||||
|
||||
clientIdUnique: unique('sessions_client_id_user_id_unique').on(t.clientId, t.userId),
|
||||
}),
|
||||
);
|
||||
|
||||
export const insertSessionSchema = createInsertSchema(sessions);
|
||||
// export const selectSessionSchema = createSelectSchema(sessions);
|
||||
|
||||
export type NewSession = typeof sessions.$inferInsert;
|
||||
export type SessionItem = typeof sessions.$inferSelect;
|
||||
|
||||
// ======== topics ======= //
|
||||
export const topics = pgTable(
|
||||
'topics',
|
||||
{
|
||||
id: text('id')
|
||||
.$defaultFn(() => idGenerator('topics'))
|
||||
.primaryKey(),
|
||||
sessionId: text('session_id').references(() => sessions.id, { onDelete: 'cascade' }),
|
||||
userId: text('user_id')
|
||||
.references(() => users.id, { onDelete: 'cascade' })
|
||||
.notNull(),
|
||||
favorite: boolean('favorite').default(false),
|
||||
title: text('title'),
|
||||
clientId: text('client_id'),
|
||||
|
||||
createdAt: createdAt(),
|
||||
updatedAt: updatedAt(),
|
||||
},
|
||||
(t) => ({
|
||||
clientIdUnique: unique('topic_client_id_user_id_unique').on(t.clientId, t.userId),
|
||||
}),
|
||||
);
|
||||
|
||||
export type NewTopic = typeof topics.$inferInsert;
|
||||
export type TopicItem = typeof topics.$inferSelect;
|
||||
|
||||
// ======== messages ======= //
|
||||
// @ts-ignore
|
||||
export const messages = pgTable(
|
||||
'messages',
|
||||
{
|
||||
id: text('id')
|
||||
.$defaultFn(() => idGenerator('messages'))
|
||||
.primaryKey(),
|
||||
|
||||
role: text('role', { enum: ['user', 'system', 'assistant', 'tool'] }).notNull(),
|
||||
content: text('content'),
|
||||
|
||||
model: text('model'),
|
||||
provider: text('provider'),
|
||||
|
||||
favorite: boolean('favorite').default(false),
|
||||
error: jsonb('error'),
|
||||
|
||||
tools: jsonb('tools'),
|
||||
|
||||
traceId: text('trace_id'),
|
||||
observationId: text('observation_id'),
|
||||
|
||||
clientId: text('client_id'),
|
||||
|
||||
// foreign keys
|
||||
userId: text('user_id')
|
||||
.references(() => users.id, { onDelete: 'cascade' })
|
||||
.notNull(),
|
||||
sessionId: text('session_id').references(() => sessions.id, { onDelete: 'cascade' }),
|
||||
topicId: text('topic_id').references(() => topics.id, { onDelete: 'cascade' }),
|
||||
parentId: text('parent_id').references(() => messages.id, { onDelete: 'set null' }),
|
||||
quotaId: text('quota_id').references(() => messages.id, { onDelete: 'set null' }),
|
||||
|
||||
// used for group chat
|
||||
agentId: text('agent_id').references(() => agents.id, { onDelete: 'set null' }),
|
||||
|
||||
createdAt: createdAt(),
|
||||
updatedAt: updatedAt(),
|
||||
},
|
||||
(table) => ({
|
||||
createdAtIdx: index('messages_created_at_idx').on(table.createdAt),
|
||||
messageClientIdUnique: uniqueIndex('message_client_id_user_unique').on(
|
||||
table.clientId,
|
||||
table.userId,
|
||||
),
|
||||
}),
|
||||
);
|
||||
|
||||
export type NewMessage = typeof messages.$inferInsert;
|
||||
export type MessageItem = typeof messages.$inferSelect;
|
||||
|
||||
export const messagePlugins = pgTable('message_plugins', {
|
||||
id: text('id')
|
||||
.references(() => messages.id, { onDelete: 'cascade' })
|
||||
.primaryKey(),
|
||||
|
||||
toolCallId: text('tool_call_id'),
|
||||
type: text('type', {
|
||||
enum: ['default', 'markdown', 'standalone', 'builtin'],
|
||||
}).default('default'),
|
||||
|
||||
apiName: text('api_name'),
|
||||
arguments: text('arguments'),
|
||||
identifier: text('identifier'),
|
||||
state: jsonb('state'),
|
||||
error: jsonb('error'),
|
||||
});
|
||||
|
||||
export const messageTTS = pgTable('message_tts', {
|
||||
id: text('id')
|
||||
.references(() => messages.id, { onDelete: 'cascade' })
|
||||
.primaryKey(),
|
||||
contentMd5: text('content_md5'),
|
||||
fileId: text('file_id').references(() => files.id, { onDelete: 'cascade' }),
|
||||
voice: text('voice'),
|
||||
});
|
||||
|
||||
export const messageTranslates = pgTable('message_translates', {
|
||||
id: text('id')
|
||||
.references(() => messages.id, { onDelete: 'cascade' })
|
||||
.primaryKey(),
|
||||
content: text('content'),
|
||||
from: text('from'),
|
||||
to: text('to'),
|
||||
});
|
||||
|
||||
export const agentsToSessions = pgTable(
|
||||
'agents_to_sessions',
|
||||
{
|
||||
agentId: text('agent_id')
|
||||
.notNull()
|
||||
.references(() => agents.id, { onDelete: 'cascade' }),
|
||||
sessionId: text('session_id')
|
||||
.notNull()
|
||||
.references(() => sessions.id, { onDelete: 'cascade' }),
|
||||
},
|
||||
(t) => ({
|
||||
pk: primaryKey({ columns: [t.agentId, t.sessionId] }),
|
||||
}),
|
||||
);
|
||||
|
||||
export const filesToMessages = pgTable(
|
||||
'files_to_messages',
|
||||
{
|
||||
fileId: text('file_id')
|
||||
.notNull()
|
||||
.references(() => files.id, { onDelete: 'cascade' }),
|
||||
messageId: text('message_id')
|
||||
.notNull()
|
||||
.references(() => messages.id, { onDelete: 'cascade' }),
|
||||
},
|
||||
(t) => ({
|
||||
pk: primaryKey({ columns: [t.fileId, t.messageId] }),
|
||||
}),
|
||||
);
|
||||
|
||||
export const filesToSessions = pgTable(
|
||||
'files_to_sessions',
|
||||
{
|
||||
fileId: text('file_id')
|
||||
.notNull()
|
||||
.references(() => files.id, { onDelete: 'cascade' }),
|
||||
sessionId: text('session_id')
|
||||
.notNull()
|
||||
.references(() => sessions.id, { onDelete: 'cascade' }),
|
||||
},
|
||||
(t) => ({
|
||||
pk: primaryKey({ columns: [t.fileId, t.sessionId] }),
|
||||
}),
|
||||
);
|
||||
|
||||
export const filesToAgents = pgTable(
|
||||
'files_to_agents',
|
||||
{
|
||||
fileId: text('file_id')
|
||||
.notNull()
|
||||
.references(() => files.id, { onDelete: 'cascade' }),
|
||||
agentId: text('agent_id')
|
||||
.notNull()
|
||||
.references(() => agents.id, { onDelete: 'cascade' }),
|
||||
},
|
||||
(t) => ({
|
||||
pk: primaryKey({ columns: [t.fileId, t.agentId] }),
|
||||
}),
|
||||
);
|
||||
|
||||
export const filesRelations = relations(files, ({ many }) => ({
|
||||
filesToMessages: many(filesToMessages),
|
||||
filesToSessions: many(filesToSessions),
|
||||
filesToAgents: many(filesToAgents),
|
||||
}));
|
||||
|
||||
export const topicRelations = relations(topics, ({ one }) => ({
|
||||
session: one(sessions, {
|
||||
fields: [topics.sessionId],
|
||||
references: [sessions.id],
|
||||
}),
|
||||
}));
|
||||
|
||||
export const pluginsRelations = relations(plugins, ({ many }) => ({
|
||||
pluginsTags: many(pluginsTags),
|
||||
}));
|
||||
|
||||
export const pluginsTagsRelations = relations(pluginsTags, ({ one }) => ({
|
||||
plugin: one(plugins, {
|
||||
fields: [pluginsTags.pluginId],
|
||||
references: [plugins.id],
|
||||
}),
|
||||
tag: one(tags, {
|
||||
fields: [pluginsTags.tagId],
|
||||
references: [tags.id],
|
||||
}),
|
||||
}));
|
||||
|
||||
export const tagsRelations = relations(tags, ({ many }) => ({
|
||||
agentsTags: many(agentsTags),
|
||||
pluginsTags: many(pluginsTags),
|
||||
}));
|
||||
|
||||
export const messagesRelations = relations(messages, ({ many, one }) => ({
|
||||
filesToMessages: many(filesToMessages),
|
||||
|
||||
session: one(sessions, {
|
||||
fields: [messages.sessionId],
|
||||
references: [sessions.id],
|
||||
}),
|
||||
|
||||
parent: one(messages, {
|
||||
fields: [messages.parentId],
|
||||
references: [messages.id],
|
||||
}),
|
||||
|
||||
topic: one(topics, {
|
||||
fields: [messages.topicId],
|
||||
references: [topics.id],
|
||||
}),
|
||||
}));
|
||||
|
||||
export const agentsRelations = relations(agents, ({ many }) => ({
|
||||
agentsToSessions: many(agentsToSessions),
|
||||
filesToAgents: many(filesToAgents),
|
||||
agentsTags: many(agentsTags),
|
||||
}));
|
||||
|
||||
export const agentsToSessionsRelations = relations(agentsToSessions, ({ one }) => ({
|
||||
session: one(sessions, {
|
||||
fields: [agentsToSessions.sessionId],
|
||||
references: [sessions.id],
|
||||
}),
|
||||
agent: one(agents, {
|
||||
fields: [agentsToSessions.agentId],
|
||||
references: [agents.id],
|
||||
}),
|
||||
}));
|
||||
|
||||
export const filesToAgentsRelations = relations(filesToAgents, ({ one }) => ({
|
||||
agent: one(agents, {
|
||||
fields: [filesToAgents.agentId],
|
||||
references: [agents.id],
|
||||
}),
|
||||
file: one(files, {
|
||||
fields: [filesToAgents.fileId],
|
||||
references: [files.id],
|
||||
}),
|
||||
}));
|
||||
|
||||
export const filesToMessagesRelations = relations(filesToMessages, ({ one }) => ({
|
||||
file: one(files, {
|
||||
fields: [filesToMessages.fileId],
|
||||
references: [files.id],
|
||||
}),
|
||||
message: one(messages, {
|
||||
fields: [filesToMessages.messageId],
|
||||
references: [messages.id],
|
||||
}),
|
||||
}));
|
||||
|
||||
export const filesToSessionsRelations = relations(filesToSessions, ({ one }) => ({
|
||||
file: one(files, {
|
||||
fields: [filesToSessions.fileId],
|
||||
references: [files.id],
|
||||
}),
|
||||
session: one(sessions, {
|
||||
fields: [filesToSessions.sessionId],
|
||||
references: [sessions.id],
|
||||
}),
|
||||
}));
|
||||
|
||||
export const agentsTagsRelations = relations(agentsTags, ({ one }) => ({
|
||||
agent: one(agents, {
|
||||
fields: [agentsTags.agentId],
|
||||
references: [agents.id],
|
||||
}),
|
||||
tag: one(tags, {
|
||||
fields: [agentsTags.tagId],
|
||||
references: [tags.id],
|
||||
}),
|
||||
}));
|
||||
|
||||
export const sessionsRelations = relations(sessions, ({ many, one }) => ({
|
||||
filesToSessions: many(filesToSessions),
|
||||
agentsToSessions: many(agentsToSessions),
|
||||
group: one(sessionGroups, {
|
||||
fields: [sessions.groupId],
|
||||
references: [sessionGroups.id],
|
||||
}),
|
||||
}));
|
||||
39
src/database/server/utils/idGenerator.test.ts
Normal file
39
src/database/server/utils/idGenerator.test.ts
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { idGenerator } from './idGenerator';
|
||||
|
||||
describe('idGenerator', () => {
|
||||
it('should generate an ID with the correct prefix and length', () => {
|
||||
const fileId = idGenerator('files');
|
||||
expect(fileId).toMatch(/^file_[a-zA-Z0-9]{12}$/);
|
||||
|
||||
const messageId = idGenerator('messages');
|
||||
expect(messageId).toMatch(/^msg_[a-zA-Z0-9]{12}$/);
|
||||
|
||||
const pluginId = idGenerator('plugins');
|
||||
expect(pluginId).toMatch(/^plg_[a-zA-Z0-9]{12}$/);
|
||||
|
||||
const sessionGroupId = idGenerator('sessionGroups');
|
||||
expect(sessionGroupId).toMatch(/^sg_[a-zA-Z0-9]{12}$/);
|
||||
|
||||
const sessionId = idGenerator('sessions');
|
||||
expect(sessionId).toMatch(/^ssn_[a-zA-Z0-9]{12}$/);
|
||||
|
||||
const topicId = idGenerator('topics');
|
||||
expect(topicId).toMatch(/^tpc_[a-zA-Z0-9]{12}$/);
|
||||
|
||||
const userId = idGenerator('user');
|
||||
expect(userId).toMatch(/^user_[a-zA-Z0-9]{12}$/);
|
||||
});
|
||||
|
||||
it('should generate an ID with custom size', () => {
|
||||
const fileId = idGenerator('files', 12);
|
||||
expect(fileId).toMatch(/^file_[a-zA-Z0-9]{12}$/);
|
||||
});
|
||||
|
||||
it('should throw an error for invalid namespace', () => {
|
||||
expect(() => idGenerator('invalid' as any)).toThrowError(
|
||||
'Invalid namespace: invalid, please check your code.',
|
||||
);
|
||||
});
|
||||
});
|
||||
26
src/database/server/utils/idGenerator.ts
Normal file
26
src/database/server/utils/idGenerator.ts
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
import { generate } from 'random-words';
|
||||
|
||||
import { createNanoId } from '@/utils/uuid';
|
||||
|
||||
const prefixes = {
|
||||
agents: 'agt',
|
||||
files: 'file',
|
||||
messages: 'msg',
|
||||
plugins: 'plg',
|
||||
sessionGroups: 'sg',
|
||||
sessions: 'ssn',
|
||||
topics: 'tpc',
|
||||
user: 'user',
|
||||
} as const;
|
||||
|
||||
export const idGenerator = (namespace: keyof typeof prefixes, size = 12) => {
|
||||
const hash = createNanoId(size);
|
||||
const prefix = prefixes[namespace];
|
||||
|
||||
if (!prefix) throw new Error(`Invalid namespace: ${namespace}, please check your code.`);
|
||||
|
||||
return `${prefix}_${hash()}`;
|
||||
};
|
||||
export const randomSlug = () => (generate(2) as string[]).join('-');
|
||||
|
||||
export const inboxSessionId = (userId: string) => `ssn_inbox_${userId}`;
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
import { ActionIcon, DiscordIcon, Icon } from '@lobehub/ui';
|
||||
import { Badge } from 'antd';
|
||||
import { ItemType } from 'antd/es/menu/interface';
|
||||
import {
|
||||
Book,
|
||||
CircleUserRound,
|
||||
|
|
@ -21,6 +22,7 @@ import urlJoin from 'url-join';
|
|||
|
||||
import type { MenuProps } from '@/components/Menu';
|
||||
import { DISCORD, DOCUMENTS, EMAIL_SUPPORT, GITHUB_ISSUES, mailTo } from '@/const/url';
|
||||
import { isServerMode } from '@/const/version';
|
||||
import DataImporter from '@/features/DataImporter';
|
||||
import { useOpenSettings } from '@/hooks/useInterceptingRoutes';
|
||||
import { usePWAInstall } from '@/hooks/usePWAInstall';
|
||||
|
|
@ -115,46 +117,50 @@ export const useMenu = () => {
|
|||
},
|
||||
];
|
||||
|
||||
const data: MenuProps['items'] = [
|
||||
{
|
||||
icon: <Icon icon={HardDriveUpload} />,
|
||||
key: 'import',
|
||||
label: <DataImporter>{t('import')}</DataImporter>,
|
||||
},
|
||||
{
|
||||
children: [
|
||||
const data = !isLogin
|
||||
? []
|
||||
: ([
|
||||
{
|
||||
key: 'allAgent',
|
||||
label: t('exportType.allAgent'),
|
||||
onClick: configService.exportAgents,
|
||||
},
|
||||
{
|
||||
key: 'allAgentWithMessage',
|
||||
label: t('exportType.allAgentWithMessage'),
|
||||
onClick: configService.exportSessions,
|
||||
},
|
||||
{
|
||||
key: 'globalSetting',
|
||||
label: t('exportType.globalSetting'),
|
||||
onClick: configService.exportSettings,
|
||||
icon: <Icon icon={HardDriveDownload} />,
|
||||
key: 'import',
|
||||
label: <DataImporter>{t('import')}</DataImporter>,
|
||||
},
|
||||
isServerMode
|
||||
? null
|
||||
: {
|
||||
children: [
|
||||
{
|
||||
key: 'allAgent',
|
||||
label: t('exportType.allAgent'),
|
||||
onClick: configService.exportAgents,
|
||||
},
|
||||
{
|
||||
key: 'allAgentWithMessage',
|
||||
label: t('exportType.allAgentWithMessage'),
|
||||
onClick: configService.exportSessions,
|
||||
},
|
||||
{
|
||||
key: 'globalSetting',
|
||||
label: t('exportType.globalSetting'),
|
||||
onClick: configService.exportSettings,
|
||||
},
|
||||
{
|
||||
type: 'divider',
|
||||
},
|
||||
{
|
||||
key: 'all',
|
||||
label: t('exportType.all'),
|
||||
onClick: configService.exportAll,
|
||||
},
|
||||
],
|
||||
icon: <Icon icon={HardDriveUpload} />,
|
||||
key: 'export',
|
||||
label: t('export'),
|
||||
},
|
||||
{
|
||||
type: 'divider',
|
||||
},
|
||||
{
|
||||
key: 'all',
|
||||
label: t('exportType.all'),
|
||||
onClick: configService.exportAll,
|
||||
},
|
||||
],
|
||||
icon: <Icon icon={HardDriveDownload} />,
|
||||
key: 'export',
|
||||
label: t('export'),
|
||||
},
|
||||
{
|
||||
type: 'divider',
|
||||
},
|
||||
];
|
||||
].filter(Boolean) as ItemType[]);
|
||||
|
||||
const helps: MenuProps['items'] = [
|
||||
{
|
||||
|
|
@ -209,13 +215,13 @@ export const useMenu = () => {
|
|||
{
|
||||
type: 'divider',
|
||||
},
|
||||
...(isLoginWithClerk ? profile : []),
|
||||
...(isLogin ? settings : []),
|
||||
...(isLoginWithClerk ? profile : []),
|
||||
/* ↓ cloud slot ↓ */
|
||||
|
||||
/* ↑ cloud slot ↑ */
|
||||
...(canInstall ? pwa : []),
|
||||
...(isLogin ? data : []),
|
||||
...data,
|
||||
...helps,
|
||||
].filter(Boolean) as MenuProps['items'];
|
||||
|
||||
|
|
|
|||
|
|
@ -1,16 +1,65 @@
|
|||
import { createTRPCClient, httpBatchLink } from '@trpc/client';
|
||||
import superjson from 'superjson';
|
||||
|
||||
import type { EdgeRouter } from '@/server/routers';
|
||||
import { createHeaderWithAuth } from '@/services/_auth';
|
||||
import { fetchErrorNotification } from '@/components/FetchErrorNotification';
|
||||
import type { EdgeRouter } from '@/server/routers/edge';
|
||||
import type { LambdaRouter } from '@/server/routers/lambda';
|
||||
import { withBasePath } from '@/utils/basePath';
|
||||
|
||||
export const edgeClient = createTRPCClient<EdgeRouter>({
|
||||
links: [
|
||||
httpBatchLink({
|
||||
headers: async () => createHeaderWithAuth(),
|
||||
headers: async () => {
|
||||
// dynamic import to avoid circular dependency
|
||||
const { createHeaderWithAuth } = await import('@/services/_auth');
|
||||
|
||||
return createHeaderWithAuth();
|
||||
},
|
||||
transformer: superjson,
|
||||
url: withBasePath('/trpc/edge'),
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
export type ErrorResponse = ErrorItem[];
|
||||
|
||||
export interface ErrorItem {
|
||||
error: {
|
||||
json: {
|
||||
code: number;
|
||||
data: Data;
|
||||
message: string;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
export interface Data {
|
||||
code: string;
|
||||
httpStatus: number;
|
||||
path: string;
|
||||
stack: string;
|
||||
}
|
||||
|
||||
export const lambdaClient = createTRPCClient<LambdaRouter>({
|
||||
links: [
|
||||
httpBatchLink({
|
||||
fetch: async (input, init) => {
|
||||
const response = await fetch(input, init);
|
||||
if (response.ok) return response;
|
||||
|
||||
const errorRes: ErrorResponse = await response.clone().json();
|
||||
|
||||
errorRes.forEach((item) => {
|
||||
const errorData = item.error.json;
|
||||
|
||||
const status = errorData.data.httpStatus;
|
||||
fetchErrorNotification.error({ errorMessage: errorData.message, status });
|
||||
});
|
||||
|
||||
return response;
|
||||
},
|
||||
transformer: superjson,
|
||||
url: '/trpc/lambda',
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,4 +1,9 @@
|
|||
import { ListObjectsCommand, PutObjectCommand, S3Client } from '@aws-sdk/client-s3';
|
||||
import {
|
||||
GetObjectCommand,
|
||||
ListObjectsCommand,
|
||||
PutObjectCommand,
|
||||
S3Client,
|
||||
} from '@aws-sdk/client-s3';
|
||||
import { getSignedUrl } from '@aws-sdk/s3-request-presigner';
|
||||
import { z } from 'zod';
|
||||
|
||||
|
|
@ -46,6 +51,21 @@ export class S3 {
|
|||
return listFileSchema.parse(res.Contents);
|
||||
}
|
||||
|
||||
public async getFileContent(key: string): Promise<string> {
|
||||
const command = new GetObjectCommand({
|
||||
Bucket: this.bucket,
|
||||
Key: key,
|
||||
});
|
||||
|
||||
const response = await this.client.send(command);
|
||||
|
||||
if (!response.Body) {
|
||||
throw new Error(`No body in response with ${key}`);
|
||||
}
|
||||
|
||||
return response.Body.transformToString();
|
||||
}
|
||||
|
||||
public async createPreSignedUrl(key: string): Promise<string> {
|
||||
const command = new PutObjectCommand({
|
||||
ACL: 'public-read',
|
||||
|
|
|
|||
62
src/server/keyVaultsEncrypt/index.test.ts
Normal file
62
src/server/keyVaultsEncrypt/index.test.ts
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
// @vitest-environment node
|
||||
import { beforeEach, describe, expect, it } from 'vitest';
|
||||
|
||||
import { KeyVaultsGateKeeper } from './index';
|
||||
|
||||
describe('KeyVaultsGateKeeper', () => {
|
||||
let gateKeeper: KeyVaultsGateKeeper;
|
||||
|
||||
beforeEach(async () => {
|
||||
process.env.KEY_VAULTS_SECRET = 'Q10pwdq00KXUu9R+c8A8p4PSlIRWi7KwgUophBtkHVk=';
|
||||
// 在每个测试用例运行前初始化 KeyVaultsGateKeeper 实例
|
||||
gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey();
|
||||
});
|
||||
|
||||
it('should encrypt and decrypt data correctly', async () => {
|
||||
const originalData = 'sensitive user data';
|
||||
|
||||
// 加密数据
|
||||
const encryptedData = await gateKeeper.encrypt(originalData);
|
||||
|
||||
// 解密数据
|
||||
const decryptionResult = await gateKeeper.decrypt(encryptedData);
|
||||
|
||||
// 断言解密后的明文与原始数据相同
|
||||
expect(decryptionResult.plaintext).toBe(originalData);
|
||||
// 断言解密是真实的(通过认证)
|
||||
expect(decryptionResult.wasAuthentic).toBe(true);
|
||||
});
|
||||
|
||||
it('should return empty plaintext and false authenticity for invalid encrypted data', async () => {
|
||||
const invalidEncryptedData = 'invalid:encrypted:data';
|
||||
|
||||
// 尝试解密无效的加密数据
|
||||
const decryptionResult = await gateKeeper.decrypt(invalidEncryptedData);
|
||||
|
||||
// 断言解密后的明文为空字符串
|
||||
expect(decryptionResult.plaintext).toBe('');
|
||||
// 断言解密是不真实的(未通过认证)
|
||||
expect(decryptionResult.wasAuthentic).toBe(false);
|
||||
});
|
||||
|
||||
it('should throw an error if KEY_VAULTS_SECRET is not set', async () => {
|
||||
// 将 KEY_VAULTS_SECRET 设为 undefined
|
||||
const originalSecretKey = process.env.KEY_VAULTS_SECRET;
|
||||
process.env.KEY_VAULTS_SECRET = '';
|
||||
|
||||
// 断言在 KEY_VAULTS_SECRET 未设置时会抛出错误
|
||||
try {
|
||||
await KeyVaultsGateKeeper.initWithEnvKey();
|
||||
} catch (e) {
|
||||
expect(e).toEqual(
|
||||
Error(` \`KEY_VAULTS_SECRET\` is not set, please set it in your environment variables.
|
||||
|
||||
If you don't have it, please run \`openssl rand -base64 32\` to create one.
|
||||
`),
|
||||
);
|
||||
}
|
||||
|
||||
// 恢复 KEY_VAULTS_SECRET 的原始值
|
||||
process.env.KEY_VAULTS_SECRET = originalSecretKey;
|
||||
});
|
||||
});
|
||||
93
src/server/keyVaultsEncrypt/index.ts
Normal file
93
src/server/keyVaultsEncrypt/index.ts
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
import { getServerDBConfig } from '@/config/db';
|
||||
|
||||
interface DecryptionResult {
|
||||
plaintext: string;
|
||||
wasAuthentic: boolean;
|
||||
}
|
||||
|
||||
export class KeyVaultsGateKeeper {
|
||||
private aesKey: CryptoKey;
|
||||
|
||||
constructor(aesKey: CryptoKey) {
|
||||
this.aesKey = aesKey;
|
||||
}
|
||||
|
||||
static initWithEnvKey = async () => {
|
||||
const { KEY_VAULTS_SECRET } = getServerDBConfig();
|
||||
if (!KEY_VAULTS_SECRET)
|
||||
throw new Error(` \`KEY_VAULTS_SECRET\` is not set, please set it in your environment variables.
|
||||
|
||||
If you don't have it, please run \`openssl rand -base64 32\` to create one.
|
||||
`);
|
||||
|
||||
const rawKey = Buffer.from(KEY_VAULTS_SECRET, 'base64'); // 确保密钥是32字节(256位)
|
||||
const aesKey = await crypto.subtle.importKey(
|
||||
'raw',
|
||||
rawKey,
|
||||
{ length: 256, name: 'AES-GCM' },
|
||||
false,
|
||||
['encrypt', 'decrypt'],
|
||||
);
|
||||
return new KeyVaultsGateKeeper(aesKey);
|
||||
};
|
||||
|
||||
/**
|
||||
* encrypt user private data
|
||||
*/
|
||||
encrypt = async (keyVault: string): Promise<string> => {
|
||||
const iv = crypto.getRandomValues(new Uint8Array(12)); // 对于GCM,推荐使用12字节的IV
|
||||
const encodedKeyVault = new TextEncoder().encode(keyVault);
|
||||
|
||||
const encryptedData = await crypto.subtle.encrypt(
|
||||
{
|
||||
iv: iv,
|
||||
name: 'AES-GCM',
|
||||
},
|
||||
this.aesKey,
|
||||
encodedKeyVault,
|
||||
);
|
||||
|
||||
const buffer = Buffer.from(encryptedData);
|
||||
const authTag = buffer.slice(-16); // 认证标签在加密数据的最后16字节
|
||||
const encrypted = buffer.slice(0, -16); // 剩下的是加密数据
|
||||
|
||||
return `${Buffer.from(iv).toString('hex')}:${authTag.toString('hex')}:${encrypted.toString('hex')}`;
|
||||
};
|
||||
|
||||
// 假设密钥和加密数据是从外部获取的
|
||||
decrypt = async (encryptedData: string): Promise<DecryptionResult> => {
|
||||
const parts = encryptedData.split(':');
|
||||
if (parts.length !== 3) {
|
||||
throw new Error('Invalid encrypted data format');
|
||||
}
|
||||
|
||||
const iv = Buffer.from(parts[0], 'hex');
|
||||
const authTag = Buffer.from(parts[1], 'hex');
|
||||
const encrypted = Buffer.from(parts[2], 'hex');
|
||||
|
||||
// 合并加密数据和认证标签
|
||||
const combined = Buffer.concat([encrypted, authTag]);
|
||||
|
||||
try {
|
||||
const decryptedBuffer = await crypto.subtle.decrypt(
|
||||
{
|
||||
iv: iv,
|
||||
name: 'AES-GCM',
|
||||
},
|
||||
this.aesKey,
|
||||
combined,
|
||||
);
|
||||
|
||||
const decrypted = new TextDecoder().decode(decryptedBuffer);
|
||||
return {
|
||||
plaintext: decrypted,
|
||||
wasAuthentic: true,
|
||||
};
|
||||
} catch {
|
||||
return {
|
||||
plaintext: '',
|
||||
wasAuthentic: false,
|
||||
};
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
@ -3,6 +3,6 @@
|
|||
*/
|
||||
import { createCallerFactory } from '@/libs/trpc';
|
||||
|
||||
import { edgeRouter } from './routers';
|
||||
import { edgeRouter } from './routers/edge';
|
||||
|
||||
export const createCaller = createCallerFactory(edgeRouter);
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
/**
|
||||
* This file contains the root router of lobe chat tRPC-backend
|
||||
* This file contains the root router of Lobe Chat tRPC-backend
|
||||
*/
|
||||
import { publicProcedure, router } from '@/libs/trpc';
|
||||
|
||||
import { configRouter } from './edge/config';
|
||||
import { uploadRouter } from './edge/upload';
|
||||
import { configRouter } from './config';
|
||||
import { uploadRouter } from './upload';
|
||||
|
||||
export const edgeRouter = router({
|
||||
config: configRouter,
|
||||
49
src/server/routers/lambda/file.ts
Normal file
49
src/server/routers/lambda/file.ts
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
import { z } from 'zod';
|
||||
|
||||
import { FileModel } from '@/database/server/models/file';
|
||||
import { authedProcedure, router } from '@/libs/trpc';
|
||||
import { UploadFileSchema } from '@/types/files';
|
||||
|
||||
const fileProcedure = authedProcedure.use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: { fileModel: new FileModel(ctx.userId) },
|
||||
});
|
||||
});
|
||||
|
||||
export const fileRouter = router({
|
||||
createFile: fileProcedure
|
||||
.input(
|
||||
UploadFileSchema.omit({ data: true, saveMode: true, url: true }).extend({ url: z.string() }),
|
||||
)
|
||||
.mutation(async ({ ctx, input }) => {
|
||||
return ctx.fileModel.create({
|
||||
fileType: input.fileType,
|
||||
metadata: input.metadata,
|
||||
name: input.name,
|
||||
size: input.size,
|
||||
url: input.url,
|
||||
});
|
||||
}),
|
||||
|
||||
findById: fileProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
}),
|
||||
)
|
||||
.query(async ({ ctx, input }) => {
|
||||
return ctx.fileModel.findById(input.id);
|
||||
}),
|
||||
|
||||
removeAllFiles: fileProcedure.mutation(async ({ ctx }) => {
|
||||
return ctx.fileModel.clear();
|
||||
}),
|
||||
|
||||
removeFile: fileProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => {
|
||||
return ctx.fileModel.delete(input.id);
|
||||
}),
|
||||
});
|
||||
|
||||
export type FileRouter = typeof fileRouter;
|
||||
54
src/server/routers/lambda/importer.ts
Normal file
54
src/server/routers/lambda/importer.ts
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
// import urlJoin from 'url-join';
|
||||
import { TRPCError } from '@trpc/server';
|
||||
import { z } from 'zod';
|
||||
|
||||
// import { fileEnv } from '@/config/file';
|
||||
import { DataImporter } from '@/database/server/modules/DataImporter';
|
||||
import { authedProcedure, router } from '@/libs/trpc';
|
||||
import { S3 } from '@/server/files/s3';
|
||||
import { ImportResults, ImporterEntryData } from '@/types/importer';
|
||||
|
||||
export const importerRouter = router({
|
||||
importByFile: authedProcedure
|
||||
.input(z.object({ pathname: z.string() }))
|
||||
.mutation(async ({ input, ctx }): Promise<ImportResults> => {
|
||||
let data: ImporterEntryData | undefined;
|
||||
|
||||
try {
|
||||
const s3 = new S3();
|
||||
const dataStr = await s3.getFileContent(input.pathname);
|
||||
data = JSON.parse(dataStr);
|
||||
} catch {
|
||||
data = undefined;
|
||||
}
|
||||
|
||||
if (!data) {
|
||||
throw new TRPCError({
|
||||
code: 'BAD_REQUEST',
|
||||
message: `Failed to read file at ${input.pathname}`,
|
||||
});
|
||||
}
|
||||
|
||||
const dataImporter = new DataImporter(ctx.userId);
|
||||
|
||||
return dataImporter.importData(data);
|
||||
}),
|
||||
|
||||
importByPost: authedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
data: z.object({
|
||||
messages: z.array(z.any()).optional(),
|
||||
sessionGroups: z.array(z.any()).optional(),
|
||||
sessions: z.array(z.any()).optional(),
|
||||
topics: z.array(z.any()).optional(),
|
||||
version: z.number(),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }): Promise<ImportResults> => {
|
||||
const dataImporter = new DataImporter(ctx.userId);
|
||||
|
||||
return dataImporter.importData(input.data);
|
||||
}),
|
||||
});
|
||||
28
src/server/routers/lambda/index.ts
Normal file
28
src/server/routers/lambda/index.ts
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* This file contains the root router of Lobe Chat tRPC-backend
|
||||
*/
|
||||
import { publicProcedure, router } from '@/libs/trpc';
|
||||
|
||||
// router that connect to db
|
||||
import { fileRouter } from './file';
|
||||
import { importerRouter } from './importer';
|
||||
import { messageRouter } from './message';
|
||||
import { pluginRouter } from './plugin';
|
||||
import { sessionRouter } from './session';
|
||||
import { sessionGroupRouter } from './sessionGroup';
|
||||
import { topicRouter } from './topic';
|
||||
import { userRouter } from './user';
|
||||
|
||||
export const lambdaRouter = router({
|
||||
file: fileRouter,
|
||||
healthcheck: publicProcedure.query(() => "i'm live!"),
|
||||
importer: importerRouter,
|
||||
message: messageRouter,
|
||||
plugin: pluginRouter,
|
||||
session: sessionRouter,
|
||||
sessionGroup: sessionGroupRouter,
|
||||
topic: topicRouter,
|
||||
user: userRouter,
|
||||
});
|
||||
|
||||
export type LambdaRouter = typeof lambdaRouter;
|
||||
165
src/server/routers/lambda/message.ts
Normal file
165
src/server/routers/lambda/message.ts
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
import { z } from 'zod';
|
||||
|
||||
import { MessageModel } from '@/database/server/models/message';
|
||||
import { authedProcedure, publicProcedure, router } from '@/libs/trpc';
|
||||
import { ChatMessage } from '@/types/message';
|
||||
import { BatchTaskResult } from '@/types/service';
|
||||
|
||||
type ChatMessageList = ChatMessage[];
|
||||
|
||||
const messageProcedure = authedProcedure.use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: { messageModel: new MessageModel(ctx.userId) },
|
||||
});
|
||||
});
|
||||
|
||||
export const messageRouter = router({
|
||||
batchCreateMessages: messageProcedure
|
||||
.input(z.array(z.any()))
|
||||
.mutation(async ({ input, ctx }): Promise<BatchTaskResult> => {
|
||||
const data = await ctx.messageModel.batchCreate(input);
|
||||
|
||||
return { added: data.rowCount as number, ids: [], skips: [], success: true };
|
||||
}),
|
||||
|
||||
count: messageProcedure.query(async ({ ctx }) => {
|
||||
return ctx.messageModel.count();
|
||||
}),
|
||||
countToday: messageProcedure.query(async ({ ctx }) => {
|
||||
return ctx.messageModel.countToday();
|
||||
}),
|
||||
|
||||
createMessage: messageProcedure
|
||||
.input(z.object({}).passthrough().partial())
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const data = await ctx.messageModel.create(input as any);
|
||||
|
||||
return data.id;
|
||||
}),
|
||||
|
||||
getAllMessages: messageProcedure.query(async ({ ctx }): Promise<ChatMessageList> => {
|
||||
return ctx.messageModel.queryAll();
|
||||
}),
|
||||
|
||||
getAllMessagesInSession: messageProcedure
|
||||
.input(
|
||||
z.object({
|
||||
sessionId: z.string().nullable().optional(),
|
||||
}),
|
||||
)
|
||||
.query(async ({ ctx, input }): Promise<ChatMessageList> => {
|
||||
return ctx.messageModel.queryBySessionId(input.sessionId);
|
||||
}),
|
||||
|
||||
getMessages: publicProcedure
|
||||
.input(
|
||||
z.object({
|
||||
current: z.number().optional(),
|
||||
pageSize: z.number().optional(),
|
||||
sessionId: z.string().nullable().optional(),
|
||||
topicId: z.string().nullable().optional(),
|
||||
}),
|
||||
)
|
||||
.query(async ({ input, ctx }) => {
|
||||
if (!ctx.userId) return [];
|
||||
|
||||
const messageModel = new MessageModel(ctx.userId);
|
||||
|
||||
return messageModel.query(input);
|
||||
}),
|
||||
|
||||
removeAllMessages: messageProcedure.mutation(async ({ ctx }) => {
|
||||
return ctx.messageModel.deleteAllMessages();
|
||||
}),
|
||||
|
||||
removeMessage: messageProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
return ctx.messageModel.deleteMessage(input.id);
|
||||
}),
|
||||
|
||||
removeMessages: messageProcedure
|
||||
.input(
|
||||
z.object({
|
||||
sessionId: z.string().nullable().optional(),
|
||||
topicId: z.string().nullable().optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
return ctx.messageModel.deleteMessages(input.sessionId, input.topicId);
|
||||
}),
|
||||
|
||||
searchMessages: messageProcedure
|
||||
.input(z.object({ keywords: z.string() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
return ctx.messageModel.queryByKeyword(input.keywords);
|
||||
}),
|
||||
|
||||
update: messageProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
value: z.object({}).passthrough().partial(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
return ctx.messageModel.update(input.id, input.value);
|
||||
}),
|
||||
|
||||
updatePluginState: messageProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
value: z.object({}).passthrough(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
return ctx.messageModel.updatePluginState(input.id, input.value);
|
||||
}),
|
||||
|
||||
updateTTS: messageProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
value: z
|
||||
.object({
|
||||
contentMd5: z.string().optional(),
|
||||
fileId: z.string().optional(),
|
||||
voice: z.string().optional(),
|
||||
})
|
||||
.or(z.literal(false)),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
if (input.value === false) {
|
||||
return ctx.messageModel.deleteMessageTTS(input.id);
|
||||
}
|
||||
|
||||
return ctx.messageModel.updateTTS(input.id, input.value);
|
||||
}),
|
||||
|
||||
updateTranslate: messageProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
value: z
|
||||
.object({
|
||||
content: z.string().optional(),
|
||||
from: z.string().optional(),
|
||||
to: z.string(),
|
||||
})
|
||||
.or(z.literal(false)),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
if (input.value === false) {
|
||||
return ctx.messageModel.deleteMessageTranslate(input.id);
|
||||
}
|
||||
|
||||
return ctx.messageModel.updateTranslate(input.id, input.value);
|
||||
}),
|
||||
});
|
||||
|
||||
export type MessageRouter = typeof messageRouter;
|
||||
100
src/server/routers/lambda/plugin.ts
Normal file
100
src/server/routers/lambda/plugin.ts
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
import { z } from 'zod';
|
||||
|
||||
import { PluginModel } from '@/database/server/models/plugin';
|
||||
import { authedProcedure, publicProcedure, router } from '@/libs/trpc';
|
||||
import { LobeTool } from '@/types/tool';
|
||||
|
||||
const pluginProcedure = authedProcedure.use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: { pluginModel: new PluginModel(ctx.userId) },
|
||||
});
|
||||
});
|
||||
|
||||
export const pluginRouter = router({
|
||||
createOrInstallPlugin: pluginProcedure
|
||||
.input(
|
||||
z.object({
|
||||
customParams: z.any(),
|
||||
identifier: z.string(),
|
||||
manifest: z.any(),
|
||||
type: z.enum(['plugin', 'customPlugin']),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const result = await ctx.pluginModel.findById(input.identifier);
|
||||
|
||||
// if not exist, we should create the plugin
|
||||
if (!result) {
|
||||
const data = await ctx.pluginModel.create({
|
||||
customParams: input.customParams,
|
||||
identifier: input.identifier,
|
||||
manifest: input.manifest,
|
||||
type: input.type,
|
||||
});
|
||||
|
||||
return data.identifier;
|
||||
}
|
||||
|
||||
// or we can just update the plugin manifest
|
||||
await ctx.pluginModel.update(input.identifier, { manifest: input.manifest });
|
||||
}),
|
||||
|
||||
createPlugin: pluginProcedure
|
||||
.input(
|
||||
z.object({
|
||||
customParams: z.any(),
|
||||
identifier: z.string(),
|
||||
manifest: z.any(),
|
||||
type: z.enum(['plugin', 'customPlugin']),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const data = await ctx.pluginModel.create({
|
||||
customParams: input.customParams,
|
||||
identifier: input.identifier,
|
||||
manifest: input.manifest,
|
||||
type: input.type,
|
||||
});
|
||||
|
||||
return data.identifier;
|
||||
}),
|
||||
|
||||
getPlugins: publicProcedure.query(async ({ ctx }): Promise<LobeTool[]> => {
|
||||
if (!ctx.userId) return [];
|
||||
|
||||
const pluginModel = new PluginModel(ctx.userId);
|
||||
|
||||
return pluginModel.query();
|
||||
}),
|
||||
|
||||
removeAllPlugins: pluginProcedure.mutation(async ({ ctx }) => {
|
||||
return ctx.pluginModel.deleteAll();
|
||||
}),
|
||||
|
||||
removePlugin: pluginProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
return ctx.pluginModel.delete(input.id);
|
||||
}),
|
||||
|
||||
updatePlugin: pluginProcedure
|
||||
.input(
|
||||
z.object({
|
||||
customParams: z.any().optional(),
|
||||
id: z.string(),
|
||||
manifest: z.any().optional(),
|
||||
settings: z.any().optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
return ctx.pluginModel.update(input.id, {
|
||||
customParams: input.customParams,
|
||||
manifest: input.manifest,
|
||||
settings: input.settings,
|
||||
});
|
||||
}),
|
||||
});
|
||||
|
||||
export type PluginRouter = typeof pluginRouter;
|
||||
194
src/server/routers/lambda/session.ts
Normal file
194
src/server/routers/lambda/session.ts
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
import { z } from 'zod';
|
||||
|
||||
import { INBOX_SESSION_ID } from '@/const/session';
|
||||
import { SessionModel } from '@/database/server/models/session';
|
||||
import { SessionGroupModel } from '@/database/server/models/sessionGroup';
|
||||
import { insertAgentSchema, insertSessionSchema } from '@/database/server/schemas/lobechat';
|
||||
import { pino } from '@/libs/logger';
|
||||
import { authedProcedure, publicProcedure, router } from '@/libs/trpc';
|
||||
import { AgentChatConfigSchema } from '@/types/agent';
|
||||
import { LobeMetaDataSchema } from '@/types/meta';
|
||||
import { BatchTaskResult } from '@/types/service';
|
||||
import { ChatSessionList } from '@/types/session';
|
||||
import { merge } from '@/utils/merge';
|
||||
|
||||
const sessionProcedure = authedProcedure.use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: {
|
||||
sessionGroupModel: new SessionGroupModel(ctx.userId),
|
||||
sessionModel: new SessionModel(ctx.userId),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
export const sessionRouter = router({
|
||||
batchCreateSessions: sessionProcedure
|
||||
.input(
|
||||
z.array(
|
||||
z
|
||||
.object({
|
||||
config: z.object({}).passthrough(),
|
||||
group: z.string().optional(),
|
||||
id: z.string(),
|
||||
meta: LobeMetaDataSchema,
|
||||
pinned: z.boolean().optional(),
|
||||
type: z.string(),
|
||||
})
|
||||
.partial(),
|
||||
),
|
||||
)
|
||||
.mutation(async ({ input, ctx }): Promise<BatchTaskResult> => {
|
||||
const data = await ctx.sessionModel.batchCreate(
|
||||
input.map((item) => ({
|
||||
...item,
|
||||
...item.meta,
|
||||
})) as any,
|
||||
);
|
||||
|
||||
return { added: data.rowCount as number, ids: [], skips: [], success: true };
|
||||
}),
|
||||
|
||||
cloneSession: sessionProcedure
|
||||
.input(z.object({ id: z.string(), newTitle: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const data = await ctx.sessionModel.duplicate(input.id, input.newTitle);
|
||||
|
||||
return data?.id;
|
||||
}),
|
||||
|
||||
countSessions: sessionProcedure.query(async ({ ctx }) => {
|
||||
return ctx.sessionModel.count();
|
||||
}),
|
||||
|
||||
createSession: sessionProcedure
|
||||
.input(
|
||||
z.object({
|
||||
config: insertAgentSchema
|
||||
.omit({ chatConfig: true, plugins: true, tags: true, tts: true })
|
||||
.passthrough()
|
||||
.partial(),
|
||||
session: insertSessionSchema.omit({ createdAt: true, updatedAt: true }).partial(),
|
||||
type: z.enum(['agent', 'group']),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const data = await ctx.sessionModel.create(input);
|
||||
|
||||
return data.id;
|
||||
}),
|
||||
|
||||
getGroupedSessions: publicProcedure.query(async ({ ctx }): Promise<ChatSessionList> => {
|
||||
if (!ctx.userId)
|
||||
return {
|
||||
sessionGroups: [],
|
||||
sessions: [],
|
||||
};
|
||||
|
||||
const sessionModel = new SessionModel(ctx.userId);
|
||||
|
||||
return sessionModel.queryWithGroups();
|
||||
}),
|
||||
|
||||
getSessionConfig: sessionProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
}),
|
||||
)
|
||||
.query(async ({ input, ctx }) => {
|
||||
if (input.id === INBOX_SESSION_ID) {
|
||||
const item = await ctx.sessionModel.findByIdOrSlug(INBOX_SESSION_ID);
|
||||
// if there is no session for user, create one
|
||||
if (!item) {
|
||||
const res = await ctx.sessionModel.createInbox();
|
||||
pino.info('create inbox session', res);
|
||||
}
|
||||
}
|
||||
|
||||
const session = await ctx.sessionModel.findByIdOrSlug(input.id);
|
||||
|
||||
if (!session) throw new Error('Session not found');
|
||||
|
||||
return session.agent;
|
||||
}),
|
||||
|
||||
getSessions: sessionProcedure
|
||||
.input(
|
||||
z.object({
|
||||
current: z.number().optional(),
|
||||
pageSize: z.number().optional(),
|
||||
}),
|
||||
)
|
||||
.query(async ({ input, ctx }) => {
|
||||
const { current, pageSize } = input;
|
||||
|
||||
return ctx.sessionModel.query({ current, pageSize });
|
||||
}),
|
||||
|
||||
removeAllSessions: sessionProcedure.mutation(async ({ ctx }) => {
|
||||
return ctx.sessionModel.deleteAll();
|
||||
}),
|
||||
|
||||
removeSession: sessionProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
return ctx.sessionModel.delete(input.id);
|
||||
}),
|
||||
|
||||
searchSessions: sessionProcedure
|
||||
.input(z.object({ keywords: z.string() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
return ctx.sessionModel.queryByKeyword(input.keywords);
|
||||
}),
|
||||
|
||||
updateSession: sessionProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
value: insertSessionSchema.partial(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
return ctx.sessionModel.update(input.id, input.value);
|
||||
}),
|
||||
updateSessionChatConfig: sessionProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
value: AgentChatConfigSchema.partial(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const session = await ctx.sessionModel.findByIdOrSlug(input.id);
|
||||
|
||||
if (!session) return;
|
||||
|
||||
return ctx.sessionModel.updateConfig(session.agent.id, {
|
||||
chatConfig: merge(session.agent.chatConfig, input.value),
|
||||
});
|
||||
}),
|
||||
updateSessionConfig: sessionProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
value: z.object({}).passthrough().partial(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const session = await ctx.sessionModel.findByIdOrSlug(input.id);
|
||||
|
||||
if (!session || !input.value) return;
|
||||
|
||||
if (!session.agent) {
|
||||
throw new Error(
|
||||
'this session is not assign with agent, please contact with admin to fix this issue.',
|
||||
);
|
||||
}
|
||||
|
||||
return ctx.sessionModel.updateConfig(session.agent.id, input.value);
|
||||
}),
|
||||
});
|
||||
|
||||
export type SessionRouter = typeof sessionRouter;
|
||||
77
src/server/routers/lambda/sessionGroup.ts
Normal file
77
src/server/routers/lambda/sessionGroup.ts
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
import { z } from 'zod';
|
||||
|
||||
import { SessionGroupModel } from '@/database/server/models/sessionGroup';
|
||||
import { insertSessionGroupSchema } from '@/database/server/schemas/lobechat';
|
||||
import { authedProcedure, router } from '@/libs/trpc';
|
||||
import { SessionGroupItem } from '@/types/session';
|
||||
|
||||
const sessionProcedure = authedProcedure.use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: {
|
||||
sessionGroupModel: new SessionGroupModel(ctx.userId),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
export const sessionGroupRouter = router({
|
||||
createSessionGroup: sessionProcedure
|
||||
.input(
|
||||
z.object({
|
||||
name: z.string(),
|
||||
sort: z.number().optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const data = await ctx.sessionGroupModel.create({
|
||||
name: input.name,
|
||||
sort: input.sort,
|
||||
});
|
||||
|
||||
return data?.id;
|
||||
}),
|
||||
|
||||
getSessionGroup: sessionProcedure.query(async ({ ctx }): Promise<SessionGroupItem[]> => {
|
||||
return ctx.sessionGroupModel.query() as any;
|
||||
}),
|
||||
|
||||
removeAllSessionGroups: sessionProcedure.mutation(async ({ ctx }) => {
|
||||
return ctx.sessionGroupModel.deleteAll();
|
||||
}),
|
||||
|
||||
removeSessionGroup: sessionProcedure
|
||||
.input(z.object({ id: z.string(), removeChildren: z.boolean().optional() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
return ctx.sessionGroupModel.delete(input.id);
|
||||
}),
|
||||
|
||||
updateSessionGroup: sessionProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
value: insertSessionGroupSchema.partial(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
return ctx.sessionGroupModel.update(input.id, input.value);
|
||||
}),
|
||||
updateSessionGroupOrder: sessionProcedure
|
||||
.input(
|
||||
z.object({
|
||||
sortMap: z.array(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
sort: z.number(),
|
||||
}),
|
||||
),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
console.log('sortMap:', input.sortMap);
|
||||
|
||||
return ctx.sessionGroupModel.updateOrder(input.sortMap);
|
||||
}),
|
||||
});
|
||||
|
||||
export type SessionGroupRouter = typeof sessionGroupRouter;
|
||||
134
src/server/routers/lambda/topic.ts
Normal file
134
src/server/routers/lambda/topic.ts
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
import { z } from 'zod';
|
||||
|
||||
import { TopicModel } from '@/database/server/models/topic';
|
||||
import { authedProcedure, publicProcedure, router } from '@/libs/trpc';
|
||||
import { BatchTaskResult } from '@/types/service';
|
||||
|
||||
const topicProcedure = authedProcedure.use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: { topicModel: new TopicModel(ctx.userId) },
|
||||
});
|
||||
});
|
||||
|
||||
export const topicRouter = router({
|
||||
batchCreateTopics: topicProcedure
|
||||
.input(
|
||||
z.array(
|
||||
z.object({
|
||||
favorite: z.boolean().optional(),
|
||||
id: z.string().optional(),
|
||||
messages: z.array(z.string()).optional(),
|
||||
sessionId: z.string().optional(),
|
||||
title: z.string(),
|
||||
}),
|
||||
),
|
||||
)
|
||||
.mutation(async ({ input, ctx }): Promise<BatchTaskResult> => {
|
||||
const data = await ctx.topicModel.batchCreate(
|
||||
input.map((item) => ({
|
||||
...item,
|
||||
})) as any,
|
||||
);
|
||||
|
||||
return { added: data.length, ids: [], skips: [], success: true };
|
||||
}),
|
||||
|
||||
batchDelete: topicProcedure
|
||||
.input(z.object({ ids: z.array(z.string()) }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
return ctx.topicModel.batchDelete(input.ids);
|
||||
}),
|
||||
|
||||
batchDeleteBySessionId: topicProcedure
|
||||
.input(z.object({ id: z.string().nullable().optional() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
return ctx.topicModel.batchDeleteBySessionId(input.id);
|
||||
}),
|
||||
|
||||
cloneTopic: topicProcedure
|
||||
.input(z.object({ id: z.string(), newTitle: z.string().optional() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const data = await ctx.topicModel.duplicate(input.id, input.newTitle);
|
||||
|
||||
return data.topic.id;
|
||||
}),
|
||||
|
||||
countTopics: topicProcedure.query(async ({ ctx }) => {
|
||||
return ctx.topicModel.count();
|
||||
}),
|
||||
|
||||
createTopic: topicProcedure
|
||||
.input(
|
||||
z.object({
|
||||
favorite: z.boolean().optional(),
|
||||
messages: z.array(z.string()).optional(),
|
||||
sessionId: z.string().nullable().optional(),
|
||||
title: z.string(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const data = await ctx.topicModel.create(input);
|
||||
|
||||
return data.id;
|
||||
}),
|
||||
|
||||
getAllTopics: topicProcedure.query(async ({ ctx }) => {
|
||||
return ctx.topicModel.queryAll();
|
||||
}),
|
||||
|
||||
getTopics: publicProcedure
|
||||
.input(
|
||||
z.object({
|
||||
current: z.number().optional(),
|
||||
pageSize: z.number().optional(),
|
||||
sessionId: z.string().nullable().optional(),
|
||||
}),
|
||||
)
|
||||
.query(async ({ input, ctx }) => {
|
||||
if (!ctx.userId) return [];
|
||||
|
||||
const topicModel = new TopicModel(ctx.userId);
|
||||
|
||||
return topicModel.query(input);
|
||||
}),
|
||||
|
||||
hasTopics: topicProcedure.query(async ({ ctx }) => {
|
||||
return (await ctx.topicModel.count()) === 0;
|
||||
}),
|
||||
|
||||
removeAllTopics: topicProcedure.mutation(async ({ ctx }) => {
|
||||
return ctx.topicModel.deleteAll();
|
||||
}),
|
||||
|
||||
removeTopic: topicProcedure
|
||||
.input(z.object({ id: z.string() }))
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
return ctx.topicModel.delete(input.id);
|
||||
}),
|
||||
|
||||
searchTopics: topicProcedure
|
||||
.input(z.object({ keywords: z.string(), sessionId: z.string().nullable().optional() }))
|
||||
.query(async ({ input, ctx }) => {
|
||||
return ctx.topicModel.queryByKeyword(input.keywords, input.sessionId);
|
||||
}),
|
||||
|
||||
updateTopic: topicProcedure
|
||||
.input(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
value: z.object({
|
||||
favorite: z.boolean().optional(),
|
||||
messages: z.array(z.string()).optional(),
|
||||
sessionId: z.string().optional(),
|
||||
title: z.string().optional(),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
return ctx.topicModel.update(input.id, input.value);
|
||||
}),
|
||||
});
|
||||
|
||||
export type TopicRouter = typeof topicRouter;
|
||||
57
src/server/routers/lambda/user.ts
Normal file
57
src/server/routers/lambda/user.ts
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
import { z } from 'zod';
|
||||
|
||||
import { MessageModel } from '@/database/server/models/message';
|
||||
import { SessionModel } from '@/database/server/models/session';
|
||||
import { UserModel } from '@/database/server/models/user';
|
||||
import { authedProcedure, router } from '@/libs/trpc';
|
||||
import { UserInitializationState, UserPreference } from '@/types/user';
|
||||
|
||||
const userProcedure = authedProcedure.use(async (opts) => {
|
||||
return opts.next({
|
||||
ctx: { userModel: new UserModel() },
|
||||
});
|
||||
});
|
||||
|
||||
export const userRouter = router({
|
||||
getUserState: userProcedure.query(async ({ ctx }): Promise<UserInitializationState> => {
|
||||
const state = await ctx.userModel.getUserState(ctx.userId);
|
||||
|
||||
const messageModel = new MessageModel(ctx.userId);
|
||||
const messageCount = await messageModel.count();
|
||||
|
||||
const sessionModel = new SessionModel(ctx.userId);
|
||||
const sessionCount = await sessionModel.count();
|
||||
|
||||
return {
|
||||
canEnablePWAGuide: messageCount >= 2,
|
||||
canEnableTrace: messageCount >= 4,
|
||||
// 有消息,或者创建过助手,则认为有 conversation
|
||||
hasConversation: messageCount > 0 || sessionCount > 1,
|
||||
|
||||
isOnboard: state.isOnboarded || false,
|
||||
preference: state.preference as UserPreference,
|
||||
settings: state.settings,
|
||||
userId: ctx.userId,
|
||||
};
|
||||
}),
|
||||
|
||||
makeUserOnboarded: userProcedure.mutation(async ({ ctx }) => {
|
||||
return ctx.userModel.updateUser(ctx.userId, { isOnboarded: true });
|
||||
}),
|
||||
|
||||
resetSettings: userProcedure.mutation(async ({ ctx }) => {
|
||||
return ctx.userModel.deleteSetting(ctx.userId);
|
||||
}),
|
||||
|
||||
updatePreference: userProcedure.input(z.any()).mutation(async ({ ctx, input }) => {
|
||||
return ctx.userModel.updatePreference(ctx.userId, input);
|
||||
}),
|
||||
|
||||
updateSettings: userProcedure
|
||||
.input(z.object({}).passthrough())
|
||||
.mutation(async ({ ctx, input }) => {
|
||||
return ctx.userModel.updateSetting(ctx.userId, input);
|
||||
}),
|
||||
});
|
||||
|
||||
export type UserRouter = typeof userRouter;
|
||||
|
|
@ -1,9 +1,6 @@
|
|||
// import { getClientConfig } from '@/config/client';
|
||||
import { ClientService } from './client';
|
||||
import { isServerMode } from '@/const/version';
|
||||
|
||||
// import { ServerService } from './server';
|
||||
//
|
||||
// const { ENABLED_SERVER_SERVICE } = getClientConfig();
|
||||
//
|
||||
// export const fileService = ENABLED_SERVER_SERVICE ? new ServerService() : new ClientService();
|
||||
export const fileService = new ClientService();
|
||||
import { ClientService } from './client';
|
||||
import { ServerService } from './server';
|
||||
|
||||
export const fileService = isServerMode ? new ServerService() : new ClientService();
|
||||
|
|
|
|||
45
src/services/file/server.ts
Normal file
45
src/services/file/server.ts
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
import urlJoin from 'url-join';
|
||||
|
||||
import { fileEnv } from '@/config/file';
|
||||
import { lambdaClient } from '@/libs/trpc/client';
|
||||
import { FilePreview, UploadFileParams } from '@/types/files';
|
||||
|
||||
import { IFileService } from './type';
|
||||
|
||||
interface CreateFileParams extends Omit<UploadFileParams, 'url'> {
|
||||
url: string;
|
||||
}
|
||||
|
||||
export class ServerService implements IFileService {
|
||||
async createFile(params: UploadFileParams) {
|
||||
return lambdaClient.file.createFile.mutate(params as CreateFileParams);
|
||||
}
|
||||
|
||||
async getFile(id: string): Promise<FilePreview> {
|
||||
if (!fileEnv.NEXT_PUBLIC_S3_DOMAIN) {
|
||||
throw new Error('fileEnv.NEXT_PUBLIC_S3_DOMAIN is not set while enable server upload');
|
||||
}
|
||||
|
||||
const item = await lambdaClient.file.findById.query({ id });
|
||||
|
||||
if (!item) {
|
||||
throw new Error('file not found');
|
||||
}
|
||||
|
||||
return {
|
||||
fileType: item.fileType,
|
||||
id: item.id,
|
||||
name: item.name,
|
||||
saveMode: 'url',
|
||||
url: urlJoin(fileEnv.NEXT_PUBLIC_S3_DOMAIN!, item.url!),
|
||||
};
|
||||
}
|
||||
|
||||
async removeFile(id: string) {
|
||||
await lambdaClient.file.removeFile.mutate({ id });
|
||||
}
|
||||
|
||||
async removeAllFiles() {
|
||||
await lambdaClient.file.removeAllFiles.mutate();
|
||||
}
|
||||
}
|
||||
|
|
@ -1,3 +1,6 @@
|
|||
import { ClientService } from './client';
|
||||
import { isServerMode } from '@/const/version';
|
||||
|
||||
export const importService = new ClientService();
|
||||
import { ClientService } from './client';
|
||||
import { ServerService } from './server';
|
||||
|
||||
export const importService = isServerMode ? new ServerService() : new ClientService();
|
||||
|
|
|
|||
115
src/services/import/server.ts
Normal file
115
src/services/import/server.ts
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
import { DefaultErrorShape } from '@trpc/server/unstable-core-do-not-import';
|
||||
|
||||
import { edgeClient, lambdaClient } from '@/libs/trpc/client';
|
||||
import { useUserStore } from '@/store/user';
|
||||
import { ImportStage, ImporterEntryData, OnImportCallbacks } from '@/types/importer';
|
||||
import { UserSettings } from '@/types/user/settings';
|
||||
import { uuid } from '@/utils/uuid';
|
||||
|
||||
export class ServerService {
|
||||
importSettings = async (settings: UserSettings) => {
|
||||
await useUserStore.getState().importAppSettings(settings);
|
||||
};
|
||||
|
||||
importData = async (data: ImporterEntryData, callbacks?: OnImportCallbacks): Promise<void> => {
|
||||
const handleError = (e: unknown) => {
|
||||
callbacks?.onStageChange?.(ImportStage.Error);
|
||||
const error = e as DefaultErrorShape;
|
||||
|
||||
callbacks?.onError?.({
|
||||
code: error.data.code,
|
||||
httpStatus: error.data.httpStatus,
|
||||
message: error.message,
|
||||
path: error.data.path,
|
||||
});
|
||||
};
|
||||
|
||||
const totalLength =
|
||||
(data.messages?.length || 0) +
|
||||
(data.sessionGroups?.length || 0) +
|
||||
(data.sessions?.length || 0) +
|
||||
(data.topics?.length || 0);
|
||||
|
||||
if (totalLength < 500) {
|
||||
callbacks?.onStageChange?.(ImportStage.Importing);
|
||||
const time = Date.now();
|
||||
try {
|
||||
const result = await lambdaClient.importer.importByPost.mutate({ data });
|
||||
const duration = Date.now() - time;
|
||||
|
||||
callbacks?.onStageChange?.(ImportStage.Success);
|
||||
callbacks?.onSuccess?.(result, duration);
|
||||
} catch (e) {
|
||||
handleError(e);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// if the data is too large, upload it to S3 and upload by file
|
||||
const filename = `${uuid()}.json`;
|
||||
|
||||
const pathname = `import_config/${filename}`;
|
||||
|
||||
const url = await edgeClient.upload.createS3PreSignedUrl.mutate({ pathname });
|
||||
|
||||
try {
|
||||
callbacks?.onStageChange?.(ImportStage.Uploading);
|
||||
await this.uploadWithProgress(url, data, callbacks?.onFileUploading);
|
||||
} catch {
|
||||
throw new Error('Upload Error');
|
||||
}
|
||||
|
||||
callbacks?.onStageChange?.(ImportStage.Importing);
|
||||
const time = Date.now();
|
||||
try {
|
||||
const result = await lambdaClient.importer.importByFile.mutate({ pathname });
|
||||
const duration = Date.now() - time;
|
||||
callbacks?.onStageChange?.(ImportStage.Success);
|
||||
callbacks?.onSuccess?.(result, duration);
|
||||
} catch (e) {
|
||||
handleError(e);
|
||||
}
|
||||
};
|
||||
|
||||
private uploadWithProgress = async (
|
||||
url: string,
|
||||
data: object,
|
||||
onProgress: OnImportCallbacks['onFileUploading'],
|
||||
) => {
|
||||
const xhr = new XMLHttpRequest();
|
||||
|
||||
let startTime = Date.now();
|
||||
xhr.upload.addEventListener('progress', (event) => {
|
||||
if (event.lengthComputable) {
|
||||
const progress = Number(((event.loaded / event.total) * 100).toFixed(1));
|
||||
|
||||
const speedInByte = event.loaded / ((Date.now() - startTime) / 1000);
|
||||
|
||||
onProgress?.({
|
||||
// if the progress is 100, it means the file is uploaded
|
||||
// but the server is still processing it
|
||||
// so make it as 99.5 and let users think it's still uploading
|
||||
progress: progress === 100 ? 99.5 : progress,
|
||||
restTime: (event.total - event.loaded) / speedInByte,
|
||||
speed: speedInByte / 1024,
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
xhr.open('PUT', url);
|
||||
xhr.setRequestHeader('Content-Type', 'application/json');
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
xhr.addEventListener('load', () => {
|
||||
if (xhr.status >= 200 && xhr.status < 300) {
|
||||
resolve(xhr.response);
|
||||
} else {
|
||||
reject(xhr.statusText);
|
||||
}
|
||||
});
|
||||
xhr.addEventListener('error', () => reject(xhr.statusText));
|
||||
xhr.send(JSON.stringify(data));
|
||||
});
|
||||
};
|
||||
}
|
||||
|
|
@ -1,12 +1,8 @@
|
|||
// import { getClientConfig } from '@/config/client';
|
||||
// import { ServerService } from './server';
|
||||
// import { ClientService } from './client';
|
||||
//
|
||||
// const { ENABLED_SERVER_SERVICE } = getClientConfig();
|
||||
//
|
||||
// export const messageService = ENABLED_SERVER_SERVICE ? new ServerService() : new ClientService();
|
||||
import { isServerMode } from '@/const/version';
|
||||
|
||||
import { ClientService } from './client';
|
||||
import { ServerService } from './server';
|
||||
|
||||
export type { CreateMessageParams } from './type';
|
||||
|
||||
export const messageService = new ClientService();
|
||||
export const messageService = isServerMode ? new ServerService() : new ClientService();
|
||||
|
|
|
|||
93
src/services/message/server.ts
Normal file
93
src/services/message/server.ts
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
/* eslint-disable @typescript-eslint/no-unused-vars */
|
||||
import { INBOX_SESSION_ID } from '@/const/session';
|
||||
import { lambdaClient } from '@/libs/trpc/client';
|
||||
import { ChatMessage, ChatMessageError, ChatTTS, ChatTranslate } from '@/types/message';
|
||||
|
||||
import { CreateMessageParams, IMessageService } from './type';
|
||||
|
||||
export class ServerService implements IMessageService {
|
||||
createMessage({ sessionId, ...params }: CreateMessageParams): Promise<string> {
|
||||
return lambdaClient.message.createMessage.mutate({
|
||||
...params,
|
||||
sessionId: this.toDbSessionId(sessionId),
|
||||
});
|
||||
}
|
||||
|
||||
batchCreateMessages(messages: ChatMessage[]): Promise<any> {
|
||||
return lambdaClient.message.batchCreateMessages.mutate(messages);
|
||||
}
|
||||
|
||||
getMessages(sessionId?: string, topicId?: string | undefined): Promise<ChatMessage[]> {
|
||||
return lambdaClient.message.getMessages.query({
|
||||
sessionId: this.toDbSessionId(sessionId),
|
||||
topicId,
|
||||
});
|
||||
}
|
||||
|
||||
getAllMessages(): Promise<ChatMessage[]> {
|
||||
return lambdaClient.message.getAllMessages.query();
|
||||
}
|
||||
getAllMessagesInSession(sessionId: string): Promise<ChatMessage[]> {
|
||||
return lambdaClient.message.getAllMessagesInSession.query({
|
||||
sessionId: this.toDbSessionId(sessionId),
|
||||
});
|
||||
}
|
||||
|
||||
countMessages(): Promise<number> {
|
||||
return lambdaClient.message.count.query();
|
||||
}
|
||||
countTodayMessages(): Promise<number> {
|
||||
return lambdaClient.message.countToday.query();
|
||||
}
|
||||
|
||||
updateMessageError(id: string, error: ChatMessageError): Promise<any> {
|
||||
return lambdaClient.message.update.mutate({ id, value: { error } });
|
||||
}
|
||||
|
||||
updateMessage(id: string, message: Partial<ChatMessage>): Promise<any> {
|
||||
return lambdaClient.message.update.mutate({ id, value: message });
|
||||
}
|
||||
|
||||
updateMessageTranslate(id: string, translate: Partial<ChatTranslate> | false): Promise<any> {
|
||||
return lambdaClient.message.updateTranslate.mutate({ id, value: translate as ChatTranslate });
|
||||
}
|
||||
|
||||
updateMessageTTS(id: string, tts: Partial<ChatTTS> | false): Promise<any> {
|
||||
return lambdaClient.message.updateTTS.mutate({ id, value: tts });
|
||||
}
|
||||
|
||||
updateMessagePluginState(id: string, value: any): Promise<any> {
|
||||
return lambdaClient.message.updatePluginState.mutate({ id, value });
|
||||
}
|
||||
|
||||
bindMessagesToTopic(topicId: string, messageIds: string[]): Promise<any> {
|
||||
throw new Error('Method not implemented.');
|
||||
}
|
||||
|
||||
removeMessage(id: string): Promise<any> {
|
||||
return lambdaClient.message.removeMessage.mutate({ id });
|
||||
}
|
||||
removeMessages(sessionId: string, topicId?: string | undefined): Promise<any> {
|
||||
return lambdaClient.message.removeMessages.mutate({
|
||||
sessionId: this.toDbSessionId(sessionId),
|
||||
topicId,
|
||||
});
|
||||
}
|
||||
removeAllMessages(): Promise<any> {
|
||||
return lambdaClient.message.removeAllMessages.mutate();
|
||||
}
|
||||
|
||||
private toDbSessionId(sessionId: string | undefined) {
|
||||
return sessionId === INBOX_SESSION_ID ? null : sessionId;
|
||||
}
|
||||
|
||||
async hasMessages() {
|
||||
const number = await this.countMessages();
|
||||
return number > 0;
|
||||
}
|
||||
|
||||
async messageCountToCheckTrace() {
|
||||
const number = await this.countMessages();
|
||||
return number >= 4;
|
||||
}
|
||||
}
|
||||
|
|
@ -1,11 +1,6 @@
|
|||
// import { getClientConfig } from '@/config/client';
|
||||
import { isServerMode } from '@/const/version';
|
||||
|
||||
import { ClientService } from './client';
|
||||
import { ServerService } from './server';
|
||||
|
||||
// import { ServerService } from './server';
|
||||
//
|
||||
// export type { InstallPluginParams } from './client';
|
||||
//
|
||||
// const { ENABLED_SERVER_SERVICE } = getClientConfig();
|
||||
|
||||
// export const pluginService = ENABLED_SERVER_SERVICE ? new ServerService() : new ClientService();
|
||||
export const pluginService = new ClientService();
|
||||
export const pluginService = isServerMode ? new ServerService() : new ClientService();
|
||||
|
|
|
|||
46
src/services/plugin/server.ts
Normal file
46
src/services/plugin/server.ts
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
import { LobeChatPluginManifest } from '@lobehub/chat-plugin-sdk';
|
||||
|
||||
import { lambdaClient } from '@/libs/trpc/client';
|
||||
import { LobeTool } from '@/types/tool';
|
||||
import { LobeToolCustomPlugin } from '@/types/tool/plugin';
|
||||
|
||||
import { IPluginService, InstallPluginParams } from './type';
|
||||
|
||||
export class ServerService implements IPluginService {
|
||||
installPlugin = async (plugin: InstallPluginParams) => {
|
||||
await lambdaClient.plugin.createOrInstallPlugin.mutate(plugin);
|
||||
};
|
||||
|
||||
getInstalledPlugins = (): Promise<LobeTool[]> => {
|
||||
return lambdaClient.plugin.getPlugins.query();
|
||||
};
|
||||
|
||||
async uninstallPlugin(identifier: string) {
|
||||
await lambdaClient.plugin.removePlugin.mutate({ id: identifier });
|
||||
}
|
||||
|
||||
async createCustomPlugin(customPlugin: LobeToolCustomPlugin) {
|
||||
await lambdaClient.plugin.createPlugin.mutate({ ...customPlugin, type: 'customPlugin' });
|
||||
}
|
||||
|
||||
async updatePlugin(id: string, value: LobeToolCustomPlugin) {
|
||||
await lambdaClient.plugin.updatePlugin.mutate({
|
||||
customParams: value.customParams,
|
||||
id,
|
||||
manifest: value.manifest,
|
||||
settings: value.settings,
|
||||
});
|
||||
}
|
||||
|
||||
async updatePluginManifest(id: string, manifest: LobeChatPluginManifest) {
|
||||
await lambdaClient.plugin.updatePlugin.mutate({ id, manifest });
|
||||
}
|
||||
|
||||
async removeAllPlugins() {
|
||||
await lambdaClient.plugin.removeAllPlugins.mutate();
|
||||
}
|
||||
|
||||
async updatePluginSettings(id: string, settings: any, signal?: AbortSignal) {
|
||||
await lambdaClient.plugin.updatePlugin.mutate({ id, settings }, { signal });
|
||||
}
|
||||
}
|
||||
|
|
@ -1,10 +1,6 @@
|
|||
// import { getClientConfig } from '@/config/client';
|
||||
//
|
||||
import { isServerMode } from '@/const/version';
|
||||
|
||||
import { ClientService } from './client';
|
||||
import { ServerService } from './server';
|
||||
|
||||
// import { ServerService } from './server';
|
||||
|
||||
// const { ENABLED_SERVER_SERVICE } = getClientConfig();
|
||||
|
||||
// export const sessionService = ENABLED_SERVER_SERVICE ? new ServerService() : new ClientService();
|
||||
export const sessionService = new ClientService();
|
||||
export const sessionService = isServerMode ? new ServerService() : new ClientService();
|
||||
|
|
|
|||
148
src/services/session/server.ts
Normal file
148
src/services/session/server.ts
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
/* eslint-disable @typescript-eslint/no-unused-vars */
|
||||
import { DeepPartial } from 'utility-types';
|
||||
|
||||
import { DEFAULT_AGENT_CONFIG } from '@/const/settings';
|
||||
import { lambdaClient } from '@/libs/trpc/client';
|
||||
import { useUserStore } from '@/store/user';
|
||||
import { authSelectors } from '@/store/user/selectors';
|
||||
import { LobeAgentChatConfig, LobeAgentConfig } from '@/types/agent';
|
||||
import { MetaData } from '@/types/meta';
|
||||
import { BatchTaskResult } from '@/types/service';
|
||||
import {
|
||||
ChatSessionList,
|
||||
LobeAgentSession,
|
||||
LobeSessionType,
|
||||
LobeSessions,
|
||||
SessionGroupId,
|
||||
SessionGroupItem,
|
||||
SessionGroups,
|
||||
} from '@/types/session';
|
||||
|
||||
import { ISessionService } from './type';
|
||||
|
||||
export class ServerService implements ISessionService {
|
||||
async hasSessions() {
|
||||
return (await this.countSessions()) === 0;
|
||||
}
|
||||
|
||||
createSession(type: LobeSessionType, data: Partial<LobeAgentSession>): Promise<string> {
|
||||
const { config, group, meta, ...session } = data;
|
||||
|
||||
return lambdaClient.session.createSession.mutate({
|
||||
config: { ...config, ...meta } as any,
|
||||
session: { ...session, groupId: group },
|
||||
type,
|
||||
});
|
||||
}
|
||||
|
||||
async batchCreateSessions(importSessions: LobeSessions): Promise<BatchTaskResult> {
|
||||
// TODO: remove any
|
||||
const data = await lambdaClient.session.batchCreateSessions.mutate(importSessions as any);
|
||||
console.log(data);
|
||||
return data;
|
||||
}
|
||||
|
||||
cloneSession(id: string, newTitle: string): Promise<string | undefined> {
|
||||
return lambdaClient.session.cloneSession.mutate({ id, newTitle });
|
||||
}
|
||||
|
||||
getGroupedSessions(): Promise<ChatSessionList> {
|
||||
return lambdaClient.session.getGroupedSessions.query();
|
||||
}
|
||||
|
||||
countSessions(): Promise<number> {
|
||||
return lambdaClient.session.countSessions.query();
|
||||
}
|
||||
|
||||
updateSession(
|
||||
id: string,
|
||||
data: Partial<{ group?: SessionGroupId; meta?: any; pinned?: boolean }>,
|
||||
): Promise<any> {
|
||||
const { group, pinned, meta } = data;
|
||||
return lambdaClient.session.updateSession.mutate({
|
||||
id,
|
||||
value: { groupId: group === 'default' ? null : group, pinned, ...meta },
|
||||
});
|
||||
}
|
||||
|
||||
async getSessionConfig(id: string): Promise<LobeAgentConfig> {
|
||||
const isLogin = authSelectors.isLogin(useUserStore.getState());
|
||||
if (!isLogin) return DEFAULT_AGENT_CONFIG;
|
||||
|
||||
// TODO: Need to be fixed
|
||||
// @ts-ignore
|
||||
return lambdaClient.session.getSessionConfig.query({ id });
|
||||
}
|
||||
|
||||
updateSessionConfig(
|
||||
id: string,
|
||||
config: DeepPartial<LobeAgentConfig>,
|
||||
signal?: AbortSignal,
|
||||
): Promise<any> {
|
||||
return lambdaClient.session.updateSessionConfig.mutate({ id, value: config }, { signal });
|
||||
}
|
||||
|
||||
updateSessionMeta(id: string, meta: Partial<MetaData>, signal?: AbortSignal): Promise<any> {
|
||||
return lambdaClient.session.updateSessionConfig.mutate({ id, value: meta }, { signal });
|
||||
}
|
||||
|
||||
updateSessionChatConfig(
|
||||
id: string,
|
||||
value: DeepPartial<LobeAgentChatConfig>,
|
||||
signal?: AbortSignal,
|
||||
): Promise<any> {
|
||||
return lambdaClient.session.updateSessionChatConfig.mutate({ id, value }, { signal });
|
||||
}
|
||||
|
||||
getSessionsByType(type: 'agent' | 'group' | 'all' = 'all'): Promise<LobeSessions> {
|
||||
// TODO: need be fixed
|
||||
// @ts-ignore
|
||||
return lambdaClient.session.getSessions.query({});
|
||||
}
|
||||
|
||||
searchSessions(keywords: string): Promise<LobeSessions> {
|
||||
return lambdaClient.session.searchSessions.query({ keywords });
|
||||
}
|
||||
|
||||
removeSession(id: string): Promise<any> {
|
||||
return lambdaClient.session.removeSession.mutate({ id });
|
||||
}
|
||||
|
||||
removeAllSessions(): Promise<any> {
|
||||
return lambdaClient.session.removeAllSessions.mutate();
|
||||
}
|
||||
|
||||
// ************************************** //
|
||||
// *********** SessionGroup *********** //
|
||||
// ************************************** //
|
||||
|
||||
createSessionGroup(name: string, sort?: number): Promise<string> {
|
||||
return lambdaClient.sessionGroup.createSessionGroup.mutate({ name, sort });
|
||||
}
|
||||
|
||||
getSessionGroups(): Promise<SessionGroupItem[]> {
|
||||
return lambdaClient.sessionGroup.getSessionGroup.query();
|
||||
}
|
||||
|
||||
batchCreateSessionGroups(groups: SessionGroups): Promise<BatchTaskResult> {
|
||||
return Promise.resolve({ added: 0, ids: [], skips: [], success: true });
|
||||
}
|
||||
|
||||
removeSessionGroup(id: string, removeChildren?: boolean): Promise<any> {
|
||||
return lambdaClient.sessionGroup.removeSessionGroup.mutate({ id, removeChildren });
|
||||
}
|
||||
|
||||
removeSessionGroups(): Promise<any> {
|
||||
return lambdaClient.sessionGroup.removeAllSessionGroups.mutate();
|
||||
}
|
||||
|
||||
updateSessionGroup(id: string, value: Partial<SessionGroupItem>): Promise<any> {
|
||||
// TODO: need be fixed
|
||||
// @ts-ignore
|
||||
return lambdaClient.sessionGroup.updateSessionGroup.mutate({ id, value });
|
||||
}
|
||||
|
||||
updateSessionGroupOrder(sortMap: { id: string; sort: number }[]): Promise<any> {
|
||||
return lambdaClient.sessionGroup.updateSessionGroupOrder.mutate({ sortMap });
|
||||
}
|
||||
}
|
||||
|
|
@ -1,11 +1,6 @@
|
|||
// import { getClientConfig } from '@/config/client';
|
||||
//
|
||||
// import { ClientService } from './client';
|
||||
// import { ServerService } from './server';
|
||||
//
|
||||
// const { ENABLED_SERVER_SERVICE } = getClientConfig();
|
||||
//
|
||||
// export const topicService = ENABLED_SERVER_SERVICE ? new ServerService() : new ClientService();
|
||||
import { ClientService } from './client';
|
||||
import { isServerMode } from '@/const/version';
|
||||
|
||||
export const topicService = new ClientService();
|
||||
import { ClientService } from './client';
|
||||
import { ServerService } from './server';
|
||||
|
||||
export const topicService = isServerMode ? new ServerService() : new ClientService();
|
||||
|
|
|
|||
68
src/services/topic/server.ts
Normal file
68
src/services/topic/server.ts
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
import { INBOX_SESSION_ID } from '@/const/session';
|
||||
import { lambdaClient } from '@/libs/trpc/client';
|
||||
import { CreateTopicParams, ITopicService, QueryTopicParams } from '@/services/topic/type';
|
||||
import { BatchTaskResult } from '@/types/service';
|
||||
import { ChatTopic } from '@/types/topic';
|
||||
|
||||
export class ServerService implements ITopicService {
|
||||
createTopic(params: CreateTopicParams): Promise<string> {
|
||||
return lambdaClient.topic.createTopic.mutate({
|
||||
...params,
|
||||
sessionId: this.toDbSessionId(params.sessionId),
|
||||
});
|
||||
}
|
||||
|
||||
batchCreateTopics(importTopics: ChatTopic[]): Promise<BatchTaskResult> {
|
||||
return lambdaClient.topic.batchCreateTopics.mutate(importTopics);
|
||||
}
|
||||
|
||||
cloneTopic(id: string, newTitle?: string | undefined): Promise<string> {
|
||||
return lambdaClient.topic.cloneTopic.mutate({ id, newTitle });
|
||||
}
|
||||
|
||||
getTopics(params: QueryTopicParams): Promise<ChatTopic[]> {
|
||||
return lambdaClient.topic.getTopics.query({
|
||||
...params,
|
||||
sessionId: this.toDbSessionId(params.sessionId),
|
||||
}) as any;
|
||||
}
|
||||
|
||||
getAllTopics(): Promise<ChatTopic[]> {
|
||||
return lambdaClient.topic.getAllTopics.query() as any;
|
||||
}
|
||||
|
||||
async countTopics() {
|
||||
return lambdaClient.topic.countTopics.query();
|
||||
}
|
||||
|
||||
searchTopics(keywords: string, sessionId?: string | undefined): Promise<ChatTopic[]> {
|
||||
return lambdaClient.topic.searchTopics.query({
|
||||
keywords,
|
||||
sessionId: this.toDbSessionId(sessionId),
|
||||
}) as any;
|
||||
}
|
||||
|
||||
updateTopic(id: string, data: Partial<ChatTopic>): Promise<any> {
|
||||
return lambdaClient.topic.updateTopic.mutate({ id, value: data });
|
||||
}
|
||||
|
||||
removeTopic(id: string): Promise<any> {
|
||||
return lambdaClient.topic.removeTopic.mutate({ id });
|
||||
}
|
||||
|
||||
removeTopics(sessionId: string): Promise<any> {
|
||||
return lambdaClient.topic.batchDeleteBySessionId.mutate({ id: this.toDbSessionId(sessionId) });
|
||||
}
|
||||
|
||||
batchRemoveTopics(topics: string[]): Promise<any> {
|
||||
return lambdaClient.topic.batchDelete.mutate({ ids: topics });
|
||||
}
|
||||
|
||||
removeAllTopic(): Promise<any> {
|
||||
return lambdaClient.topic.removeAllTopics.mutate();
|
||||
}
|
||||
|
||||
private toDbSessionId(sessionId?: string | null) {
|
||||
return sessionId === INBOX_SESSION_ID ? null : sessionId;
|
||||
}
|
||||
}
|
||||
|
|
@ -1,11 +1,6 @@
|
|||
// import { getClientConfig } from '@/config/client';
|
||||
//
|
||||
// import { ClientService } from './client';
|
||||
// import { ServerService } from './server';
|
||||
//
|
||||
// const { ENABLED_SERVER_SERVICE } = getClientConfig();
|
||||
//
|
||||
// export const userService = ENABLED_SERVER_SERVICE ? new ServerService() : new ClientService();
|
||||
import { ClientService } from './client';
|
||||
import { isServerMode } from '@/const/version';
|
||||
|
||||
export const userService = new ClientService();
|
||||
import { ClientService } from './client';
|
||||
import { ServerService } from './server';
|
||||
|
||||
export const userService = isServerMode ? new ServerService() : new ClientService();
|
||||
|
|
|
|||
28
src/services/user/server.ts
Normal file
28
src/services/user/server.ts
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
import { DeepPartial } from 'utility-types';
|
||||
|
||||
import { lambdaClient } from '@/libs/trpc/client';
|
||||
import { IUserService } from '@/services/user/type';
|
||||
import { UserInitializationState, UserPreference } from '@/types/user';
|
||||
import { UserSettings } from '@/types/user/settings';
|
||||
|
||||
export class ServerService implements IUserService {
|
||||
getUserState = async (): Promise<UserInitializationState> => {
|
||||
return lambdaClient.user.getUserState.query();
|
||||
};
|
||||
|
||||
async makeUserOnboarded() {
|
||||
return lambdaClient.user.makeUserOnboarded.mutate();
|
||||
}
|
||||
|
||||
async updatePreference(preference: UserPreference) {
|
||||
return lambdaClient.user.updatePreference.mutate(preference);
|
||||
}
|
||||
|
||||
updateUserSettings = async (value: DeepPartial<UserSettings>, signal?: AbortSignal) => {
|
||||
return lambdaClient.user.updateSettings.mutate(value, { signal });
|
||||
};
|
||||
|
||||
resetUserSettings = async () => {
|
||||
return lambdaClient.user.resetSettings.mutate();
|
||||
};
|
||||
}
|
||||
7
tests/setup-db.ts
Normal file
7
tests/setup-db.ts
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
// import env
|
||||
import { Crypto } from '@peculiar/webcrypto';
|
||||
import * as dotenv from 'dotenv';
|
||||
|
||||
dotenv.config();
|
||||
|
||||
global.crypto = new Crypto();
|
||||
|
|
@ -22,12 +22,13 @@ export default defineConfig({
|
|||
],
|
||||
provider: 'v8',
|
||||
reporter: ['text', 'json', 'lcov', 'text-summary'],
|
||||
reportsDirectory: './coverage/app',
|
||||
},
|
||||
deps: {
|
||||
inline: ['vitest-canvas-mock'],
|
||||
},
|
||||
// threads: false,
|
||||
environment: 'happy-dom',
|
||||
exclude: ['**/node_modules/**', '**/dist/**', '**/build/**', 'src/database/server/**/**'],
|
||||
globals: true,
|
||||
setupFiles: './tests/setup.ts',
|
||||
},
|
||||
|
|
|
|||
23
vitest.server.config.ts
Normal file
23
vitest.server.config.ts
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import { resolve } from 'node:path';
|
||||
import { defineConfig } from 'vitest/config';
|
||||
|
||||
export default defineConfig({
|
||||
test: {
|
||||
alias: {
|
||||
'@': resolve(__dirname, './src'),
|
||||
},
|
||||
coverage: {
|
||||
all: false,
|
||||
exclude: ['src/database/server/core/dbForTest.ts'],
|
||||
provider: 'v8',
|
||||
reporter: ['text', 'json', 'lcov', 'text-summary'],
|
||||
reportsDirectory: './coverage/server',
|
||||
},
|
||||
environment: 'node',
|
||||
include: ['src/database/server/**/**/*.test.ts'],
|
||||
poolOptions: {
|
||||
threads: { singleThread: true },
|
||||
},
|
||||
setupFiles: './tests/setup-db.ts',
|
||||
},
|
||||
});
|
||||
Loading…
Reference in a new issue