diff --git a/Dockerfile b/Dockerfile index 2bb9652e8e..32ef8919c1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -200,6 +200,8 @@ ENV \ UPSTAGE_API_KEY="" UPSTAGE_MODEL_LIST="" \ # Wenxin WENXIN_ACCESS_KEY="" WENXIN_SECRET_KEY="" WENXIN_MODEL_LIST="" \ + # xAI + XAI_API_KEY="" XAI_MODEL_LIST="" \ # 01.AI ZEROONE_API_KEY="" ZEROONE_MODEL_LIST="" \ # Zhipu diff --git a/Dockerfile.database b/Dockerfile.database index 6db815efaf..4699983c0d 100644 --- a/Dockerfile.database +++ b/Dockerfile.database @@ -235,6 +235,8 @@ ENV \ UPSTAGE_API_KEY="" UPSTAGE_MODEL_LIST="" \ # Wenxin WENXIN_ACCESS_KEY="" WENXIN_SECRET_KEY="" WENXIN_MODEL_LIST="" \ + # xAI + XAI_API_KEY="" XAI_MODEL_LIST="" \ # 01.AI ZEROONE_API_KEY="" ZEROONE_MODEL_LIST="" \ # Zhipu diff --git a/src/app/(main)/settings/llm/ProviderList/providers.tsx b/src/app/(main)/settings/llm/ProviderList/providers.tsx index 314abb774f..d67e56fa12 100644 --- a/src/app/(main)/settings/llm/ProviderList/providers.tsx +++ b/src/app/(main)/settings/llm/ProviderList/providers.tsx @@ -23,6 +23,7 @@ import { TaichuProviderCard, TogetherAIProviderCard, UpstageProviderCard, + XAIProviderCard, ZeroOneProviderCard, ZhiPuProviderCard, } from '@/config/modelProviders'; @@ -70,6 +71,7 @@ export const useProviderList = (): ProviderItem[] => { MistralProviderCard, Ai21ProviderCard, UpstageProviderCard, + XAIProviderCard, QwenProviderCard, WenxinProvider, HunyuanProviderCard, diff --git a/src/config/llm.ts b/src/config/llm.ts index 4a8f64ab4c..328f86b28e 100644 --- a/src/config/llm.ts +++ b/src/config/llm.ts @@ -153,6 +153,10 @@ export const getLLMConfig = () => { SENSENOVA_ACCESS_KEY_ID: z.string().optional(), SENSENOVA_ACCESS_KEY_SECRET: z.string().optional(), SENSENOVA_MODEL_LIST: z.string().optional(), + + ENABLED_XAI: z.boolean(), + XAI_API_KEY: z.string().optional(), + XAI_MODEL_LIST: z.string().optional(), }, runtimeEnv: { API_KEY_SELECT_MODE: process.env.API_KEY_SELECT_MODE, @@ -304,6 +308,10 @@ export const getLLMConfig = () => { SENSENOVA_ACCESS_KEY_ID: process.env.SENSENOVA_ACCESS_KEY_ID, SENSENOVA_ACCESS_KEY_SECRET: process.env.SENSENOVA_ACCESS_KEY_SECRET, SENSENOVA_MODEL_LIST: process.env.SENSENOVA_MODEL_LIST, + + ENABLED_XAI: !!process.env.XAI_API_KEY, + XAI_API_KEY: process.env.XAI_API_KEY, + XAI_MODEL_LIST: process.env.XAI_MODEL_LIST, }, }); }; diff --git a/src/config/modelProviders/index.ts b/src/config/modelProviders/index.ts index 3b0c8eb1e5..b535d5c514 100644 --- a/src/config/modelProviders/index.ts +++ b/src/config/modelProviders/index.ts @@ -31,6 +31,7 @@ import TaichuProvider from './taichu'; import TogetherAIProvider from './togetherai'; import UpstageProvider from './upstage'; import WenxinProvider from './wenxin'; +import XAIProvider from './xai'; import ZeroOneProvider from './zeroone'; import ZhiPuProvider from './zhipu'; @@ -53,6 +54,7 @@ export const LOBE_DEFAULT_MODEL_LIST: ChatModelCard[] = [ PerplexityProvider.chatModels, AnthropicProvider.chatModels, HuggingFaceProvider.chatModels, + XAIProvider.chatModels, ZeroOneProvider.chatModels, StepfunProvider.chatModels, NovitaProvider.chatModels, @@ -88,6 +90,7 @@ export const DEFAULT_MODEL_PROVIDER_LIST = [ MistralProvider, Ai21Provider, UpstageProvider, + XAIProvider, QwenProvider, WenxinProvider, HunyuanProvider, @@ -145,5 +148,6 @@ export { default as TaichuProviderCard } from './taichu'; export { default as TogetherAIProviderCard } from './togetherai'; export { default as UpstageProviderCard } from './upstage'; export { default as WenxinProviderCard } from './wenxin'; +export { default as XAIProviderCard } from './xai'; export { default as ZeroOneProviderCard } from './zeroone'; export { default as ZhiPuProviderCard } from './zhipu'; diff --git a/src/config/modelProviders/xai.ts b/src/config/modelProviders/xai.ts new file mode 100644 index 0000000000..94eca70a05 --- /dev/null +++ b/src/config/modelProviders/xai.ts @@ -0,0 +1,29 @@ +import { ModelProviderCard } from '@/types/llm'; + +// ref: https://x.ai/about +const XAI: ModelProviderCard = { + chatModels: [ + { + description: '拥有与 Grok 2 相当的性能,但具有更高的效率、速度和功能。', + displayName: 'Grok Beta', + enabled: true, + functionCall: true, + id: 'grok-beta', + pricing: { + input: 5, + output: 15, + }, + tokens: 131_072, + }, + ], + checkModel: 'grok-beta', + description: + 'xAI 是一家致力于构建人工智能以加速人类科学发现的公司。我们的使命是推动我们对宇宙的共同理解。', + id: 'xai', + modelList: { showModelFetcher: true }, + modelsUrl: 'https://docs.x.ai/docs#models', + name: 'xAI', + url: 'https://console.x.ai', +}; + +export default XAI; diff --git a/src/const/settings/llm.ts b/src/const/settings/llm.ts index e69e99d79c..4aa0edb245 100644 --- a/src/const/settings/llm.ts +++ b/src/const/settings/llm.ts @@ -29,6 +29,7 @@ import { TogetherAIProviderCard, UpstageProviderCard, WenxinProviderCard, + XAIProviderCard, ZeroOneProviderCard, ZhiPuProviderCard, filterEnabledModels, @@ -161,6 +162,10 @@ export const DEFAULT_LLM_CONFIG: UserModelProviderConfig = { enabled: false, enabledModels: filterEnabledModels(WenxinProviderCard), }, + xai: { + enabled: false, + enabledModels: filterEnabledModels(XAIProviderCard), + }, zeroone: { enabled: false, enabledModels: filterEnabledModels(ZeroOneProviderCard), diff --git a/src/libs/agent-runtime/AgentRuntime.ts b/src/libs/agent-runtime/AgentRuntime.ts index 5930c5f8dc..7f54ec1abd 100644 --- a/src/libs/agent-runtime/AgentRuntime.ts +++ b/src/libs/agent-runtime/AgentRuntime.ts @@ -42,6 +42,7 @@ import { TextToSpeechPayload, } from './types'; import { LobeUpstageAI } from './upstage'; +import { LobeXAI } from './xai'; import { LobeZeroOneAI } from './zeroone'; import { LobeZhipuAI } from './zhipu'; @@ -156,6 +157,7 @@ class AgentRuntime { taichu: Partial; togetherai: Partial; upstage: Partial; + xai: Partial; zeroone: Partial; zhipu: Partial; }>, @@ -324,6 +326,11 @@ class AgentRuntime { break; } + case ModelProvider.XAI: { + runtimeModel = new LobeXAI(params.xai); + break; + } + case ModelProvider.Cloudflare: { runtimeModel = new LobeCloudflareAI(params.cloudflare ?? {}); break; diff --git a/src/libs/agent-runtime/types/type.ts b/src/libs/agent-runtime/types/type.ts index 35e55189b1..5a00cbc3dd 100644 --- a/src/libs/agent-runtime/types/type.ts +++ b/src/libs/agent-runtime/types/type.ts @@ -53,6 +53,7 @@ export enum ModelProvider { TogetherAI = 'togetherai', Upstage = 'upstage', Wenxin = 'wenxin', + XAI = 'xai', ZeroOne = 'zeroone', ZhiPu = 'zhipu', } diff --git a/src/libs/agent-runtime/xai/index.test.ts b/src/libs/agent-runtime/xai/index.test.ts new file mode 100644 index 0000000000..ab54dd391c --- /dev/null +++ b/src/libs/agent-runtime/xai/index.test.ts @@ -0,0 +1,255 @@ +// @vitest-environment node +import OpenAI from 'openai'; +import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +import { + ChatStreamCallbacks, + LobeOpenAICompatibleRuntime, + ModelProvider, +} from '@/libs/agent-runtime'; + +import * as debugStreamModule from '../utils/debugStream'; +import { LobeXAI } from './index'; + +const provider = ModelProvider.XAI; +const defaultBaseURL = 'https://api.x.ai/v1'; + +const bizErrorType = 'ProviderBizError'; +const invalidErrorType = 'InvalidProviderAPIKey'; + +// Mock the console.error to avoid polluting test output +vi.spyOn(console, 'error').mockImplementation(() => {}); + +let instance: LobeOpenAICompatibleRuntime; + +beforeEach(() => { + instance = new LobeXAI({ apiKey: 'test' }); + + // 使用 vi.spyOn 来模拟 chat.completions.create 方法 + vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue( + new ReadableStream() as any, + ); +}); + +afterEach(() => { + vi.clearAllMocks(); +}); + +describe('LobeXAI', () => { + describe('init', () => { + it('should correctly initialize with an API key', async () => { + const instance = new LobeXAI({ apiKey: 'test_api_key' }); + expect(instance).toBeInstanceOf(LobeXAI); + expect(instance.baseURL).toEqual(defaultBaseURL); + }); + }); + + describe('chat', () => { + describe('Error', () => { + it('should return OpenAIBizError with an openai error response when OpenAI.APIError is thrown', async () => { + // Arrange + const apiError = new OpenAI.APIError( + 400, + { + status: 400, + error: { + message: 'Bad Request', + }, + }, + 'Error message', + {}, + ); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'grok-beta', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: defaultBaseURL, + error: { + error: { message: 'Bad Request' }, + status: 400, + }, + errorType: bizErrorType, + provider, + }); + } + }); + + it('should throw AgentRuntimeError with NoOpenAIAPIKey if no apiKey is provided', async () => { + try { + new LobeXAI({}); + } catch (e) { + expect(e).toEqual({ errorType: invalidErrorType }); + } + }); + + it('should return OpenAIBizError with the cause when OpenAI.APIError is thrown with cause', async () => { + // Arrange + const errorInfo = { + stack: 'abc', + cause: { + message: 'api is undefined', + }, + }; + const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'grok-beta', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: defaultBaseURL, + error: { + cause: { message: 'api is undefined' }, + stack: 'abc', + }, + errorType: bizErrorType, + provider, + }); + } + }); + + it('should return OpenAIBizError with an cause response with desensitize Url', async () => { + // Arrange + const errorInfo = { + stack: 'abc', + cause: { message: 'api is undefined' }, + }; + const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); + + instance = new LobeXAI({ + apiKey: 'test', + + baseURL: 'https://api.abc.com/v1', + }); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'grok-beta', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: 'https://api.***.com/v1', + error: { + cause: { message: 'api is undefined' }, + stack: 'abc', + }, + errorType: bizErrorType, + provider, + }); + } + }); + + it('should throw an InvalidXAIAPIKey error type on 401 status code', async () => { + // Mock the API call to simulate a 401 error + const error = new Error('Unauthorized') as any; + error.status = 401; + vi.mocked(instance['client'].chat.completions.create).mockRejectedValue(error); + + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'grok-beta', + temperature: 0, + }); + } catch (e) { + // Expect the chat method to throw an error with InvalidXAIAPIKey + expect(e).toEqual({ + endpoint: defaultBaseURL, + error: new Error('Unauthorized'), + errorType: invalidErrorType, + provider, + }); + } + }); + + it('should return AgentRuntimeError for non-OpenAI errors', async () => { + // Arrange + const genericError = new Error('Generic Error'); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(genericError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'grok-beta', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: defaultBaseURL, + errorType: 'AgentRuntimeError', + provider, + error: { + name: genericError.name, + cause: genericError.cause, + message: genericError.message, + stack: genericError.stack, + }, + }); + } + }); + }); + + describe('DEBUG', () => { + it('should call debugStream and return StreamingTextResponse when DEBUG_XAI_CHAT_COMPLETION is 1', async () => { + // Arrange + const mockProdStream = new ReadableStream() as any; // 模拟的 prod 流 + const mockDebugStream = new ReadableStream({ + start(controller) { + controller.enqueue('Debug stream content'); + controller.close(); + }, + }) as any; + mockDebugStream.toReadableStream = () => mockDebugStream; // 添加 toReadableStream 方法 + + // 模拟 chat.completions.create 返回值,包括模拟的 tee 方法 + (instance['client'].chat.completions.create as Mock).mockResolvedValue({ + tee: () => [mockProdStream, { toReadableStream: () => mockDebugStream }], + }); + + // 保存原始环境变量值 + const originalDebugValue = process.env.DEBUG_XAI_CHAT_COMPLETION; + + // 模拟环境变量 + process.env.DEBUG_XAI_CHAT_COMPLETION = '1'; + vi.spyOn(debugStreamModule, 'debugStream').mockImplementation(() => Promise.resolve()); + + // 执行测试 + // 运行你的测试函数,确保它会在条件满足时调用 debugStream + // 假设的测试函数调用,你可能需要根据实际情况调整 + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'grok-beta', + stream: true, + temperature: 0, + }); + + // 验证 debugStream 被调用 + expect(debugStreamModule.debugStream).toHaveBeenCalled(); + + // 恢复原始环境变量值 + process.env.DEBUG_XAI_CHAT_COMPLETION = originalDebugValue; + }); + }); + }); +}); diff --git a/src/libs/agent-runtime/xai/index.ts b/src/libs/agent-runtime/xai/index.ts new file mode 100644 index 0000000000..ed52caa342 --- /dev/null +++ b/src/libs/agent-runtime/xai/index.ts @@ -0,0 +1,10 @@ +import { ModelProvider } from '../types'; +import { LobeOpenAICompatibleFactory } from '../utils/openaiCompatibleFactory'; + +export const LobeXAI = LobeOpenAICompatibleFactory({ + baseURL: 'https://api.x.ai/v1', + debug: { + chatCompletion: () => process.env.DEBUG_XAI_CHAT_COMPLETION === '1', + }, + provider: ModelProvider.XAI, +}); diff --git a/src/server/globalConfig/index.ts b/src/server/globalConfig/index.ts index 437486ce0a..6fec71178d 100644 --- a/src/server/globalConfig/index.ts +++ b/src/server/globalConfig/index.ts @@ -33,6 +33,7 @@ import { TogetherAIProviderCard, UpstageProviderCard, WenxinProviderCard, + XAIProviderCard, ZeroOneProviderCard, ZhiPuProviderCard, } from '@/config/modelProviders'; @@ -146,6 +147,9 @@ export const getServerGlobalConfig = () => { ENABLED_HUGGINGFACE, HUGGINGFACE_MODEL_LIST, + + ENABLED_XAI, + XAI_MODEL_LIST, } = getLLMConfig(); const config: GlobalServerConfig = { @@ -399,6 +403,14 @@ export const getServerGlobalConfig = () => { modelString: WENXIN_MODEL_LIST, }), }, + xai: { + enabled: ENABLED_XAI, + enabledModels: extractEnabledModels(XAI_MODEL_LIST), + serverModelCards: transformToChatModelCards({ + defaultChatModels: XAIProviderCard.chatModels, + modelString: XAI_MODEL_LIST, + }), + }, zeroone: { enabled: ENABLED_ZEROONE, enabledModels: extractEnabledModels(ZEROONE_MODEL_LIST), diff --git a/src/server/modules/AgentRuntime/index.ts b/src/server/modules/AgentRuntime/index.ts index f440a10588..6d2ead94a0 100644 --- a/src/server/modules/AgentRuntime/index.ts +++ b/src/server/modules/AgentRuntime/index.ts @@ -286,6 +286,13 @@ const getLlmOptionsFromPayload = (provider: string, payload: JWTPayload) => { const apiKey = sensenovaAccessKeyID + ':' + sensenovaAccessKeySecret; + return { apiKey }; + } + case ModelProvider.XAI: { + const { XAI_API_KEY } = getLLMConfig(); + + const apiKey = apiKeyManager.pick(payload?.apiKey || XAI_API_KEY); + return { apiKey }; } } diff --git a/src/types/user/settings/keyVaults.ts b/src/types/user/settings/keyVaults.ts index ad0e5221d5..dec9a9c1ab 100644 --- a/src/types/user/settings/keyVaults.ts +++ b/src/types/user/settings/keyVaults.ts @@ -65,6 +65,7 @@ export interface UserKeyVaults { togetherai?: OpenAICompatibleKeyVault; upstage?: OpenAICompatibleKeyVault; wenxin?: WenxinKeyVault; + xai?: OpenAICompatibleKeyVault; zeroone?: OpenAICompatibleKeyVault; zhipu?: OpenAICompatibleKeyVault; }