diff --git a/apps/cli/src/api/client.ts b/apps/cli/src/api/client.ts index 5a8b0cbc9d..c54c15302d 100644 --- a/apps/cli/src/api/client.ts +++ b/apps/cli/src/api/client.ts @@ -5,8 +5,8 @@ import type { LambdaRouter } from '@/server/routers/lambda'; import type { ToolsRouter } from '@/server/routers/tools'; import { getValidToken } from '../auth/refresh'; -import { OFFICIAL_SERVER_URL } from '../constants/urls'; -import { loadSettings } from '../settings'; +import { CLI_API_KEY_ENV } from '../constants/auth'; +import { resolveServerUrl } from '../settings'; import { log } from '../utils/logger'; export type TrpcClient = ReturnType>; @@ -19,31 +19,46 @@ async function getAuthAndServer() { // LOBEHUB_JWT + LOBEHUB_SERVER env vars (used by server-side sandbox execution) const envJwt = process.env.LOBEHUB_JWT; if (envJwt) { - const serverUrl = process.env.LOBEHUB_SERVER || OFFICIAL_SERVER_URL; - return { accessToken: envJwt, serverUrl: serverUrl.replace(/\/$/, '') }; + const serverUrl = resolveServerUrl(); + + return { + headers: { 'Oidc-Auth': envJwt }, + serverUrl, + }; + } + + const envApiKey = process.env[CLI_API_KEY_ENV]; + if (envApiKey) { + const serverUrl = resolveServerUrl(); + + return { + headers: { 'X-API-Key': envApiKey }, + serverUrl, + }; } const result = await getValidToken(); if (!result) { - log.error("No authentication found. Run 'lh login' first."); + log.error(`No authentication found. Run 'lh login' first, or set ${CLI_API_KEY_ENV}.`); process.exit(1); } - const accessToken = result.credentials.accessToken; - const serverUrl = loadSettings()?.serverUrl || OFFICIAL_SERVER_URL; + const serverUrl = resolveServerUrl(); - return { accessToken, serverUrl: serverUrl.replace(/\/$/, '') }; + return { + headers: { 'Oidc-Auth': result.credentials.accessToken }, + serverUrl, + }; } export async function getTrpcClient(): Promise { if (_client) return _client; - const { accessToken, serverUrl } = await getAuthAndServer(); - + const { headers, serverUrl } = await getAuthAndServer(); _client = createTRPCClient({ links: [ httpLink({ - headers: { 'Oidc-Auth': accessToken }, + headers, transformer: superjson, url: `${serverUrl}/trpc/lambda`, }), @@ -56,12 +71,11 @@ export async function getTrpcClient(): Promise { export async function getToolsTrpcClient(): Promise { if (_toolsClient) return _toolsClient; - const { accessToken, serverUrl } = await getAuthAndServer(); - + const { headers, serverUrl } = await getAuthAndServer(); _toolsClient = createTRPCClient({ links: [ httpLink({ - headers: { 'Oidc-Auth': accessToken }, + headers, transformer: superjson, url: `${serverUrl}/trpc/tools`, }), diff --git a/apps/cli/src/api/http.ts b/apps/cli/src/api/http.ts index 954f190215..43b47082a5 100644 --- a/apps/cli/src/api/http.ts +++ b/apps/cli/src/api/http.ts @@ -1,6 +1,6 @@ import { getValidToken } from '../auth/refresh'; -import { OFFICIAL_SERVER_URL } from '../constants/urls'; -import { loadSettings } from '../settings'; +import { CLI_API_KEY_ENV } from '../constants/auth'; +import { resolveServerUrl } from '../settings'; import { log } from '../utils/logger'; // Must match the server's SECRET_XOR_KEY (src/envs/auth.ts) @@ -33,12 +33,19 @@ export interface AuthInfo { export async function getAuthInfo(): Promise { const result = await getValidToken(); if (!result) { + if (process.env[CLI_API_KEY_ENV]) { + log.error( + `API key auth from ${CLI_API_KEY_ENV} is not supported for /webapi/* routes. Run OIDC login instead.`, + ); + process.exit(1); + } + log.error("No authentication found. Run 'lh login' first."); process.exit(1); } const accessToken = result!.credentials.accessToken; - const serverUrl = loadSettings()?.serverUrl || OFFICIAL_SERVER_URL; + const serverUrl = resolveServerUrl(); return { accessToken, @@ -47,6 +54,6 @@ export async function getAuthInfo(): Promise { 'Oidc-Auth': accessToken, 'X-lobe-chat-auth': obfuscatePayloadWithXOR({}), }, - serverUrl: serverUrl.replace(/\/$/, ''), + serverUrl, }; } diff --git a/apps/cli/src/auth/apiKey.ts b/apps/cli/src/auth/apiKey.ts new file mode 100644 index 0000000000..e6eac3be92 --- /dev/null +++ b/apps/cli/src/auth/apiKey.ts @@ -0,0 +1,41 @@ +import { normalizeUrl, resolveServerUrl } from '../settings'; + +interface CurrentUserResponse { + data?: { + id?: string; + userId?: string; + }; + error?: string; + message?: string; + success?: boolean; +} + +export async function getUserIdFromApiKey(apiKey: string, serverUrl?: string): Promise { + const normalizedServerUrl = normalizeUrl(serverUrl) || resolveServerUrl(); + + const response = await fetch(`${normalizedServerUrl}/api/v1/users/me`, { + headers: { + Authorization: `Bearer ${apiKey}`, + }, + }); + + let body: CurrentUserResponse | undefined; + try { + body = (await response.json()) as CurrentUserResponse; + } catch { + throw new Error(`Failed to parse response from ${normalizedServerUrl}/api/v1/users/me.`); + } + + if (!response.ok || body?.success === false) { + throw new Error( + body?.error || body?.message || `Request failed with status ${response.status}.`, + ); + } + + const userId = body?.data?.id || body?.data?.userId; + if (!userId) { + throw new Error('Current user response did not include a user id.'); + } + + return userId; +} diff --git a/apps/cli/src/auth/refresh.ts b/apps/cli/src/auth/refresh.ts index 52deb66740..69439673d4 100644 --- a/apps/cli/src/auth/refresh.ts +++ b/apps/cli/src/auth/refresh.ts @@ -1,5 +1,4 @@ -import { OFFICIAL_SERVER_URL } from '../constants/urls'; -import { loadSettings } from '../settings'; +import { resolveServerUrl } from '../settings'; import { loadCredentials, saveCredentials, type StoredCredentials } from './credentials'; const CLIENT_ID = 'lobehub-cli'; @@ -20,7 +19,7 @@ export async function getValidToken(): Promise<{ credentials: StoredCredentials // Token expired — try refresh if (!credentials.refreshToken) return null; - const serverUrl = loadSettings()?.serverUrl || OFFICIAL_SERVER_URL; + const serverUrl = resolveServerUrl(); const refreshed = await refreshAccessToken(serverUrl, credentials.refreshToken); if (!refreshed) return null; diff --git a/apps/cli/src/auth/resolveToken.test.ts b/apps/cli/src/auth/resolveToken.test.ts index fa33236fa1..2a1ca84c38 100644 --- a/apps/cli/src/auth/resolveToken.test.ts +++ b/apps/cli/src/auth/resolveToken.test.ts @@ -1,12 +1,21 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import { getUserIdFromApiKey } from './apiKey'; import { getValidToken } from './refresh'; import { resolveToken } from './resolveToken'; +vi.mock('./apiKey', () => ({ + getUserIdFromApiKey: vi.fn(), +})); vi.mock('./refresh', () => ({ getValidToken: vi.fn(), })); - +vi.mock('../settings', () => ({ + loadSettings: vi.fn().mockReturnValue({ serverUrl: 'https://app.lobehub.com' }), + resolveServerUrl: vi.fn(() => + (process.env.LOBEHUB_SERVER || 'https://app.lobehub.com').replace(/\/$/, ''), + ), +})); vi.mock('../utils/logger', () => ({ log: { debug: vi.fn(), @@ -25,14 +34,23 @@ function makeJwt(sub: string): string { describe('resolveToken', () => { let exitSpy: ReturnType; + const originalApiKey = process.env.LOBEHUB_CLI_API_KEY; + const originalJwt = process.env.LOBEHUB_JWT; + const originalServer = process.env.LOBEHUB_SERVER; beforeEach(() => { exitSpy = vi.spyOn(process, 'exit').mockImplementation(() => { throw new Error('process.exit'); }); + delete process.env.LOBEHUB_CLI_API_KEY; + delete process.env.LOBEHUB_JWT; + delete process.env.LOBEHUB_SERVER; }); afterEach(() => { + process.env.LOBEHUB_CLI_API_KEY = originalApiKey; + process.env.LOBEHUB_JWT = originalJwt; + process.env.LOBEHUB_SERVER = originalServer; exitSpy.mockRestore(); }); @@ -42,7 +60,12 @@ describe('resolveToken', () => { const result = await resolveToken({ token }); - expect(result).toEqual({ token, userId: 'user-123' }); + expect(result).toEqual({ + serverUrl: 'https://app.lobehub.com', + token, + tokenType: 'jwt', + userId: 'user-123', + }); }); it('should exit if JWT has no sub claim', async () => { @@ -67,7 +90,12 @@ describe('resolveToken', () => { userId: 'user-456', }); - expect(result).toEqual({ token: 'svc-token', userId: 'user-456' }); + expect(result).toEqual({ + serverUrl: 'https://app.lobehub.com', + token: 'svc-token', + tokenType: 'serviceToken', + userId: 'user-456', + }); }); it('should exit if --user-id is not provided', async () => { @@ -76,6 +104,37 @@ describe('resolveToken', () => { }); }); + describe('with environment api key', () => { + it('should return API key from environment', async () => { + process.env.LOBEHUB_CLI_API_KEY = 'sk-lh-test'; + vi.mocked(getUserIdFromApiKey).mockResolvedValue('user-789'); + + const result = await resolveToken({}); + + expect(getUserIdFromApiKey).toHaveBeenCalledWith('sk-lh-test', 'https://app.lobehub.com'); + expect(result).toEqual({ + serverUrl: 'https://app.lobehub.com', + token: 'sk-lh-test', + tokenType: 'apiKey', + userId: 'user-789', + }); + }); + + it('should prefer LOBEHUB_SERVER when validating the API key', async () => { + process.env.LOBEHUB_CLI_API_KEY = 'sk-lh-test'; + process.env.LOBEHUB_SERVER = 'https://self-hosted.example.com/'; + vi.mocked(getUserIdFromApiKey).mockResolvedValue('user-789'); + + const result = await resolveToken({}); + + expect(getUserIdFromApiKey).toHaveBeenCalledWith( + 'sk-lh-test', + 'https://self-hosted.example.com', + ); + expect(result.serverUrl).toBe('https://self-hosted.example.com'); + }); + }); + describe('with stored credentials', () => { it('should return stored credentials token', async () => { const token = makeJwt('stored-user'); @@ -87,7 +146,12 @@ describe('resolveToken', () => { const result = await resolveToken({}); - expect(result).toEqual({ token, userId: 'stored-user' }); + expect(result).toEqual({ + serverUrl: 'https://app.lobehub.com', + token, + tokenType: 'jwt', + userId: 'stored-user', + }); }); it('should exit if stored token has no sub', async () => { diff --git a/apps/cli/src/auth/resolveToken.ts b/apps/cli/src/auth/resolveToken.ts index 9e650b9c17..d84fe90ceb 100644 --- a/apps/cli/src/auth/resolveToken.ts +++ b/apps/cli/src/auth/resolveToken.ts @@ -1,4 +1,7 @@ +import { CLI_API_KEY_ENV } from '../constants/auth'; +import { resolveServerUrl } from '../settings'; import { log } from '../utils/logger'; +import { getUserIdFromApiKey } from './apiKey'; import { getValidToken } from './refresh'; interface ResolveTokenOptions { @@ -8,7 +11,9 @@ interface ResolveTokenOptions { } interface ResolvedAuth { + serverUrl: string; token: string; + tokenType: 'apiKey' | 'jwt' | 'serviceToken'; userId: string; } @@ -25,20 +30,21 @@ function parseJwtSub(token: string): string | undefined { } /** - * Resolve an access token from explicit options or stored credentials. + * Resolve an access token from explicit options, environment variables, or stored credentials. * Exits the process if no token can be resolved. */ export async function resolveToken(options: ResolveTokenOptions): Promise { // LOBEHUB_JWT env var takes highest priority (used by server-side sandbox execution) const envJwt = process.env.LOBEHUB_JWT; if (envJwt) { + const serverUrl = resolveServerUrl(); const userId = parseJwtSub(envJwt); if (!userId) { log.error('Could not extract userId from LOBEHUB_JWT.'); process.exit(1); } log.debug('Using LOBEHUB_JWT from environment'); - return { token: envJwt, userId }; + return { serverUrl, token: envJwt, tokenType: 'jwt', userId }; } // Explicit token takes priority @@ -48,7 +54,7 @@ export async function resolveToken(options: ResolveTokenOptions): Promise ({ - resolveToken: vi.fn().mockResolvedValue({ token: 'test-token', userId: 'test-user' }), + resolveToken: vi.fn().mockResolvedValue({ + serverUrl: 'https://app.lobehub.com', + token: 'test-token', + tokenType: 'jwt', + userId: 'test-user', + }), })); vi.mock('../settings', () => ({ loadSettings: vi.fn().mockReturnValue(null), + normalizeUrl: vi.fn((url?: string) => (url ? url.replace(/\/$/, '') : undefined)), saveSettings: vi.fn(), })); @@ -161,6 +167,12 @@ describe('connect command', () => { serverUrl: 'https://self-hosted.example.com', }); }); + it('should pass the resolved serverUrl to GatewayClient', async () => { + const program = createProgram(); + await program.parseAsync(['node', 'test', 'connect']); + + expect(clientOptions.serverUrl).toBe('https://app.lobehub.com'); + }); it('should handle tool call requests', async () => { const program = createProgram(); @@ -208,7 +220,12 @@ describe('connect command', () => { }); it('should handle auth_expired', async () => { - vi.mocked(resolveToken).mockResolvedValueOnce({ token: 'new-tok', userId: 'user' }); + vi.mocked(resolveToken).mockResolvedValueOnce({ + serverUrl: 'https://app.lobehub.com', + token: 'new-tok', + tokenType: 'jwt', + userId: 'user', + }); const program = createProgram(); await program.parseAsync(['node', 'test', 'connect']); @@ -220,6 +237,24 @@ describe('connect command', () => { expect(exitSpy).toHaveBeenCalledWith(1); }); + it('should ignore auth_expired for api key auth', async () => { + vi.mocked(resolveToken).mockResolvedValueOnce({ + serverUrl: 'https://self-hosted.example.com', + token: 'test-api-key', + tokenType: 'apiKey', + userId: 'user', + }); + + const program = createProgram(); + await program.parseAsync(['node', 'test', 'connect']); + + await clientEventHandlers['auth_expired']?.(); + + expect(log.error).not.toHaveBeenCalled(); + expect(cleanupAllProcesses).not.toHaveBeenCalled(); + expect(exitSpy).not.toHaveBeenCalled(); + }); + it('should handle error event', async () => { const program = createProgram(); await program.parseAsync(['node', 'test', 'connect']); diff --git a/apps/cli/src/commands/connect.ts b/apps/cli/src/commands/connect.ts index b29a38da16..2e34b893f7 100644 --- a/apps/cli/src/commands/connect.ts +++ b/apps/cli/src/commands/connect.ts @@ -11,6 +11,7 @@ import { GatewayClient } from '@lobechat/device-gateway-client'; import type { Command } from 'commander'; import { resolveToken } from '../auth/resolveToken'; +import { CLI_API_KEY_ENV } from '../constants/auth'; import { OFFICIAL_GATEWAY_URL } from '../constants/urls'; import { appendLog, @@ -23,7 +24,7 @@ import { stopDaemon, writeStatus, } from '../daemon/manager'; -import { loadSettings, saveSettings } from '../settings'; +import { loadSettings, normalizeUrl, saveSettings } from '../settings'; import { executeToolCall } from '../tools'; import { cleanupAllProcesses } from '../tools/shell'; import { log, setVerbose } from '../utils/logger'; @@ -174,7 +175,7 @@ function buildDaemonArgs(options: ConnectOptions): string[] { async function runConnect(options: ConnectOptions, isDaemonChild: boolean) { const auth = await resolveToken(options); const settings = loadSettings(); - const gatewayUrl = options.gateway?.replace(/\/$/, '') || settings?.gatewayUrl; + const gatewayUrl = normalizeUrl(options.gateway) || settings?.gatewayUrl; if (!gatewayUrl && settings?.serverUrl) { log.error( @@ -194,7 +195,9 @@ async function runConnect(options: ConnectOptions, isDaemonChild: boolean) { deviceId: options.deviceId, gatewayUrl: resolvedGatewayUrl, logger: isDaemonChild ? createDaemonLogger() : log, + serverUrl: auth.serverUrl, token: auth.token, + tokenType: auth.tokenType, userId: auth.userId, }); @@ -214,7 +217,7 @@ async function runConnect(options: ConnectOptions, isDaemonChild: boolean) { info(` Hostname : ${os.hostname()}`); info(` Platform : ${process.platform}`); info(` Gateway : ${resolvedGatewayUrl}`); - info(` Auth : jwt`); + info(` Auth : ${auth.tokenType}`); info(` Mode : ${isDaemonChild ? 'daemon' : 'foreground'}`); info('───────────────────'); @@ -285,13 +288,19 @@ async function runConnect(options: ConnectOptions, isDaemonChild: boolean) { // Handle auth failed client.on('auth_failed', (reason) => { error(`Authentication failed: ${reason}`); - error("Run 'lh login' to re-authenticate."); + error( + `Run 'lh login', or set ${CLI_API_KEY_ENV} and run 'lh login --server ' to configure API key authentication.`, + ); cleanup(); process.exit(1); }); // Handle auth expired client.on('auth_expired', async () => { + if (auth.tokenType === 'apiKey') { + return; + } + error('Authentication expired. Attempting to refresh...'); const refreshed = await resolveToken({}); if (refreshed) { diff --git a/apps/cli/src/commands/login.test.ts b/apps/cli/src/commands/login.test.ts index bf71ae1635..9bd079bd83 100644 --- a/apps/cli/src/commands/login.test.ts +++ b/apps/cli/src/commands/login.test.ts @@ -3,11 +3,15 @@ import fs from 'node:fs'; import { Command } from 'commander'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import { getUserIdFromApiKey } from '../auth/apiKey'; import { saveCredentials } from '../auth/credentials'; import { loadSettings, saveSettings } from '../settings'; import { log } from '../utils/logger'; import { registerLoginCommand, resolveCommandExecutable } from './login'; +vi.mock('../auth/apiKey', () => ({ + getUserIdFromApiKey: vi.fn(), +})); vi.mock('../auth/credentials', () => ({ saveCredentials: vi.fn(), })); @@ -37,6 +41,7 @@ vi.mock('node:child_process', () => ({ describe('login command', () => { let exitSpy: ReturnType; + const originalApiKey = process.env.LOBEHUB_CLI_API_KEY; const originalPath = process.env.PATH; const originalPathext = process.env.PATHEXT; const originalSystemRoot = process.env.SystemRoot; @@ -46,11 +51,13 @@ describe('login command', () => { vi.stubGlobal('fetch', vi.fn()); exitSpy = vi.spyOn(process, 'exit').mockImplementation((() => {}) as any); vi.mocked(loadSettings).mockReturnValue(null); + delete process.env.LOBEHUB_CLI_API_KEY; }); afterEach(() => { vi.useRealTimers(); exitSpy.mockRestore(); + process.env.LOBEHUB_CLI_API_KEY = originalApiKey; process.env.PATH = originalPath; process.env.PATHEXT = originalPathext; process.env.SystemRoot = originalSystemRoot; @@ -102,8 +109,12 @@ describe('login command', () => { } as any; } + async function runLogin(program: Command, args: string[] = []) { + return program.parseAsync(['node', 'test', 'login', ...args]); + } + async function runLoginAndAdvanceTimers(program: Command, args: string[] = []) { - const parsePromise = program.parseAsync(['node', 'test', 'login', ...args]); + const parsePromise = runLogin(program, args); // Advance timers to let sleep resolve in the polling loop for (let i = 0; i < 10; i++) { await vi.advanceTimersByTimeAsync(2000); @@ -130,6 +141,19 @@ describe('login command', () => { expect(log.info).toHaveBeenCalledWith(expect.stringContaining('Login successful')); }); + it('should use environment api key without storing credentials', async () => { + process.env.LOBEHUB_CLI_API_KEY = 'sk-lh-env-test'; + vi.mocked(getUserIdFromApiKey).mockResolvedValue('user-123'); + + const program = createProgram(); + await runLogin(program); + + expect(getUserIdFromApiKey).toHaveBeenCalledWith('sk-lh-env-test', 'https://app.lobehub.com'); + expect(saveCredentials).not.toHaveBeenCalled(); + expect(saveSettings).toHaveBeenCalledWith({ serverUrl: 'https://app.lobehub.com' }); + expect(log.info).toHaveBeenCalledWith(expect.stringContaining('Login successful')); + }); + it('should persist custom server into settings', async () => { vi.mocked(fetch) .mockResolvedValueOnce(deviceAuthResponse()) @@ -159,6 +183,23 @@ describe('login command', () => { }); }); + it('should preserve existing gateway for environment api key on the same server', async () => { + process.env.LOBEHUB_CLI_API_KEY = 'sk-lh-env-test'; + vi.mocked(getUserIdFromApiKey).mockResolvedValue('user-123'); + vi.mocked(loadSettings).mockReturnValueOnce({ + gatewayUrl: 'https://gateway.example.com', + serverUrl: 'https://test.com', + }); + + const program = createProgram(); + await runLogin(program, ['--server', 'https://test.com/']); + + expect(saveSettings).toHaveBeenCalledWith({ + gatewayUrl: 'https://gateway.example.com', + serverUrl: 'https://test.com', + }); + }); + it('should clear existing gateway when logging into a different server', async () => { vi.mocked(loadSettings).mockReturnValueOnce({ gatewayUrl: 'https://gateway.example.com', diff --git a/apps/cli/src/commands/login.ts b/apps/cli/src/commands/login.ts index b96ced9636..f4432f8f8c 100644 --- a/apps/cli/src/commands/login.ts +++ b/apps/cli/src/commands/login.ts @@ -4,9 +4,11 @@ import path from 'node:path'; import type { Command } from 'commander'; +import { getUserIdFromApiKey } from '../auth/apiKey'; import { saveCredentials } from '../auth/credentials'; +import { CLI_API_KEY_ENV } from '../constants/auth'; import { OFFICIAL_SERVER_URL } from '../constants/urls'; -import { loadSettings, saveSettings } from '../settings'; +import { loadSettings, normalizeUrl, saveSettings } from '../settings'; import { log } from '../utils/logger'; const CLIENT_ID = 'lobehub-cli'; @@ -51,13 +53,43 @@ async function parseJsonResponse(res: Response, endpoint: string): Promise export function registerLoginCommand(program: Command) { program .command('login') - .description('Log in to LobeHub via browser (Device Code Flow)') + .description('Log in to LobeHub via browser (Device Code Flow) or configure API key server') .option('--server ', 'LobeHub server URL', OFFICIAL_SERVER_URL) .action(async (options: LoginOptions) => { - const serverUrl = options.server.replace(/\/$/, ''); + const serverUrl = normalizeUrl(options.server) || OFFICIAL_SERVER_URL; log.info('Starting login...'); + const apiKey = process.env[CLI_API_KEY_ENV]; + if (apiKey) { + try { + await getUserIdFromApiKey(apiKey, serverUrl); + + const existingSettings = loadSettings(); + const shouldPreserveGateway = existingSettings?.serverUrl === serverUrl; + + saveSettings( + shouldPreserveGateway + ? { + gatewayUrl: existingSettings.gatewayUrl, + serverUrl, + } + : { + // Gateway auth is tied to the login server's token issuer/JWKS. + // When server changes, clear old gateway to avoid stale cross-environment config. + serverUrl, + }, + ); + log.info('Login successful! Credentials saved.'); + return; + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + log.error(`API key validation failed: ${message}`); + process.exit(1); + return; + } + } + // Step 1: Request device code let deviceAuth: DeviceAuthResponse; try { @@ -164,6 +196,7 @@ export function registerLoginCommand(program: Command) { : undefined, refreshToken: body.refresh_token, }); + const existingSettings = loadSettings(); const shouldPreserveGateway = existingSettings?.serverUrl === serverUrl; diff --git a/apps/cli/src/commands/status.test.ts b/apps/cli/src/commands/status.test.ts index 884b691fef..67071b6236 100644 --- a/apps/cli/src/commands/status.test.ts +++ b/apps/cli/src/commands/status.test.ts @@ -3,10 +3,16 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; // Mock resolveToken vi.mock('../auth/resolveToken', () => ({ - resolveToken: vi.fn().mockResolvedValue({ token: 'test-token', userId: 'test-user' }), + resolveToken: vi.fn().mockResolvedValue({ + serverUrl: 'https://app.lobehub.com', + token: 'test-token', + tokenType: 'jwt', + userId: 'test-user', + }), })); vi.mock('../settings', () => ({ loadSettings: vi.fn().mockReturnValue(null), + normalizeUrl: vi.fn((url?: string) => (url ? url.replace(/\/$/, '') : undefined)), saveSettings: vi.fn(), })); @@ -115,6 +121,16 @@ describe('status command', () => { serverUrl: 'https://self-hosted.example.com', }); }); + it('should pass the resolved serverUrl to GatewayClient', async () => { + const program = createProgram(); + const parsePromise = program.parseAsync(['node', 'test', 'status']); + await vi.advanceTimersByTimeAsync(0); + + clientEventHandlers['connected']?.(); + + await parsePromise; + expect(clientOptions.serverUrl).toBe('https://app.lobehub.com'); + }); it('should log CONNECTED on successful connection', async () => { const program = createProgram(); diff --git a/apps/cli/src/commands/status.ts b/apps/cli/src/commands/status.ts index 9465703bb0..985ced5a5a 100644 --- a/apps/cli/src/commands/status.ts +++ b/apps/cli/src/commands/status.ts @@ -3,7 +3,7 @@ import type { Command } from 'commander'; import { resolveToken } from '../auth/resolveToken'; import { OFFICIAL_GATEWAY_URL } from '../constants/urls'; -import { loadSettings, saveSettings } from '../settings'; +import { loadSettings, normalizeUrl, saveSettings } from '../settings'; import { log, setVerbose } from '../utils/logger'; interface StatusOptions { @@ -30,7 +30,7 @@ export function registerStatusCommand(program: Command) { const auth = await resolveToken(options); const settings = loadSettings(); - const gatewayUrl = options.gateway?.replace(/\/$/, '') || settings?.gatewayUrl; + const gatewayUrl = normalizeUrl(options.gateway) || settings?.gatewayUrl; if (!gatewayUrl && settings?.serverUrl) { log.error( @@ -50,7 +50,9 @@ export function registerStatusCommand(program: Command) { autoReconnect: false, gatewayUrl: gatewayUrl || OFFICIAL_GATEWAY_URL, logger: log, + serverUrl: auth.serverUrl, token: auth.token, + tokenType: auth.tokenType, userId: auth.userId, }); diff --git a/apps/cli/src/constants/auth.ts b/apps/cli/src/constants/auth.ts new file mode 100644 index 0000000000..de3602f00c --- /dev/null +++ b/apps/cli/src/constants/auth.ts @@ -0,0 +1 @@ +export const CLI_API_KEY_ENV = 'LOBEHUB_CLI_API_KEY'; diff --git a/apps/cli/src/settings/index.test.ts b/apps/cli/src/settings/index.test.ts index 1e18d2fc45..213298bd61 100644 --- a/apps/cli/src/settings/index.test.ts +++ b/apps/cli/src/settings/index.test.ts @@ -5,18 +5,19 @@ import path from 'node:path'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { log } from '../utils/logger'; -import { loadSettings, saveSettings } from './index'; +import { loadSettings, normalizeUrl, resolveServerUrl, saveSettings } from './index'; const tmpDir = path.join(os.tmpdir(), 'lobehub-cli-test-settings'); const settingsDir = path.join(tmpDir, '.lobehub'); const settingsFile = path.join(settingsDir, 'settings.json'); +const originalServer = process.env.LOBEHUB_SERVER; vi.mock('node:os', async (importOriginal) => { const actual = await importOriginal>(); return { ...actual, default: { - ...actual['default'], + ...actual.default, homedir: () => path.join(os.tmpdir(), 'lobehub-cli-test-settings'), }, }; @@ -31,10 +32,12 @@ vi.mock('../utils/logger', () => ({ describe('settings', () => { beforeEach(() => { fs.mkdirSync(tmpDir, { recursive: true }); + delete process.env.LOBEHUB_SERVER; }); afterEach(() => { fs.rmSync(tmpDir, { force: true, recursive: true }); + process.env.LOBEHUB_SERVER = originalServer; vi.clearAllMocks(); }); @@ -64,4 +67,28 @@ describe('settings', () => { expect(loadSettings()).toBeNull(); expect(log.warn).toHaveBeenCalledWith(expect.stringContaining('Please delete this file')); }); + + it('should normalize trailing slashes', () => { + expect(normalizeUrl('https://self-hosted.example.com/')).toBe( + 'https://self-hosted.example.com', + ); + expect(normalizeUrl(undefined)).toBeUndefined(); + }); + + it('should prefer LOBEHUB_SERVER over settings', () => { + saveSettings({ serverUrl: 'https://settings.example.com/' }); + process.env.LOBEHUB_SERVER = 'https://env.example.com/'; + + expect(resolveServerUrl()).toBe('https://env.example.com'); + }); + + it('should fall back to settings then official server', () => { + saveSettings({ serverUrl: 'https://settings.example.com/' }); + + expect(resolveServerUrl()).toBe('https://settings.example.com'); + + fs.unlinkSync(settingsFile); + + expect(resolveServerUrl()).toBe('https://app.lobehub.com'); + }); }); diff --git a/apps/cli/src/settings/index.ts b/apps/cli/src/settings/index.ts index f98ad52cf7..3cd78f4f90 100644 --- a/apps/cli/src/settings/index.ts +++ b/apps/cli/src/settings/index.ts @@ -14,10 +14,17 @@ const LOBEHUB_DIR_NAME = process.env.LOBEHUB_CLI_HOME || '.lobehub'; const SETTINGS_DIR = path.join(os.homedir(), LOBEHUB_DIR_NAME); const SETTINGS_FILE = path.join(SETTINGS_DIR, 'settings.json'); -function normalizeUrl(url: string | undefined): string | undefined { +export function normalizeUrl(url: string | undefined): string | undefined { return url ? url.replace(/\/$/, '') : undefined; } +export function resolveServerUrl(): string { + const envServerUrl = normalizeUrl(process.env.LOBEHUB_SERVER); + const settingsServerUrl = normalizeUrl(loadSettings()?.serverUrl); + + return envServerUrl || settingsServerUrl || OFFICIAL_SERVER_URL; +} + export function saveSettings(settings: StoredSettings): void { const serverUrl = normalizeUrl(settings.serverUrl); const gatewayUrl = normalizeUrl(settings.gatewayUrl); diff --git a/apps/device-gateway/src/DeviceGatewayDO.ts b/apps/device-gateway/src/DeviceGatewayDO.ts index b3d787bdde..791248fa6f 100644 --- a/apps/device-gateway/src/DeviceGatewayDO.ts +++ b/apps/device-gateway/src/DeviceGatewayDO.ts @@ -1,7 +1,7 @@ import { DurableObject } from 'cloudflare:workers'; import { Hono } from 'hono'; -import { verifyDesktopToken } from './auth'; +import { resolveSocketAuth, verifyApiKeyToken, verifyDesktopToken } from './auth'; import type { DeviceAttachment, Env } from './types'; const AUTH_TIMEOUT = 10_000; // 10s to authenticate after connect @@ -58,24 +58,25 @@ export class DeviceGatewayDO extends DurableObject { if (att.authenticated) return; // Already authenticated, ignore try { - const token = data.token as string; - if (!token) throw new Error('Missing token'); + const token = data.token as string | undefined; + const tokenType = data.tokenType as 'apiKey' | 'jwt' | 'serviceToken' | undefined; + const serverUrl = data.serverUrl as string | undefined; + const storedUserId = await this.ctx.storage.get('_userId'); - let verifiedUserId: string; - - if (token === this.env.SERVICE_TOKEN) { - // Service token auth (for CLI debugging) - const storedUserId = await this.ctx.storage.get('_userId'); - if (!storedUserId) throw new Error('Missing userId'); - verifiedUserId = storedUserId; - } else { - // JWT auth (normal desktop flow) - const result = await verifyDesktopToken(this.env, token); - verifiedUserId = result.userId; - } + const verifiedUserId = await resolveSocketAuth({ + serverUrl, + serviceToken: this.env.SERVICE_TOKEN, + storedUserId, + token, + tokenType, + verifyApiKey: verifyApiKeyToken, + verifyJwt: async (jwt) => { + const result = await verifyDesktopToken(this.env, jwt); + return { userId: result.userId }; + }, + }); // Verify userId matches the DO routing - const storedUserId = await this.ctx.storage.get('_userId'); if (storedUserId && verifiedUserId !== storedUserId) { throw new Error('userId mismatch'); } diff --git a/apps/device-gateway/src/auth.test.ts b/apps/device-gateway/src/auth.test.ts new file mode 100644 index 0000000000..54f879f6c9 --- /dev/null +++ b/apps/device-gateway/src/auth.test.ts @@ -0,0 +1,96 @@ +import { describe, expect, it, vi } from 'vitest'; + +import { resolveSocketAuth } from './auth'; + +describe('resolveSocketAuth', () => { + it('rejects missing token', async () => { + const verifyApiKey = vi.fn(); + const verifyJwt = vi.fn(); + + await expect( + resolveSocketAuth({ + serviceToken: 'service-secret', + storedUserId: 'user-123', + verifyApiKey, + verifyJwt, + }), + ).rejects.toThrow('Missing token'); + + expect(verifyApiKey).not.toHaveBeenCalled(); + expect(verifyJwt).not.toHaveBeenCalled(); + }); + + it('rejects the real service token when storedUserId is missing', async () => { + const verifyApiKey = vi.fn(); + const verifyJwt = vi.fn(); + + await expect( + resolveSocketAuth({ + serviceToken: 'service-secret', + token: 'service-secret', + tokenType: 'serviceToken', + verifyApiKey, + verifyJwt, + }), + ).rejects.toThrow('Missing userId'); + + expect(verifyApiKey).not.toHaveBeenCalled(); + expect(verifyJwt).not.toHaveBeenCalled(); + }); + it('rejects clients that only self-declare serviceToken mode', async () => { + const verifyApiKey = vi.fn(); + const verifyJwt = vi.fn().mockRejectedValue(new Error('invalid jwt')); + + await expect( + resolveSocketAuth({ + serviceToken: 'service-secret', + storedUserId: 'user-123', + token: 'attacker-token', + tokenType: 'serviceToken', + verifyApiKey, + verifyJwt, + }), + ).rejects.toThrow('invalid jwt'); + + expect(verifyApiKey).not.toHaveBeenCalled(); + expect(verifyJwt).toHaveBeenCalledWith('attacker-token'); + }); + + it('treats a forged serviceToken claim with a valid JWT as JWT auth', async () => { + const verifyApiKey = vi.fn(); + const verifyJwt = vi.fn().mockResolvedValue({ userId: 'user-123' }); + + await expect( + resolveSocketAuth({ + serviceToken: 'service-secret', + storedUserId: 'user-123', + token: 'valid-jwt', + tokenType: 'serviceToken', + verifyApiKey, + verifyJwt, + }), + ).resolves.toBe('user-123'); + + expect(verifyApiKey).not.toHaveBeenCalled(); + expect(verifyJwt).toHaveBeenCalledWith('valid-jwt'); + }); + + it('accepts the real service token', async () => { + const verifyApiKey = vi.fn(); + const verifyJwt = vi.fn(); + + await expect( + resolveSocketAuth({ + serviceToken: 'service-secret', + storedUserId: 'user-123', + token: 'service-secret', + tokenType: 'serviceToken', + verifyApiKey, + verifyJwt, + }), + ).resolves.toBe('user-123'); + + expect(verifyApiKey).not.toHaveBeenCalled(); + expect(verifyJwt).not.toHaveBeenCalled(); + }); +}); diff --git a/apps/device-gateway/src/auth.ts b/apps/device-gateway/src/auth.ts index f1d9af940c..39e766556b 100644 --- a/apps/device-gateway/src/auth.ts +++ b/apps/device-gateway/src/auth.ts @@ -4,6 +4,26 @@ import type { Env } from './types'; let cachedKey: CryptoKey | null = null; +interface CurrentUserResponse { + data?: { + id?: string; + userId?: string; + }; + error?: string; + message?: string; + success?: boolean; +} + +export interface ResolveSocketAuthOptions { + serverUrl?: string; + serviceToken: string; + storedUserId?: string; + token?: string; + tokenType?: 'apiKey' | 'jwt' | 'serviceToken'; + verifyApiKey: (serverUrl: string, token: string) => Promise<{ userId: string }>; + verifyJwt: (token: string) => Promise<{ userId: string }>; +} + async function getPublicKey(env: Env): Promise { if (cachedKey) return cachedKey; @@ -34,3 +54,57 @@ export async function verifyDesktopToken( userId: payload.sub, }; } + +export async function verifyApiKeyToken( + serverUrl: string, + token: string, +): Promise<{ userId: string }> { + const normalizedServerUrl = new URL(serverUrl).toString().replace(/\/$/, ''); + + const response = await fetch(`${normalizedServerUrl}/api/v1/users/me`, { + headers: { + Authorization: `Bearer ${token}`, + }, + }); + + let body: CurrentUserResponse | undefined; + try { + body = (await response.json()) as CurrentUserResponse; + } catch { + throw new Error(`Failed to parse response from ${normalizedServerUrl}/api/v1/users/me.`); + } + + if (!response.ok || body?.success === false) { + throw new Error( + body?.error || body?.message || `Request failed with status ${response.status}.`, + ); + } + + const userId = body?.data?.id || body?.data?.userId; + if (!userId) { + throw new Error('Current user response did not include a user id.'); + } + + return { userId }; +} + +export async function resolveSocketAuth(options: ResolveSocketAuthOptions): Promise { + const { serverUrl, serviceToken, storedUserId, token, tokenType, verifyApiKey, verifyJwt } = + options; + + if (!token) throw new Error('Missing token'); + + if (tokenType === 'apiKey') { + if (!serverUrl) throw new Error('Missing serverUrl'); + const result = await verifyApiKey(serverUrl, token); + return result.userId; + } + + if (token === serviceToken) { + if (!storedUserId) throw new Error('Missing userId'); + return storedUserId; + } + + const result = await verifyJwt(token); + return result.userId; +} diff --git a/apps/device-gateway/src/types.ts b/apps/device-gateway/src/types.ts index 1a2a23a68f..7fb79821a4 100644 --- a/apps/device-gateway/src/types.ts +++ b/apps/device-gateway/src/types.ts @@ -20,7 +20,9 @@ export interface DeviceAttachment { // Desktop → CF export interface AuthMessage { + serverUrl?: string; token: string; + tokenType?: 'apiKey' | 'jwt' | 'serviceToken'; type: 'auth'; } diff --git a/packages/database/src/models/apiKey.ts b/packages/database/src/models/apiKey.ts index 9dc3222ccd..d0c114c428 100644 --- a/packages/database/src/models/apiKey.ts +++ b/packages/database/src/models/apiKey.ts @@ -1,14 +1,25 @@ +import { generateApiKey, isApiKeyExpired, validateApiKeyFormat } from '@lobechat/utils/apiKey'; +import { hashApiKey } from '@lobechat/utils/server'; import { and, desc, eq } from 'drizzle-orm'; import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt'; -import { generateApiKey, isApiKeyExpired, validateApiKeyFormat } from '@/utils/apiKey'; -import { hashApiKey } from '@/utils/server/apiKeyHash'; import type { ApiKeyItem, NewApiKeyItem } from '../schemas'; import { apiKeys } from '../schemas'; import type { LobeChatDatabase } from '../type'; export class ApiKeyModel { + static findByKey = async (db: LobeChatDatabase, key: string) => { + if (!validateApiKeyFormat(key)) { + return null; + } + const keyHash = hashApiKey(key); + + return db.query.apiKeys.findFirst({ + where: eq(apiKeys.keyHash, keyHash), + }); + }; + private userId: string; private db: LobeChatDatabase; private gateKeeperPromise: Promise | null = null; @@ -75,14 +86,7 @@ export class ApiKeyModel { }; findByKey = async (key: string) => { - if (!validateApiKeyFormat(key)) { - return null; - } - const keyHash = hashApiKey(key); - - return this.db.query.apiKeys.findFirst({ - where: eq(apiKeys.keyHash, keyHash), - }); + return ApiKeyModel.findByKey(this.db, key); }; validateKey = async (key: string) => { diff --git a/packages/device-gateway-client/src/client.test.ts b/packages/device-gateway-client/src/client.test.ts index 3bc03d81e6..3613507f7a 100644 --- a/packages/device-gateway-client/src/client.test.ts +++ b/packages/device-gateway-client/src/client.test.ts @@ -50,7 +50,9 @@ describe('GatewayClient', () => { autoReconnect: false, deviceId: 'test-device-id', gatewayUrl: 'https://gateway.test.com', + serverUrl: 'https://app.test.com', token: 'test-token', + tokenType: 'apiKey', userId: 'test-user', }); }); @@ -88,6 +90,16 @@ describe('GatewayClient', () => { expect(client.connectionStatus).toBe('authenticating'); expect(statusChanges).toContain('connecting'); expect(statusChanges).toContain('authenticating'); + + const ws = (client as any).ws; + expect(ws.send).toHaveBeenCalledWith( + JSON.stringify({ + serverUrl: 'https://app.test.com', + token: 'test-token', + tokenType: 'apiKey', + type: 'auth', + }), + ); }); it('should not reconnect if already connected', async () => { diff --git a/packages/device-gateway-client/src/client.ts b/packages/device-gateway-client/src/client.ts index 2aaf4ebb15..19f8c1bad4 100644 --- a/packages/device-gateway-client/src/client.ts +++ b/packages/device-gateway-client/src/client.ts @@ -44,7 +44,9 @@ export interface GatewayClientOptions { deviceId?: string; gatewayUrl?: string; logger?: GatewayClientLogger; + serverUrl?: string; token: string; + tokenType?: 'apiKey' | 'jwt' | 'serviceToken'; userId?: string; } @@ -58,15 +60,19 @@ export class GatewayClient extends EventEmitter { private deviceId: string; private gatewayUrl: string; private token: string; + private tokenType?: 'apiKey' | 'jwt' | 'serviceToken'; private userId?: string; + private serverUrl?: string; private logger: GatewayClientLogger; private autoReconnect: boolean; constructor(options: GatewayClientOptions) { super(); this.token = options.token; + this.tokenType = options.tokenType; this.gatewayUrl = options.gatewayUrl || DEFAULT_GATEWAY_URL; this.deviceId = options.deviceId || randomUUID(); + this.serverUrl = options.serverUrl; this.userId = options.userId; this.logger = options.logger || noopLogger; this.autoReconnect = options.autoReconnect ?? true; @@ -180,7 +186,12 @@ export class GatewayClient extends EventEmitter { this.setStatus('authenticating'); // Send token as first message instead of in URL - this.sendMessage({ type: 'auth', token: this.token }); + this.sendMessage({ + serverUrl: this.serverUrl, + token: this.token, + tokenType: this.tokenType, + type: 'auth', + }); }; private handleMessage = (data: WebSocket.Data) => { diff --git a/packages/device-gateway-client/src/types.ts b/packages/device-gateway-client/src/types.ts index 8dd674da10..3354999514 100644 --- a/packages/device-gateway-client/src/types.ts +++ b/packages/device-gateway-client/src/types.ts @@ -24,7 +24,9 @@ export interface DeviceSystemInfo { // Client → Server export interface AuthMessage { + serverUrl?: string; token: string; + tokenType?: 'apiKey' | 'jwt' | 'serviceToken'; type: 'auth'; } diff --git a/src/libs/trpc/lambda/context.test.ts b/src/libs/trpc/lambda/context.test.ts index 2a8bd754ee..452e457695 100644 --- a/src/libs/trpc/lambda/context.test.ts +++ b/src/libs/trpc/lambda/context.test.ts @@ -1,6 +1,70 @@ -import { describe, expect, it } from 'vitest'; +import { NextRequest } from 'next/server'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; -import { createContextInner } from './context'; +import { ApiKeyModel } from '@/database/models/apiKey'; + +import { createContextInner, createLambdaContext } from './context'; + +const { + mockExtractTraceContext, + mockFindByKey, + mockGetSession, + mockUpdateLastUsed, + mockValidateOIDCJWT, +} = vi.hoisted(() => ({ + mockExtractTraceContext: vi.fn(), + mockFindByKey: vi.fn(), + mockGetSession: vi.fn(), + mockUpdateLastUsed: vi.fn(), + mockValidateOIDCJWT: vi.fn(), +})); + +vi.mock('@/auth', () => ({ + auth: { + api: { + getSession: mockGetSession, + }, + }, +})); + +vi.mock('@/database/core/db-adaptor', () => ({ + getServerDB: vi.fn().mockResolvedValue({}), +})); + +vi.mock('@/database/models/apiKey', () => ({ + ApiKeyModel: Object.assign( + vi.fn().mockImplementation((_db: unknown, userId: string) => ({ + updateLastUsed: userId ? mockUpdateLastUsed : vi.fn(), + })), + { + findByKey: mockFindByKey, + }, + ), +})); + +vi.mock('@/envs/auth', () => ({ + LOBE_CHAT_AUTH_HEADER: 'X-lobe-chat-auth', + LOBE_CHAT_OIDC_AUTH_HEADER: 'Oidc-Auth', + authEnv: { + ENABLE_OIDC: true, + }, +})); + +vi.mock('@/libs/observability/traceparent', () => ({ + extractTraceContext: mockExtractTraceContext, +})); + +vi.mock('@/libs/oidc-provider/jwt', () => ({ + validateOIDCJWT: mockValidateOIDCJWT, +})); + +vi.mock('@/utils/apiKey', async (importOriginal) => { + const actual = await importOriginal>(); + return { + ...actual, + isApiKeyExpired: vi.fn().mockReturnValue(false), + }; +}); describe('createContextInner', () => { it('should create context with default values when no params provided', async () => { @@ -101,3 +165,72 @@ describe('createContextInner', () => { expect(ctx.traceContext).toBe(traceContext); }); }); + +describe('createLambdaContext', () => { + beforeEach(() => { + vi.clearAllMocks(); + mockExtractTraceContext.mockReturnValue(undefined); + mockGetSession.mockResolvedValue({ user: { id: 'session-user' } }); + mockValidateOIDCJWT.mockResolvedValue({ + tokenData: { sub: 'oidc-user' }, + userId: 'oidc-user', + }); + mockUpdateLastUsed.mockResolvedValue(undefined); + }); + + it('should authenticate with API key and skip session fallback', async () => { + const apiKeyRecord = { + accessedAt: new Date(), + createdAt: new Date(), + enabled: true, + expiresAt: null, + id: 'key-1', + key: 'encrypted-key', + keyHash: 'hashed-key', + lastUsedAt: null, + name: 'Test API Key', + updatedAt: new Date(), + userId: 'api-user', + } satisfies NonNullable>>; + + vi.mocked(ApiKeyModel.findByKey).mockResolvedValue(apiKeyRecord); + + const request = new NextRequest('https://example.com/trpc/lambda', { + headers: { + 'X-API-Key': 'sk-lh-aaaaaaaaaaaaaaaa', + }, + }); + + const context = await createLambdaContext(request); + + expect(context.userId).toBe('api-user'); + expect(mockGetSession).not.toHaveBeenCalled(); + expect(mockValidateOIDCJWT).not.toHaveBeenCalled(); + }); + + it('should reject invalid API key without falling back to OIDC or session', async () => { + vi.mocked(ApiKeyModel.findByKey).mockResolvedValue(null); + + const request = new NextRequest('https://example.com/trpc/lambda', { + headers: { + 'Oidc-Auth': 'oidc-token', + 'X-API-Key': 'sk-lh-bbbbbbbbbbbbbbbb', + }, + }); + + const context = await createLambdaContext(request); + + expect(context.userId).toBeNull(); + expect(mockValidateOIDCJWT).not.toHaveBeenCalled(); + expect(mockGetSession).not.toHaveBeenCalled(); + }); + + it('should use session auth when no API key header is present', async () => { + const request = new NextRequest('https://example.com/trpc/lambda'); + + const context = await createLambdaContext(request); + + expect(context.userId).toBe('session-user'); + expect(mockGetSession).toHaveBeenCalledOnce(); + }); +}); diff --git a/src/libs/trpc/lambda/context.ts b/src/libs/trpc/lambda/context.ts index dc6082fa71..f1bbb51e5f 100644 --- a/src/libs/trpc/lambda/context.ts +++ b/src/libs/trpc/lambda/context.ts @@ -5,12 +5,16 @@ import debug from 'debug'; import { type NextRequest } from 'next/server'; import { auth } from '@/auth'; +import { getServerDB } from '@/database/core/db-adaptor'; +import { ApiKeyModel } from '@/database/models/apiKey'; import { authEnv, LOBE_CHAT_AUTH_HEADER, LOBE_CHAT_OIDC_AUTH_HEADER } from '@/envs/auth'; import { extractTraceContext } from '@/libs/observability/traceparent'; import { validateOIDCJWT } from '@/libs/oidc-provider/jwt'; +import { isApiKeyExpired, validateApiKeyFormat } from '@/utils/apiKey'; // Create context logger namespace const log = debug('lobe-trpc:lambda:context'); +const LOBE_CHAT_API_KEY_HEADER = 'X-API-Key'; const extractClientIp = (request: NextRequest): string | undefined => { const forwardedFor = request.headers.get('x-forwarded-for'); @@ -25,6 +29,31 @@ const extractClientIp = (request: NextRequest): string | undefined => { return undefined; }; +const validateApiKeyUserId = async (apiKey: string): Promise => { + if (!validateApiKeyFormat(apiKey)) return null; + + try { + const db = await getServerDB(); + const apiKeyRecord = await ApiKeyModel.findByKey(db, apiKey); + + if (!apiKeyRecord) return null; + if (!apiKeyRecord.enabled) return null; + if (isApiKeyExpired(apiKeyRecord.expiresAt)) return null; + + const userApiKeyModel = new ApiKeyModel(db, apiKeyRecord.userId); + void userApiKeyModel.updateLastUsed(apiKeyRecord.id).catch((error) => { + log('Failed to update API key last used timestamp: %O', error); + console.error('Failed to update API key last used timestamp:', error); + }); + + return apiKeyRecord.userId; + } catch (error) { + log('API key authentication failed: %O', error); + console.error('API key authentication failed, trying other methods:', error); + return null; + } +}; + export interface OIDCAuth { // Other OIDC information that might be needed (optional, as payload contains all info) [key: string]: any; @@ -117,6 +146,31 @@ export const createLambdaContext = async (request: NextRequest): Promise