feat: Support Cloudflare Workers AI (#3402)

* Delete .nvmrc

* feat: Add Cloudflare as a model provider

This commit adds support for Cloudflare as a model provider. It includes changes to the `ModelProvider` enum, the `UserKeyVaults` interface, the `getServerGlobalConfig` function, the `DEFAULT_LLM_CONFIG` constant, the `getLLMConfig` function, the `AgentRuntime` class, and the `DEFAULT_MODEL_PROVIDER_LIST` constant.

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix icon

* fix

* Create .nvmrc

* Delete src/config/modelProviders/.nvmrc

* CF -> CLOUDFLARE

* revert

* chore: Update agentRuntime.ts and auth.ts to support Cloudflare account ID in payload

* Add provider setting

* fix

* Update cloudflare.ts

* fix

* Update cloudflare.ts

* accountID

* fix

* i18n

* save changes

* commit check

* disable function calling for now

* does not catch errors when fetching models

* ready to add base url

* commit check

* revert change

* revert string boolean check

* fix type error on Vercel.
refer to https://github.com/vercel/next.js/issues/38736#issuecomment-1278917422

* i18n by groq/llama-3.1-8b-instant

* rename env var

* add test

* Revert changes that are not relavant to Cloudflare and result in merge conflicts.

* add test for models()

* move helper code to standalone file

* add test for helper methods

* remove encoder

* Merge main into cf-chat-m

* remove brand

* remove template comment

* add provider card

* Update lobe-icons

* Fix setting layout

* minor modification of model list

---------

Co-authored-by: sxjeru <sxjeru@gmail.com>
This commit is contained in:
BrandonStudio 2024-11-12 01:21:16 +08:00 committed by GitHub
parent b4514cf4f0
commit efb7adf89a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
40 changed files with 1672 additions and 3 deletions

View file

@ -112,6 +112,11 @@ OPENAI_API_KEY=sk-xxxxxxxxx
# QWEN_API_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
### Cloudflare Workers AI ####
# CLOUDFLARE_API_KEY=xxxxxxxxxxxxxxxxxxxxxxxxxxxxx
# CLOUDFLARE_BASE_URL_OR_ACCOUNT_ID=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
### SiliconCloud AI ####
# SILICONCLOUD_API_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx

View file

@ -51,6 +51,18 @@
"title": "استخدام معلومات المصادقة الخاصة بـ Bedrock المخصصة"
}
},
"cloudflare": {
"apiKey": {
"desc": "يرجى ملء Cloudflare API Key",
"placeholder": "Cloudflare API Key",
"title": "Cloudflare API Key"
},
"baseURLOrAccountID": {
"desc": "أدخل رقم حساب Cloudflare أو عنوان URL API المخصص",
"placeholder": "رقم حساب Cloudflare / عنوان URL API المخصص",
"title": "رقم حساب Cloudflare / عنوان URL API"
}
},
"github": {
"personalAccessToken": {
"desc": "أدخل رمز الوصول الشخصي الخاص بك على Github، انقر [هنا](https://github.com/settings/tokens) لإنشاء واحد",

View file

@ -51,6 +51,18 @@
"title": "Използване на персонализирана информация за удостоверяване на Bedrock"
}
},
"cloudflare": {
"apiKey": {
"desc": "Моля, въведете Cloudflare API Key",
"placeholder": "Cloudflare API Key",
"title": "Cloudflare API Key"
},
"baseURLOrAccountID": {
"desc": "Въведете ID на Cloudflare или личен API адрес",
"placeholder": "ID на Cloudflare / личен API адрес",
"title": "ID на Cloudflare / API адрес"
}
},
"github": {
"personalAccessToken": {
"desc": "Въведете вашия GitHub PAT, кликнете [тук](https://github.com/settings/tokens), за да създадете",

View file

@ -51,6 +51,18 @@
"title": "Verwenden Sie benutzerdefinierte Bedrock-Authentifizierungsinformationen"
}
},
"cloudflare": {
"apiKey": {
"desc": "Bitte füllen Sie die Cloudflare API Key",
"placeholder": "Cloudflare API Key",
"title": "Cloudflare API Key"
},
"baseURLOrAccountID": {
"desc": "Eingeben Sie die Cloudflare-Kundenkennung oder die benutzerdefinierte API-Adresse",
"placeholder": "Cloudflare-Kundenkennung / benutzerdefinierte API-Adresse",
"title": "Cloudflare-Kundenkennung / API-Adresse"
}
},
"github": {
"personalAccessToken": {
"desc": "Geben Sie Ihr GitHub-PAT ein und klicken Sie [hier](https://github.com/settings/tokens), um eines zu erstellen.",

View file

@ -51,6 +51,18 @@
"title": "Use Custom Bedrock Authentication Information"
}
},
"cloudflare": {
"apiKey": {
"desc": "Please enter Cloudflare API Key",
"placeholder": "Cloudflare API Key",
"title": "Cloudflare API Key"
},
"baseURLOrAccountID": {
"desc": "Enter your Cloudflare account ID or custom API address",
"placeholder": "Cloudflare Account ID / custom API URL",
"title": "Cloudflare Account ID / API Address"
}
},
"github": {
"personalAccessToken": {
"desc": "Enter your GitHub PAT. Click [here](https://github.com/settings/tokens) to create one.",

View file

@ -51,6 +51,18 @@
"title": "Usar información de autenticación de Bedrock personalizada"
}
},
"cloudflare": {
"apiKey": {
"desc": "Por favor complete la Cloudflare API Key",
"placeholder": "Cloudflare API Key",
"title": "Cloudflare API Key"
},
"baseURLOrAccountID": {
"desc": "Ingrese el ID de cuenta de Cloudflare o la dirección URL personalizada de API",
"placeholder": "ID de cuenta de Cloudflare / URL de API personalizada",
"title": "ID de cuenta de Cloudflare / dirección URL de API"
}
},
"github": {
"personalAccessToken": {
"desc": "Introduce tu PAT de Github, haz clic [aquí](https://github.com/settings/tokens) para crear uno",

View file

@ -51,6 +51,18 @@
"title": "Utiliser des informations d'authentification Bedrock personnalisées"
}
},
"cloudflare": {
"apiKey": {
"desc": "Veuillez remplir l'Cloudflare API Key",
"placeholder": "Cloudflare API Key",
"title": "Cloudflare API Key"
},
"baseURLOrAccountID": {
"desc": "Saisir l'ID de compte Cloudflare ou l'adresse API personnalisée",
"placeholder": "ID de compte Cloudflare / URL API personnalisée",
"title": "ID de compte Cloudflare / adresse API"
}
},
"github": {
"personalAccessToken": {
"desc": "Entrez votre PAT GitHub, cliquez [ici](https://github.com/settings/tokens) pour en créer un.",

View file

@ -51,6 +51,18 @@
"title": "Usa le informazioni di autenticazione Bedrock personalizzate"
}
},
"cloudflare": {
"apiKey": {
"desc": "Compila l'Cloudflare API Key",
"placeholder": "Cloudflare API Key",
"title": "Cloudflare API Key"
},
"baseURLOrAccountID": {
"desc": "Inserisci l'ID dell'account Cloudflare o l'indirizzo API personalizzato",
"placeholder": "ID account Cloudflare / URL API personalizzato",
"title": "ID account Cloudflare / indirizzo API"
}
},
"github": {
"personalAccessToken": {
"desc": "Inserisci il tuo PAT di Github, clicca [qui](https://github.com/settings/tokens) per crearne uno",

View file

@ -51,6 +51,18 @@
"title": "使用カスタム Bedrock 認証情報"
}
},
"cloudflare": {
"apiKey": {
"desc": "Cloudflare API Key を入力してください",
"placeholder": "Cloudflare API Key",
"title": "Cloudflare API Key"
},
"baseURLOrAccountID": {
"desc": "Cloudflare アカウント ID またはカスタム API アドレスを入力してください。",
"placeholder": "Cloudflare アカウント ID / カスタム API URL",
"title": "Cloudflare アカウント ID / API アドレス"
}
},
"github": {
"personalAccessToken": {
"desc": "あなたのGithub PATを入力してください。[こちら](https://github.com/settings/tokens)をクリックして作成します",

View file

@ -51,6 +51,18 @@
"title": "사용자 정의 Bedrock 인증 정보 사용"
}
},
"cloudflare": {
"apiKey": {
"desc": "Cloudflare API Key 를 작성해 주세요.",
"placeholder": "Cloudflare API Key",
"title": "Cloudflare API Key"
},
"baseURLOrAccountID": {
"desc": "클라우드 플레어 계정 ID 또는 사용자 지정 API 주소 입력",
"placeholder": "클라우드 플레어 계정 ID / 사용자 지정 API 주소",
"title": "클라우드 플레어 계정 ID / API 주소"
}
},
"github": {
"personalAccessToken": {
"desc": "당신의 Github PAT를 입력하세요. [여기](https://github.com/settings/tokens)를 클릭하여 생성하세요.",

View file

@ -51,6 +51,18 @@
"title": "Gebruik aangepaste Bedrock-verificatiegegevens"
}
},
"cloudflare": {
"apiKey": {
"desc": "Voer Cloudflare API Key in",
"placeholder": "Cloudflare API Key",
"title": "Cloudflare API Key"
},
"baseURLOrAccountID": {
"desc": "Voer uw Cloudflare-account ID of een custom API-URL in",
"placeholder": "Cloudflare-account ID / custom API-URL",
"title": "Cloudflare-account ID / API-URL"
}
},
"github": {
"personalAccessToken": {
"desc": "Vul je Github PAT in, klik [hier](https://github.com/settings/tokens) om er een te maken",

View file

@ -51,6 +51,18 @@
"title": "Użyj niestandardowych informacji uwierzytelniających Bedrock"
}
},
"cloudflare": {
"apiKey": {
"desc": "Wprowadź klucz Cloudflare API Key",
"placeholder": "Cloudflare API Key",
"title": "Cloudflare API Key"
},
"baseURLOrAccountID": {
"desc": "Wprowadź ID konta Cloudflare lub adres API niestandardowy",
"placeholder": "ID konta Cloudflare / adres API niestandardowy",
"title": "ID konta Cloudflare / adres API"
}
},
"github": {
"personalAccessToken": {
"desc": "Wprowadź swój osobisty token dostępu GitHub (PAT), kliknij [tutaj](https://github.com/settings/tokens), aby go utworzyć",

View file

@ -51,6 +51,18 @@
"title": "Usar informações de autenticação Bedrock personalizadas"
}
},
"cloudflare": {
"apiKey": {
"desc": "Insira o Cloudflare API Key",
"placeholder": "Cloudflare API Key",
"title": "Cloudflare API Key"
},
"baseURLOrAccountID": {
"desc": "Insira o ID da conta do Cloudflare ou o endereço da API personalizado",
"placeholder": "ID da conta do Cloudflare / URL da API personalizada",
"title": "ID da conta do Cloudflare / Endereço da API"
}
},
"github": {
"personalAccessToken": {
"desc": "Insira seu PAT do Github, clique [aqui](https://github.com/settings/tokens) para criar",

View file

@ -51,6 +51,18 @@
"title": "Использовать пользовательскую информацию аутентификации Bedrock"
}
},
"cloudflare": {
"apiKey": {
"desc": "Пожалуйста, заполните Cloudflare API Key",
"placeholder": "Cloudflare API Key",
"title": "Cloudflare API Key"
},
"baseURLOrAccountID": {
"desc": "Введите ID аккаунта Cloudflare или адрес API по умолчанию",
"placeholder": "ID аккаунта Cloudflare / адрес API по умолчанию",
"title": "ID аккаунта Cloudflare / адрес API"
}
},
"github": {
"personalAccessToken": {
"desc": "Введите ваш персональный токен доступа GitHub (PAT), нажмите [здесь](https://github.com/settings/tokens), чтобы создать его",

View file

@ -51,6 +51,18 @@
"title": "Özel Bedrock Kimlik Bilgilerini Kullan"
}
},
"cloudflare": {
"apiKey": {
"desc": "Lütfen doldurun Cloudflare API Key",
"placeholder": "Cloudflare API Key",
"title": "Cloudflare API Key"
},
"baseURLOrAccountID": {
"desc": "Cloudflare hesabınızın ID'sini veya özel API adresinizi girin",
"placeholder": "Cloudflare Hesap ID / Özel API Adresi",
"title": "Cloudflare Hesap ID / API Adresi"
}
},
"github": {
"personalAccessToken": {
"desc": "Github PAT'nizi girin, [buraya](https://github.com/settings/tokens) tıklayarak oluşturun",

View file

@ -51,6 +51,18 @@
"title": "Sử dụng Thông tin Xác thực Bedrock tùy chỉnh"
}
},
"cloudflare": {
"apiKey": {
"desc": "Vui lòng nhập Cloudflare API Key",
"placeholder": "Cloudflare API Key",
"title": "Cloudflare API Key"
},
"baseURLOrAccountID": {
"desc": "Nhập ID tài khoản Cloudflare hoặc địa chỉ API tùy chỉnh",
"placeholder": "ID tài khoản Cloudflare / địa chỉ API tùy chỉnh",
"title": "ID tài khoản Cloudflare / địa chỉ API"
}
},
"github": {
"personalAccessToken": {
"desc": "Nhập mã truy cập cá nhân Github của bạn, nhấp vào [đây](https://github.com/settings/tokens) để tạo",

View file

@ -51,6 +51,18 @@
"title": "使用自定义 Bedrock 鉴权信息"
}
},
"cloudflare": {
"apiKey": {
"desc": "请填写 Cloudflare API Key",
"placeholder": "Cloudflare API Key",
"title": "Cloudflare API Key"
},
"baseURLOrAccountID": {
"desc": "填入 Cloudflare 账户 ID 或 自定义 API 地址",
"placeholder": "Cloudflare Account ID / custom API URL",
"title": "Cloudflare 账户 ID / API 地址"
}
},
"github": {
"personalAccessToken": {
"desc": "填入你的 Github PAT点击 [这里](https://github.com/settings/tokens) 创建",

View file

@ -51,6 +51,18 @@
"title": "使用自定義 Bedrock 驗證資訊"
}
},
"cloudflare": {
"apiKey": {
"desc": "請填入 Cloudflare API Key",
"placeholder": "Cloudflare API Key",
"title": "Cloudflare API Key"
},
"baseURLOrAccountID": {
"desc": "填入 Cloudflare 帳戶 ID 或 自定義 API 位址",
"placeholder": "Cloudflare 帳戶 ID / 自定義 API 位址",
"title": "Cloudflare 帳戶 ID / API 位址"
}
},
"github": {
"personalAccessToken": {
"desc": "填入你的 Github 個人存取權杖,點擊[這裡](https://github.com/settings/tokens) 創建",

View file

@ -123,7 +123,7 @@
"@langchain/community": "^0.3.0",
"@lobehub/chat-plugin-sdk": "^1.32.4",
"@lobehub/chat-plugins-gateway": "^1.9.0",
"@lobehub/icons": "^1.37.0",
"@lobehub/icons": "^1.38.1",
"@lobehub/tts": "^1.25.1",
"@lobehub/ui": "^1.152.0",
"@neondatabase/serverless": "^0.10.1",

View file

@ -0,0 +1,43 @@
'use client';
import { Input } from 'antd';
import { useTranslation } from 'react-i18next';
import { CloudflareProviderCard } from '@/config/modelProviders';
import { GlobalLLMProviderKey } from '@/types/user/settings';
import { KeyVaultsConfigKey } from '../../const';
import { ProviderItem } from '../../type';
const providerKey: GlobalLLMProviderKey = 'cloudflare';
export const useCloudflareProvider = (): ProviderItem => {
const { t } = useTranslation('modelProvider');
return {
...CloudflareProviderCard,
apiKeyItems: [
{
children: (
<Input.Password
autoComplete={'new-password'}
placeholder={t(`${providerKey}.apiKey.placeholder`)}
/>
),
desc: t(`${providerKey}.apiKey.desc`),
label: t(`${providerKey}.apiKey.title`),
name: [KeyVaultsConfigKey, providerKey, 'apiKey'],
},
{
children: (
<Input
placeholder={t(`${providerKey}.baseURLOrAccountID.placeholder`)}
/>
),
desc: t(`${providerKey}.baseURLOrAccountID.desc`),
label: t(`${providerKey}.baseURLOrAccountID.title`),
name: [KeyVaultsConfigKey, providerKey, 'baseURLOrAccountID'],
},
],
};
};

View file

@ -30,6 +30,7 @@ import {
import { ProviderItem } from '../type';
import { useAzureProvider } from './Azure';
import { useBedrockProvider } from './Bedrock';
import { useCloudflareProvider } from './Cloudflare';
import { useGithubProvider } from './Github';
import { useHuggingFaceProvider } from './HuggingFace';
import { useOllamaProvider } from './Ollama';
@ -42,6 +43,7 @@ export const useProviderList = (): ProviderItem[] => {
const OllamaProvider = useOllamaProvider();
const OpenAIProvider = useOpenAIProvider();
const BedrockProvider = useBedrockProvider();
const CloudflareProvider = useCloudflareProvider();
const GithubProvider = useGithubProvider();
const HuggingFaceProvider = useHuggingFaceProvider();
const WenxinProvider = useWenxinProvider();
@ -58,6 +60,7 @@ export const useProviderList = (): ProviderItem[] => {
DeepSeekProviderCard,
HuggingFaceProvider,
OpenRouterProviderCard,
CloudflareProvider,
GithubProvider,
NovitaProviderCard,
TogetherAIProviderCard,
@ -87,6 +90,7 @@ export const useProviderList = (): ProviderItem[] => {
OllamaProvider,
OpenAIProvider,
BedrockProvider,
CloudflareProvider,
GithubProvider,
WenxinProvider,
HuggingFaceProvider,

View file

@ -115,6 +115,10 @@ export const getLLMConfig = () => {
TAICHU_API_KEY: z.string().optional(),
TAICHU_MODEL_LIST: z.string().optional(),
ENABLED_CLOUDFLARE: z.boolean(),
CLOUDFLARE_API_KEY: z.string().optional(),
CLOUDFLARE_BASE_URL_OR_ACCOUNT_ID: z.string().optional(),
ENABLED_AI360: z.boolean(),
AI360_API_KEY: z.string().optional(),
AI360_MODEL_LIST: z.string().optional(),
@ -261,6 +265,11 @@ export const getLLMConfig = () => {
TAICHU_API_KEY: process.env.TAICHU_API_KEY,
TAICHU_MODEL_LIST: process.env.TAICHU_MODEL_LIST,
ENABLED_CLOUDFLARE:
!!process.env.CLOUDFLARE_API_KEY && !!process.env.CLOUDFLARE_BASE_URL_OR_ACCOUNT_ID,
CLOUDFLARE_API_KEY: process.env.CLOUDFLARE_API_KEY,
CLOUDFLARE_BASE_URL_OR_ACCOUNT_ID: process.env.CLOUDFLARE_BASE_URL_OR_ACCOUNT_ID,
ENABLED_AI360: !!process.env.AI360_API_KEY,
AI360_API_KEY: process.env.AI360_API_KEY,
AI360_MODEL_LIST: process.env.AI360_MODEL_LIST,

View file

@ -0,0 +1,89 @@
import { ModelProviderCard } from '@/types/llm';
// ref https://developers.cloudflare.com/workers-ai/models/#text-generation
// api https://developers.cloudflare.com/workers-ai/configuration/open-ai-compatibility
const Cloudflare: ModelProviderCard = {
chatModels: [
{
displayName: 'deepseek-coder-6.7b-instruct-awq',
enabled: true,
id: '@hf/thebloke/deepseek-coder-6.7b-instruct-awq',
tokens: 16_384,
},
{
displayName: 'gemma-7b-it',
enabled: true,
id: '@hf/google/gemma-7b-it',
tokens: 2048,
},
{
displayName: 'hermes-2-pro-mistral-7b',
enabled: true,
// functionCall: true,
id: '@hf/nousresearch/hermes-2-pro-mistral-7b',
tokens: 4096,
},
{
displayName: 'llama-3-8b-instruct-awq',
id: '@cf/meta/llama-3-8b-instruct-awq',
tokens: 8192,
},
{
displayName: 'mistral-7b-instruct-v0.2',
id: '@hf/mistral/mistral-7b-instruct-v0.2',
tokens: 4096,
},
{
displayName: 'neural-chat-7b-v3-1-awq',
enabled: true,
id: '@hf/thebloke/neural-chat-7b-v3-1-awq',
tokens: 32_768,
},
{
displayName: 'openchat-3.5-0106',
id: '@cf/openchat/openchat-3.5-0106',
tokens: 8192,
},
{
displayName: 'openhermes-2.5-mistral-7b-awq',
enabled: true,
id: '@hf/thebloke/openhermes-2.5-mistral-7b-awq',
tokens: 32_768,
},
{
displayName: 'qwen1.5-14b-chat-awq',
enabled: true,
id: '@cf/qwen/qwen1.5-14b-chat-awq',
tokens: 32_768,
},
{
displayName: 'starling-lm-7b-beta',
enabled: true,
id: '@hf/nexusflow/starling-lm-7b-beta',
tokens: 4096,
},
{
displayName: 'zephyr-7b-beta-awq',
enabled: true,
id: '@hf/thebloke/zephyr-7b-beta-awq',
tokens: 32_768,
},
{
description:
'Generation over generation, Meta Llama 3 demonstrates state-of-the-art performance on a wide range of industry benchmarks and offers new capabilities, including improved reasoning.\t',
displayName: 'meta-llama-3-8b-instruct',
enabled: true,
functionCall: false,
id: '@hf/meta-llama/meta-llama-3-8b-instruct',
},
],
checkModel: '@hf/meta-llama/meta-llama-3-8b-instruct',
id: 'cloudflare',
modelList: {
showModelFetcher: true,
},
name: 'Cloudflare Workers AI',
url: 'https://developers.cloudflare.com/workers-ai/models',
};
export default Cloudflare;

View file

@ -6,6 +6,7 @@ import AnthropicProvider from './anthropic';
import AzureProvider from './azure';
import BaichuanProvider from './baichuan';
import BedrockProvider from './bedrock';
import CloudflareProvider from './cloudflare';
import DeepSeekProvider from './deepseek';
import FireworksAIProvider from './fireworksai';
import GithubProvider from './github';
@ -57,6 +58,7 @@ export const LOBE_DEFAULT_MODEL_LIST: ChatModelCard[] = [
NovitaProvider.chatModels,
BaichuanProvider.chatModels,
TaichuProvider.chatModels,
CloudflareProvider.chatModels,
Ai360Provider.chatModels,
SiliconCloudProvider.chatModels,
UpstageProvider.chatModels,
@ -99,6 +101,7 @@ export const DEFAULT_MODEL_PROVIDER_LIST = [
MinimaxProvider,
Ai360Provider,
TaichuProvider,
CloudflareProvider,
SiliconCloudProvider,
];
@ -117,6 +120,7 @@ export { default as AnthropicProviderCard } from './anthropic';
export { default as AzureProviderCard } from './azure';
export { default as BaichuanProviderCard } from './baichuan';
export { default as BedrockProviderCard } from './bedrock';
export { default as CloudflareProviderCard } from './cloudflare';
export { default as DeepSeekProviderCard } from './deepseek';
export { default as FireworksAIProviderCard } from './fireworksai';
export { default as GithubProviderCard } from './github';

View file

@ -37,6 +37,8 @@ export interface JWTPayload {
awsSecretAccessKey?: string;
awsSessionToken?: string;
cloudflareBaseURLOrAccountID?: string;
wenxinAccessKey?: string;
wenxinSecretKey?: string;

View file

@ -4,6 +4,7 @@ import {
AnthropicProviderCard,
BaichuanProviderCard,
BedrockProviderCard,
CloudflareProviderCard,
DeepSeekProviderCard,
FireworksAIProviderCard,
GithubProviderCard,
@ -59,6 +60,10 @@ export const DEFAULT_LLM_CONFIG: UserModelProviderConfig = {
enabled: false,
enabledModels: filterEnabledModels(BedrockProviderCard),
},
cloudflare: {
enabled: false,
enabledModels: filterEnabledModels(CloudflareProviderCard),
},
deepseek: {
enabled: false,
enabledModels: filterEnabledModels(DeepSeekProviderCard),

View file

@ -9,6 +9,7 @@ import { LobeAnthropicAI } from './anthropic';
import { LobeAzureOpenAI } from './azureOpenai';
import { LobeBaichuanAI } from './baichuan';
import { LobeBedrockAI, LobeBedrockAIParams } from './bedrock';
import { LobeCloudflareAI, LobeCloudflareParams } from './cloudflare';
import { LobeDeepSeekAI } from './deepseek';
import { LobeFireworksAI } from './fireworksai';
import { LobeGithubAI } from './github';
@ -131,6 +132,7 @@ class AgentRuntime {
azure: { apiVersion?: string; apikey?: string; endpoint?: string };
baichuan: Partial<ClientOptions>;
bedrock: Partial<LobeBedrockAIParams>;
cloudflare: Partial<LobeCloudflareParams>;
deepseek: Partial<ClientOptions>;
fireworksai: Partial<ClientOptions>;
github: Partial<ClientOptions>;
@ -321,8 +323,12 @@ class AgentRuntime {
runtimeModel = await LobeSenseNovaAI.fromAPIKey(params.sensenova);
break;
}
}
case ModelProvider.Cloudflare: {
runtimeModel = new LobeCloudflareAI(params.cloudflare ?? {});
break;
}
}
return new AgentRuntime(runtimeModel);
}
}

View file

@ -0,0 +1,648 @@
// @vitest-environment node
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { ChatCompletionTool } from '@/libs/agent-runtime';
import * as debugStreamModule from '../utils/debugStream';
import { LobeCloudflareAI } from './index';
const provider = 'cloudflare';
const bizErrorType = 'ProviderBizError';
const invalidErrorType = 'InvalidProviderAPIKey';
// Mock the console.error to avoid polluting test output
vi.spyOn(console, 'error').mockImplementation(() => {});
let instance: LobeCloudflareAI;
const textEncoder = new TextEncoder();
afterEach(() => {
vi.restoreAllMocks();
});
describe('LobeCloudflareAI', () => {
const accountID = '80009000a000b000c000d000e000f000';
describe('init', () => {
it('should correctly initialize with API key and Account ID', async () => {
const instance = new LobeCloudflareAI({
apiKey: 'test_api_key',
baseURLOrAccountID: accountID,
});
expect(instance).toBeInstanceOf(LobeCloudflareAI);
expect(instance.baseURL).toBe(
`https://api.cloudflare.com/client/v4/accounts/${accountID}/ai/run/`,
);
expect(instance.accountID).toBe(accountID);
});
it('should correctly initialize with API key and Gateway URL', async () => {
const baseURL = `https://gateway.ai.cloudflare.com/v1/${accountID}/test-gateway/workers-ai`;
const instance = new LobeCloudflareAI({
apiKey: 'test_api_key',
baseURLOrAccountID: baseURL,
});
expect(instance).toBeInstanceOf(LobeCloudflareAI);
expect(instance.baseURL).toBe(baseURL + '/'); // baseURL MUST end with '/'.
expect(instance.accountID).toBe(accountID);
});
});
describe('chat', () => {
beforeEach(() => {
instance = new LobeCloudflareAI({
apiKey: 'test_api_key',
baseURLOrAccountID: accountID,
});
// Mock fetch
vi.spyOn(globalThis, 'fetch').mockResolvedValue(
new Response(
new ReadableStream<Uint8Array>({
start(controller) {
controller.enqueue(textEncoder.encode('data: {"response": "Hello, world!"}\n\n'));
controller.close();
},
}),
),
);
});
it('should return a Response on successful API call', async () => {
const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: '@hf/meta-llama/meta-llama-3-8b-instruct',
temperature: 0,
});
// Assert
expect(result).toBeInstanceOf(Response);
});
it('should handle text messages correctly', async () => {
// Arrange
const textEncoder = new TextEncoder();
const mockResponse = new Response(
new ReadableStream<Uint8Array>({
start(controller) {
controller.enqueue(textEncoder.encode('data: {"response": "Hello, world!"}\n\n'));
controller.close();
},
}),
);
(globalThis.fetch as Mock).mockResolvedValue(mockResponse);
// Act
const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: '@hf/meta-llama/meta-llama-3-8b-instruct',
temperature: 0,
top_p: 1,
});
// Assert
expect(globalThis.fetch).toHaveBeenCalledWith(
// url
expect.objectContaining({
pathname: `/client/v4/accounts/${accountID}/ai/run/@hf/meta-llama/meta-llama-3-8b-instruct`,
}),
// body
expect.objectContaining({
body: expect.any(String),
method: 'POST',
}),
);
const fetchCallArgs = (globalThis.fetch as Mock).mock.calls[0];
const body = JSON.parse(fetchCallArgs[1].body);
expect(body).toEqual(
expect.objectContaining({
//max_tokens: 4096,
messages: [{ content: 'Hello', role: 'user' }],
//stream: true,
temperature: 0,
top_p: 1,
}),
);
expect(result).toBeInstanceOf(Response);
});
it('should handle system prompt correctly', async () => {
// Arrange
const textEncoder = new TextEncoder();
const mockResponse = new Response(
new ReadableStream<Uint8Array>({
start(controller) {
controller.enqueue(textEncoder.encode('data: {"response": "Hello, world!"}\n\n'));
controller.close();
},
}),
);
(globalThis.fetch as Mock).mockResolvedValue(mockResponse);
// Act
const result = await instance.chat({
messages: [
{ content: 'You are an awesome greeter', role: 'system' },
{ content: 'Hello', role: 'user' },
],
model: '@hf/meta-llama/meta-llama-3-8b-instruct',
temperature: 0,
});
// Assert
expect(globalThis.fetch).toHaveBeenCalledWith(
// url
expect.objectContaining({
pathname: `/client/v4/accounts/${accountID}/ai/run/@hf/meta-llama/meta-llama-3-8b-instruct`,
}),
// body
expect.objectContaining({
body: expect.any(String),
method: 'POST',
}),
);
const fetchCallArgs = (globalThis.fetch as Mock).mock.calls[0];
const body = JSON.parse(fetchCallArgs[1].body);
expect(body).toEqual(
expect.objectContaining({
//max_tokens: 4096,
messages: [
{ content: 'You are an awesome greeter', role: 'system' },
{ content: 'Hello', role: 'user' },
],
//stream: true,
temperature: 0,
}),
);
expect(result).toBeInstanceOf(Response);
});
it('should call Cloudflare API with supported opions', async () => {
// Arrange
const mockResponse = new Response(
new ReadableStream<Uint8Array>({
start(controller) {
controller.enqueue(textEncoder.encode('data: {"response": "Hello, world!"}\n\n'));
controller.close();
},
}),
);
(globalThis.fetch as Mock).mockResolvedValue(mockResponse);
// Act
const result = await instance.chat({
max_tokens: 2048,
messages: [{ content: 'Hello', role: 'user' }],
model: '@hf/meta-llama/meta-llama-3-8b-instruct',
temperature: 0.5,
top_p: 1,
});
// Assert
expect(globalThis.fetch).toHaveBeenCalledWith(
// url
expect.objectContaining({
pathname: `/client/v4/accounts/${accountID}/ai/run/@hf/meta-llama/meta-llama-3-8b-instruct`,
}),
// body
expect.objectContaining({
body: expect.any(String),
method: 'POST',
}),
);
const fetchCallArgs = (globalThis.fetch as Mock).mock.calls[0];
const body = JSON.parse(fetchCallArgs[1].body);
expect(body).toEqual(
expect.objectContaining({
max_tokens: 2048,
messages: [{ content: 'Hello', role: 'user' }],
//stream: true,
temperature: 0.5,
top_p: 1,
}),
);
expect(result).toBeInstanceOf(Response);
});
it('should call debugStream in DEBUG mode', async () => {
// Arrange
const mockProdStream = new ReadableStream({
start(controller) {
controller.enqueue('Hello, world!');
controller.close();
},
}) as any;
const mockDebugStream = new ReadableStream({
start(controller) {
controller.enqueue('Debug stream content');
controller.close();
},
}) as any;
mockDebugStream.toReadableStream = () => mockDebugStream;
(globalThis.fetch as Mock).mockResolvedValue({
body: {
tee: () => [mockProdStream, { toReadableStream: () => mockDebugStream }],
},
});
const originalDebugValue = process.env.DEBUG_CLOUDFLARE_CHAT_COMPLETION;
process.env.DEBUG_CLOUDFLARE_CHAT_COMPLETION = '1';
vi.spyOn(debugStreamModule, 'debugStream').mockImplementation(() => Promise.resolve());
// Act
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: '@hf/meta-llama/meta-llama-3-8b-instruct',
temperature: 0,
});
// Assert
expect(debugStreamModule.debugStream).toHaveBeenCalled();
// Cleanup
process.env.DEBUG_CLOUDFLARE_CHAT_COMPLETION = originalDebugValue;
});
describe('chat with tools', () => {
it('should call client.beta.tools.messages.create when tools are provided', async () => {
// Arrange
const tools: ChatCompletionTool[] = [
{ function: { name: 'tool1', description: 'desc1' }, type: 'function' },
];
// Act
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: '@hf/meta-llama/meta-llama-3-8b-instruct',
temperature: 1,
tools,
});
// Assert
expect(globalThis.fetch).toHaveBeenCalled();
const fetchCallArgs = (globalThis.fetch as Mock).mock.calls[0];
const body = JSON.parse(fetchCallArgs[1].body);
expect(body).toEqual(
expect.objectContaining({
tools: tools.map((t) => t.function),
}),
);
});
});
describe('Error', () => {
it('should throw ProviderBizError error on 400 error', async () => {
// Arrange
const apiError = {
status: 400,
error: {
type: 'error',
error: {
type: 'authentication_error',
message: 'invalid x-api-key',
},
},
};
(globalThis.fetch as Mock).mockRejectedValue(apiError);
try {
// Act
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: '@hf/meta-llama/meta-llama-3-8b-instruct',
temperature: 0,
});
} catch (e) {
// Assert
expect(e).toEqual({
endpoint: expect.stringMatching(/https:\/\/.+/),
error: apiError,
errorType: bizErrorType,
provider,
});
}
});
it('should throw InvalidProviderAPIKey if no accountID is provided', async () => {
try {
new LobeCloudflareAI({
apiKey: 'test',
});
} catch (e) {
expect(e).toEqual({ errorType: invalidErrorType });
}
});
it('should throw InvalidProviderAPIKey if no apiKey is provided', async () => {
try {
new LobeCloudflareAI({
baseURLOrAccountID: accountID,
});
} catch (e) {
expect(e).toEqual({ errorType: invalidErrorType });
}
});
it('should not throw Error when apiKey is not provided but baseURL is provided', async () => {
const customInstance = new LobeCloudflareAI({
baseURLOrAccountID: 'https://custom.cloudflare.url/',
});
expect(customInstance).toBeInstanceOf(LobeCloudflareAI);
expect(customInstance.apiKey).toBeUndefined();
expect(customInstance.baseURL).toBe('https://custom.cloudflare.url/');
});
});
describe('Error handling', () => {
it('should throw ProviderBizError on other error status codes', async () => {
// Arrange
const apiError = { status: 400 };
(globalThis.fetch as Mock).mockRejectedValue(apiError);
// Act & Assert
await expect(
instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: '@hf/meta-llama/meta-llama-3-8b-instruct',
temperature: 1,
}),
).rejects.toEqual({
endpoint: expect.stringMatching(/https:\/\/.+/),
error: apiError,
errorType: bizErrorType,
provider,
});
});
it('should desensitize accountID in error message', async () => {
// Arrange
const apiError = { status: 400 };
const customInstance = new LobeCloudflareAI({
apiKey: 'test',
baseURLOrAccountID: accountID,
});
(globalThis.fetch as Mock).mockRejectedValue(apiError);
// Act & Assert
await expect(
customInstance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: '@hf/meta-llama/meta-llama-3-8b-instruct',
temperature: 0,
}),
).rejects.toEqual({
endpoint: expect.not.stringContaining(accountID),
error: apiError,
errorType: bizErrorType,
provider,
});
});
});
describe('Options', () => {
it('should pass signal to API call', async () => {
// Arrange
const controller = new AbortController();
// Act
await instance.chat(
{
messages: [{ content: 'Hello', role: 'user' }],
model: '@hf/meta-llama/meta-llama-3-8b-instruct',
temperature: 1,
},
{ signal: controller.signal },
);
// Assert
expect(globalThis.fetch).toHaveBeenCalledWith(
expect.any(URL),
expect.objectContaining({ signal: controller.signal }),
);
});
it('should apply callback to the returned stream', async () => {
// Arrange
const callback = vi.fn();
// Act
await instance.chat(
{
messages: [{ content: 'Hello', role: 'user' }],
model: '@hf/meta-llama/meta-llama-3-8b-instruct',
temperature: 0,
},
{
callback: { onStart: callback },
},
);
// Assert
expect(callback).toHaveBeenCalled();
});
it('should set headers on the response', async () => {
// Arrange
const headers = { 'X-Test-Header': 'test' };
// Act
const result = await instance.chat(
{
messages: [{ content: 'Hello', role: 'user' }],
model: '@hf/meta-llama/meta-llama-3-8b-instruct',
temperature: 1,
},
{ headers },
);
// Assert
expect(result.headers.get('X-Test-Header')).toBe('test');
});
});
describe('Edge cases', () => {
it('should handle empty messages array', async () => {
// Act & Assert
await expect(
instance.chat({
messages: [],
model: '@hf/meta-llama/meta-llama-3-8b-instruct',
temperature: 1,
}),
).resolves.toBeInstanceOf(Response);
});
});
});
describe('models', () => {
it('should send request', async () => {
// Arrange
const apiKey = 'test_api_key';
const instance = new LobeCloudflareAI({ apiKey, baseURLOrAccountID: accountID });
vi.spyOn(globalThis, 'fetch').mockResolvedValue(
new Response(
JSON.stringify({
result: [
{
description: 'Model 1',
name: 'model1',
task: { name: 'Text Generation' },
properties: [{ property_id: 'beta', value: 'false' }],
},
{
description: 'Model 2',
name: 'model2',
task: { name: 'Text Generation' },
properties: [{ property_id: 'beta', value: 'true' }],
},
],
}),
),
);
// Act
const result = await instance.models();
// Assert
expect(globalThis.fetch).toHaveBeenCalledWith(
`https://api.cloudflare.com/client/v4/accounts/${accountID}/ai/models/search`,
{
headers: {
'Authorization': `Bearer ${apiKey}`,
'Content-Type': 'application/json',
},
method: 'GET',
},
);
expect(result).toHaveLength(2);
});
it('should set id to name', async () => {
// Arrange
const instance = new LobeCloudflareAI({
apiKey: 'test_api_key',
baseURLOrAccountID: accountID,
});
vi.spyOn(globalThis, 'fetch').mockResolvedValue(
new Response(
JSON.stringify({
result: [
{
id: 'id1',
name: 'name1',
task: { name: 'Text Generation' },
},
],
}),
),
);
// Act
const result = await instance.models();
// Assert
expect(result).toEqual([
expect.objectContaining({
displayName: 'name1',
id: 'name1',
}),
]);
});
it('should filter text generation models', async () => {
// Arrange
const instance = new LobeCloudflareAI({
apiKey: 'test_api_key',
baseURLOrAccountID: accountID,
});
vi.spyOn(globalThis, 'fetch').mockResolvedValue(
new Response(
JSON.stringify({
result: [
{
id: '1',
name: 'model1',
task: { name: 'Text Generation' },
},
{
id: '2',
name: 'model2',
task: { name: 'Text Classification' },
},
],
}),
),
);
// Act
const result = await instance.models();
// Assert
expect(result).toEqual([
expect.objectContaining({
displayName: 'model1',
id: 'model1',
}),
]);
});
it('should enable non-beta models and mark beta models', async () => {
// Arrange
const instance = new LobeCloudflareAI({
apiKey: 'test_api_key',
baseURLOrAccountID: accountID,
});
vi.spyOn(globalThis, 'fetch').mockResolvedValue(
new Response(
JSON.stringify({
result: [
{
id: '1',
name: 'model1',
task: { name: 'Text Generation' },
properties: [{ property_id: 'beta', value: 'false' }],
},
{
id: '2',
name: 'model2',
task: { name: 'Text Generation' },
properties: [{ property_id: 'beta', value: 'true' }],
},
],
}),
),
);
// Act
const result = await instance.models();
// Assert
expect(result).toEqual([
expect.objectContaining({
displayName: 'model1',
enabled: true,
id: 'model1',
}),
expect.objectContaining({
displayName: 'model2 (Beta)',
enabled: false,
id: 'model2',
}),
]);
});
});
});

View file

@ -0,0 +1,123 @@
import { ChatModelCard } from '@/types/llm';
import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types';
import {
CloudflareStreamTransformer,
DEFAULT_BASE_URL_PREFIX,
convertModelManifest,
desensitizeCloudflareUrl,
fillUrl,
} from '../utils/cloudflareHelpers';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { StreamingResponse } from '../utils/response';
import { createCallbacksTransformer } from '../utils/streams';
export interface LobeCloudflareParams {
apiKey?: string;
baseURLOrAccountID?: string;
}
export class LobeCloudflareAI implements LobeRuntimeAI {
baseURL: string;
accountID: string;
apiKey?: string;
constructor({ apiKey, baseURLOrAccountID }: LobeCloudflareParams) {
if (!baseURLOrAccountID) {
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);
}
if (baseURLOrAccountID.startsWith('http')) {
this.baseURL = baseURLOrAccountID.endsWith('/')
? baseURLOrAccountID
: baseURLOrAccountID + '/';
// Try get accountID from baseURL
this.accountID = baseURLOrAccountID.replaceAll(/^.*\/([\dA-Fa-f]{32})\/.*$/g, '$1');
} else {
if (!apiKey) {
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);
}
this.accountID = baseURLOrAccountID;
this.baseURL = fillUrl(baseURLOrAccountID);
}
this.apiKey = apiKey;
}
async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions): Promise<Response> {
try {
const { model, tools, ...restPayload } = payload;
const functions = tools?.map((tool) => tool.function);
const headers = options?.headers || {};
if (this.apiKey) {
headers['Authorization'] = `Bearer ${this.apiKey}`;
}
const url = new URL(model, this.baseURL);
const response = await fetch(url, {
body: JSON.stringify({ tools: functions, ...restPayload }),
headers: { 'Content-Type': 'application/json', ...headers },
method: 'POST',
signal: options?.signal,
});
const desensitizedEndpoint = desensitizeCloudflareUrl(url.toString());
switch (response.status) {
case 400: {
throw AgentRuntimeError.chat({
endpoint: desensitizedEndpoint,
error: response,
errorType: AgentRuntimeErrorType.ProviderBizError,
provider: ModelProvider.Cloudflare,
});
}
}
// Only tee when debugging
let responseBody: ReadableStream;
if (process.env.DEBUG_CLOUDFLARE_CHAT_COMPLETION === '1') {
const [prod, useForDebug] = response.body!.tee();
debugStream(useForDebug).catch();
responseBody = prod;
} else {
responseBody = response.body!;
}
return StreamingResponse(
responseBody
.pipeThrough(new TransformStream(new CloudflareStreamTransformer()))
.pipeThrough(createCallbacksTransformer(options?.callback)),
{ headers: options?.headers },
);
} catch (error) {
const desensitizedEndpoint = desensitizeCloudflareUrl(this.baseURL);
throw AgentRuntimeError.chat({
endpoint: desensitizedEndpoint,
error: error as any,
errorType: AgentRuntimeErrorType.ProviderBizError,
provider: ModelProvider.Cloudflare,
});
}
}
async models(): Promise<ChatModelCard[]> {
const url = `${DEFAULT_BASE_URL_PREFIX}/client/v4/accounts/${this.accountID}/ai/models/search`;
const response = await fetch(url, {
headers: {
'Authorization': `Bearer ${this.apiKey}`,
'Content-Type': 'application/json',
},
method: 'GET',
});
const j = await response.json();
const models: any[] = j['result'].filter(
(model: any) => model['task']['name'] === 'Text Generation',
);
const chatModels: ChatModelCard[] = models
.map((model) => convertModelManifest(model))
.sort((a, b) => a.displayName.localeCompare(b.displayName));
return chatModels;
}
}

View file

@ -28,6 +28,7 @@ export enum ModelProvider {
Azure = 'azure',
Baichuan = 'baichuan',
Bedrock = 'bedrock',
Cloudflare = 'cloudflare',
DeepSeek = 'deepseek',
FireworksAI = 'fireworksai',
Github = 'github',

View file

@ -0,0 +1,339 @@
// @vitest-environment node
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import * as desensitizeTool from '../utils/desensitizeUrl';
import {
CloudflareStreamTransformer,
desensitizeCloudflareUrl,
fillUrl,
getModelBeta,
getModelDisplayName,
getModelFunctionCalling,
getModelTokens,
} from './cloudflareHelpers';
//const {
// getModelBeta,
// getModelDisplayName,
// getModelFunctionCalling,
// getModelTokens,
//} = require('./cloudflareHelpers');
//const cloudflareHelpers = require('./cloudflareHelpers');
//const getModelBeta = cloudflareHelpers.__get__('getModelBeta');
//const getModelDisplayName = cloudflareHelpers.__get__('getModelDisplayName');
//const getModelFunctionCalling = cloudflareHelpers.__get__('getModelFunctionCalling');
//const getModelTokens = cloudflareHelpers.__get__('getModelTokens');
afterEach(() => {
vi.restoreAllMocks();
});
describe('cloudflareHelpers', () => {
describe('CloudflareStreamTransformer', () => {
let transformer: CloudflareStreamTransformer;
beforeEach(() => {
transformer = new CloudflareStreamTransformer();
});
describe('parseChunk', () => {
let chunks: string[];
let controller: TransformStreamDefaultController;
beforeEach(() => {
chunks = [];
controller = Object.create(TransformStreamDefaultController.prototype);
vi.spyOn(controller, 'enqueue').mockImplementation((chunk) => {
chunks.push(chunk);
});
});
it('should parse chunk', () => {
// Arrange
const chunk = 'data: {"key": "value", "response": "response1"}';
const textDecoder = new TextDecoder();
// Act
transformer['parseChunk'](chunk, controller);
// Assert
expect(chunks.length).toBe(2);
expect(chunks[0]).toBe('event: text\n');
expect(chunks[1]).toBe('data: "response1"\n\n');
});
it('should not replace `data` in text', () => {
// Arrange
const chunk = 'data: {"key": "value", "response": "data: a"}';
const textDecoder = new TextDecoder();
// Act
transformer['parseChunk'](chunk, controller);
// Assert
expect(chunks.length).toBe(2);
expect(chunks[0]).toBe('event: text\n');
expect(chunks[1]).toBe('data: "data: a"\n\n');
});
});
describe('transform', () => {
const textDecoder = new TextDecoder();
const textEncoder = new TextEncoder();
let chunks: string[];
beforeEach(() => {
chunks = [];
vi.spyOn(
transformer as any as {
parseChunk: (chunk: string, controller: TransformStreamDefaultController) => void;
},
'parseChunk',
).mockImplementation((chunk: string, _) => {
chunks.push(chunk);
});
});
it('should split single chunk', async () => {
// Arrange
const chunk = textEncoder.encode('data: {"key": "value", "response": "response1"}\n\n');
// Act
await transformer.transform(chunk, undefined!);
// Assert
expect(chunks.length).toBe(1);
expect(chunks[0]).toBe('data: {"key": "value", "response": "response1"}');
});
it('should split multiple chunks', async () => {
// Arrange
const chunk = textEncoder.encode(
'data: {"key": "value", "response": "response1"}\n\n' +
'data: {"key": "value", "response": "response2"}\n\n',
);
// Act
await transformer.transform(chunk, undefined!);
// Assert
expect(chunks.length).toBe(2);
expect(chunks[0]).toBe('data: {"key": "value", "response": "response1"}');
expect(chunks[1]).toBe('data: {"key": "value", "response": "response2"}');
});
it('should ignore empty chunk', async () => {
// Arrange
const chunk = textEncoder.encode('\n\n');
// Act
await transformer.transform(chunk, undefined!);
// Assert
expect(chunks.join()).toBe('');
});
it('should split and concat delayed chunks', async () => {
// Arrange
const chunk1 = textEncoder.encode('data: {"key": "value", "respo');
const chunk2 = textEncoder.encode('nse": "response1"}\n\ndata: {"key": "val');
const chunk3 = textEncoder.encode('ue", "response": "response2"}\n\n');
// Act & Assert
await transformer.transform(chunk1, undefined!);
expect(transformer['parseChunk']).not.toHaveBeenCalled();
expect(chunks.length).toBe(0);
expect(transformer['buffer']).toBe('data: {"key": "value", "respo');
await transformer.transform(chunk2, undefined!);
expect(chunks.length).toBe(1);
expect(chunks[0]).toBe('data: {"key": "value", "response": "response1"}');
expect(transformer['buffer']).toBe('data: {"key": "val');
await transformer.transform(chunk3, undefined!);
expect(chunks.length).toBe(2);
expect(chunks[1]).toBe('data: {"key": "value", "response": "response2"}');
expect(transformer['buffer']).toBe('');
});
it('should ignore standalone [DONE]', async () => {
// Arrange
const chunk = textEncoder.encode('data: [DONE]\n\n');
// Act
await transformer.transform(chunk, undefined!);
// Assert
expect(transformer['parseChunk']).not.toHaveBeenCalled();
expect(chunks.length).toBe(0);
expect(transformer['buffer']).toBe('');
});
it('should ignore [DONE] in chunk', async () => {
// Arrange
const chunk = textEncoder.encode(
'data: {"key": "value", "response": "response1"}\n\ndata: [DONE]\n\n',
);
// Act
await transformer.transform(chunk, undefined!);
// Assert
expect(chunks.length).toBe(1);
expect(chunks[0]).toBe('data: {"key": "value", "response": "response1"}');
expect(transformer['buffer']).toBe('');
});
});
});
describe('fillUrl', () => {
it('should return URL with account id', () => {
const url = fillUrl('80009000a000b000c000d000e000f000');
expect(url).toBe(
'https://api.cloudflare.com/client/v4/accounts/80009000a000b000c000d000e000f000/ai/run/',
);
});
});
describe('maskAccountId', () => {
describe('desensitizeAccountId', () => {
it('should replace account id with **** in official API endpoint', () => {
const url =
'https://api.cloudflare.com/client/v4/accounts/80009000a000b000c000d000e000f000/ai/run/';
const maskedUrl = desensitizeCloudflareUrl(url);
expect(maskedUrl).toBe('https://api.cloudflare.com/client/v4/accounts/****/ai/run/');
});
it('should replace account id with **** in custom API endpoint', () => {
const url =
'https://api.cloudflare.com/custom/prefix/80009000a000b000c000d000e000f000/custom/suffix/';
const maskedUrl = desensitizeCloudflareUrl(url);
expect(maskedUrl).toBe('https://api.cloudflare.com/custom/prefix/****/custom/suffix/');
});
});
describe('desensitizeCloudflareUrl', () => {
it('should mask account id in official API endpoint', () => {
const url =
'https://api.cloudflare.com/client/v4/accounts/80009000a000b000c000d000e000f000/ai/run/';
const maskedUrl = desensitizeCloudflareUrl(url);
expect(maskedUrl).toBe('https://api.cloudflare.com/client/v4/accounts/****/ai/run/');
});
it('should call desensitizeUrl for custom API endpoint', () => {
const url = 'https://custom.url/path';
vi.spyOn(desensitizeTool, 'desensitizeUrl').mockImplementation(
(_) => 'https://custom.mocked.url',
);
const maskedUrl = desensitizeCloudflareUrl(url);
expect(desensitizeTool.desensitizeUrl).toHaveBeenCalledWith('https://custom.url');
expect(maskedUrl).toBe('https://custom.mocked.url/path');
});
it('should mask account id in custom API endpoint', () => {
const url =
'https://custom.url/custom/prefix/80009000a000b000c000d000e000f000/custom/suffix/';
const maskedUrl = desensitizeCloudflareUrl(url);
expect(maskedUrl).toBe('https://cu****om.url/custom/prefix/****/custom/suffix/');
});
it('should mask account id in custom API endpoint with query params', () => {
const url =
'https://custom.url/custom/prefix/80009000a000b000c000d000e000f000/custom/suffix/?query=param';
const maskedUrl = desensitizeCloudflareUrl(url);
expect(maskedUrl).toBe(
'https://cu****om.url/custom/prefix/****/custom/suffix/?query=param',
);
});
it('should mask account id in custom API endpoint with port', () => {
const url =
'https://custom.url:8080/custom/prefix/80009000a000b000c000d000e000f000/custom/suffix/';
const maskedUrl = desensitizeCloudflareUrl(url);
expect(maskedUrl).toBe('https://cu****om.url:****/custom/prefix/****/custom/suffix/');
});
});
});
describe('modelManifest', () => {
describe('getModelBeta', () => {
it('should get beta property', () => {
const model = { properties: [{ property_id: 'beta', value: 'true' }] };
const beta = getModelBeta(model);
expect(beta).toBe(true);
});
it('should return false if beta property is false', () => {
const model = { properties: [{ property_id: 'beta', value: 'false' }] };
const beta = getModelBeta(model);
expect(beta).toBe(false);
});
it('should return false if beta property is not present', () => {
const model = { properties: [] };
const beta = getModelBeta(model);
expect(beta).toBe(false);
});
});
describe('getModelDisplayName', () => {
it('should return display name with beta suffix', () => {
const model = { name: 'model', properties: [{ property_id: 'beta', value: 'true' }] };
const name = getModelDisplayName(model, true);
expect(name).toBe('model (Beta)');
});
it('should return display name without beta suffix', () => {
const model = { name: 'model', properties: [] };
const name = getModelDisplayName(model, false);
expect(name).toBe('model');
});
it('should return model["name"]', () => {
const model = { id: 'modelID', name: 'modelName' };
const name = getModelDisplayName(model, false);
expect(name).toBe('modelName');
});
it('should return last part of model["name"]', () => {
const model = { name: '@provider/modelFamily/modelName' };
const name = getModelDisplayName(model, false);
expect(name).toBe('modelName');
});
});
describe('getModelFunctionCalling', () => {
it('should return true if function_calling property is true', () => {
const model = { properties: [{ property_id: 'function_calling', value: 'true' }] };
const functionCalling = getModelFunctionCalling(model);
expect(functionCalling).toBe(true);
});
it('should return false if function_calling property is false', () => {
const model = { properties: [{ property_id: 'function_calling', value: 'false' }] };
const functionCalling = getModelFunctionCalling(model);
expect(functionCalling).toBe(false);
});
it('should return false if function_calling property is not set', () => {
const model = { properties: [] };
const functionCalling = getModelFunctionCalling(model);
expect(functionCalling).toBe(false);
});
});
describe('getModelTokens', () => {
it('should return tokens property value', () => {
const model = { properties: [{ property_id: 'max_total_tokens', value: '100' }] };
const tokens = getModelTokens(model);
expect(tokens).toBe(100);
});
it('should return undefined if tokens property is not present', () => {
const model = { properties: [] };
const tokens = getModelTokens(model);
expect(tokens).toBeUndefined();
});
});
});
});

View file

@ -0,0 +1,134 @@
import { desensitizeUrl } from '../utils/desensitizeUrl';
class CloudflareStreamTransformer {
private textDecoder = new TextDecoder();
private buffer: string = '';
private parseChunk(chunk: string, controller: TransformStreamDefaultController) {
const dataPrefix = /^data: /;
const json = chunk.replace(dataPrefix, '');
const parsedChunk = JSON.parse(json);
controller.enqueue(`event: text\n`);
controller.enqueue(`data: ${JSON.stringify(parsedChunk.response)}\n\n`);
}
public async transform(chunk: Uint8Array, controller: TransformStreamDefaultController) {
let textChunk = this.textDecoder.decode(chunk);
if (this.buffer.trim() !== '') {
textChunk = this.buffer + textChunk;
this.buffer = '';
}
const splits = textChunk.split('\n\n');
for (let i = 0; i < splits.length - 1; i++) {
if (/\[DONE]/.test(splits[i].trim())) {
return;
}
this.parseChunk(splits[i], controller);
}
const lastChunk = splits.at(-1)!;
if (lastChunk.trim() !== '') {
this.buffer += lastChunk; // does not need to be trimmed.
} // else drop.
}
}
const CF_PROPERTY_NAME = 'property_id';
const DEFAULT_BASE_URL_PREFIX = 'https://api.cloudflare.com';
function fillUrl(accountID: string): string {
return `${DEFAULT_BASE_URL_PREFIX}/client/v4/accounts/${accountID}/ai/run/`;
}
function desensitizeAccountId(path: string): string {
return path.replace(/\/[\dA-Fa-f]{32}\//, '/****/');
}
function desensitizeCloudflareUrl(url: string): string {
const urlObj = new URL(url);
let { protocol, hostname, port, pathname, search } = urlObj;
if (url.startsWith(DEFAULT_BASE_URL_PREFIX)) {
return `${protocol}//${hostname}${port ? `:${port}` : ''}${desensitizeAccountId(pathname)}${search}`;
} else {
const desensitizedUrl = desensitizeUrl(`${protocol}//${hostname}${port ? `:${port}` : ''}`);
if (desensitizedUrl.endsWith('/') && pathname.startsWith('/')) {
pathname = pathname.slice(1);
}
return `${desensitizedUrl}${desensitizeAccountId(pathname)}${search}`;
}
}
function getModelBeta(model: any): boolean {
try {
const betaProperty = model['properties'].filter(
(property: any) => property[CF_PROPERTY_NAME] === 'beta',
);
if (betaProperty.length === 1) {
return betaProperty[0]['value'] === 'true'; // This is a string now.
}
return false;
} catch {
return false;
}
}
function getModelDisplayName(model: any, beta: boolean): string {
const modelId = model['name'];
let name = modelId.split('/').at(-1)!;
if (beta) {
name += ' (Beta)';
}
return name;
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars, unused-imports/no-unused-vars
function getModelFunctionCalling(model: any): boolean {
try {
const fcProperty = model['properties'].filter(
(property: any) => property[CF_PROPERTY_NAME] === 'function_calling',
);
if (fcProperty.length === 1) {
return fcProperty[0]['value'] === 'true';
}
return false;
} catch {
return false;
}
}
function getModelTokens(model: any): number | undefined {
try {
const tokensProperty = model['properties'].filter(
(property: any) => property[CF_PROPERTY_NAME] === 'max_total_tokens',
);
if (tokensProperty.length === 1) {
return parseInt(tokensProperty[0]['value']);
}
return undefined;
} catch {
return undefined;
}
}
function convertModelManifest(model: any) {
const modelBeta = getModelBeta(model);
return {
description: model['description'],
displayName: getModelDisplayName(model, modelBeta),
enabled: !modelBeta,
functionCall: false, //getModelFunctionCalling(model),
id: model['name'],
tokens: getModelTokens(model),
};
}
export {
CloudflareStreamTransformer,
convertModelManifest,
DEFAULT_BASE_URL_PREFIX,
desensitizeCloudflareUrl,
fillUrl,
getModelBeta,
getModelDisplayName,
getModelFunctionCalling,
getModelTokens,
};

View file

@ -21,7 +21,7 @@ export default {
},
bedrock: {
accessKeyId: {
desc: '填入AWS Access Key Id',
desc: '填入 AWS Access Key Id',
placeholder: 'AWS Access Key Id',
title: 'AWS Access Key Id',
},
@ -52,6 +52,18 @@ export default {
title: '使用自定义 Bedrock 鉴权信息',
},
},
cloudflare: {
apiKey: {
desc: '请填写 Cloudflare API Key',
placeholder: 'Cloudflare API Key',
title: 'Cloudflare API Key',
},
baseURLOrAccountID: {
desc: '填入 Cloudflare 账户 ID 或 自定义 API 地址',
placeholder: 'Cloudflare Account ID / custom API URL',
title: 'Cloudflare 账户 ID / API 地址',
}
},
github: {
personalAccessToken: {
desc: '填入你的 Github PAT点击 [这里](https://github.com/settings/tokens) 创建',

View file

@ -99,6 +99,9 @@ export const getServerGlobalConfig = () => {
BAICHUAN_MODEL_LIST,
ENABLED_TAICHU,
ENABLED_CLOUDFLARE,
TAICHU_MODEL_LIST,
ENABLED_AI21,
@ -202,6 +205,7 @@ export const getServerGlobalConfig = () => {
modelString: AWS_BEDROCK_MODEL_LIST,
}),
},
cloudflare: { enabled: ENABLED_CLOUDFLARE },
deepseek: {
enabled: ENABLED_DEEPSEEK,
enabledModels: extractEnabledModels(DEEPSEEK_MODEL_LIST),

View file

@ -210,6 +210,17 @@ const getLlmOptionsFromPayload = (provider: string, payload: JWTPayload) => {
return { apiKey };
}
case ModelProvider.Cloudflare: {
const { CLOUDFLARE_API_KEY, CLOUDFLARE_BASE_URL_OR_ACCOUNT_ID } = getLLMConfig();
const apiKey = apiKeyManager.pick(payload?.apiKey || CLOUDFLARE_API_KEY);
const baseURLOrAccountID =
payload.apiKey && payload.cloudflareBaseURLOrAccountID
? payload.cloudflareBaseURLOrAccountID
: CLOUDFLARE_BASE_URL_OR_ACCOUNT_ID;
return { apiKey, baseURLOrAccountID };
}
case ModelProvider.Ai360: {
const { AI360_API_KEY } = getLLMConfig();

View file

@ -69,6 +69,15 @@ export const getProviderAuthPayload = (provider: string) => {
return { endpoint: config?.baseURL };
}
case ModelProvider.Cloudflare: {
const config = keyVaultsConfigSelectors.cloudflareConfig(useUserStore.getState());
return {
apiKey: config?.apiKey,
cloudflareBaseURLOrAccountID: config?.baseURLOrAccountID,
};
}
default: {
const config = keyVaultsConfigSelectors.getVaultByProvider(provider as GlobalLLMProviderKey)(
useUserStore.getState(),

View file

@ -175,6 +175,13 @@ export function initializeWithClientStore(provider: string, payload: any) {
case ModelProvider.ZeroOne: {
break;
}
case ModelProvider.Cloudflare: {
providerOptions = {
apikey: providerAuthPayload?.apiKey,
baseURLOrAccountID: providerAuthPayload?.cloudflareBaseURLOrAccountID,
};
break;
}
}
/**

View file

@ -18,6 +18,7 @@ const wenxinConfig = (s: UserStore) => keyVaultsSettings(s).wenxin || {};
const ollamaConfig = (s: UserStore) => keyVaultsSettings(s).ollama || {};
const sensenovaConfig = (s: UserStore) => keyVaultsSettings(s).sensenova || {};
const azureConfig = (s: UserStore) => keyVaultsSettings(s).azure || {};
const cloudflareConfig = (s: UserStore) => keyVaultsSettings(s).cloudflare || {};
const getVaultByProvider = (provider: GlobalLLMProviderKey) => (s: UserStore) =>
(keyVaultsSettings(s)[provider] || {}) as OpenAICompatibleKeyVault &
AzureOpenAIKeyVault &
@ -38,6 +39,7 @@ const password = (s: UserStore) => keyVaultsSettings(s).password || '';
export const keyVaultsConfigSelectors = {
azureConfig,
bedrockConfig,
cloudflareConfig,
getVaultByProvider,
isProviderApiKeyNotEmpty,
isProviderEndpointNotEmpty,

View file

@ -69,6 +69,7 @@ const openAIConfig = (s: UserStore) => currentLLMSettings(s).openai;
const bedrockConfig = (s: UserStore) => currentLLMSettings(s).bedrock;
const ollamaConfig = (s: UserStore) => currentLLMSettings(s).ollama;
const azureConfig = (s: UserStore) => currentLLMSettings(s).azure;
const cloudflareConfig = (s: UserStore) => currentLLMSettings(s).cloudflare;
const sensenovaConfig = (s: UserStore) => currentLLMSettings(s).sensenova;
const isAzureEnabled = (s: UserStore) => currentLLMSettings(s).azure.enabled;
@ -76,6 +77,7 @@ const isAzureEnabled = (s: UserStore) => currentLLMSettings(s).azure.enabled;
export const modelConfigSelectors = {
azureConfig,
bedrockConfig,
cloudflareConfig,
currentEditingCustomModelCard,
getCustomModelCard,

View file

@ -16,6 +16,11 @@ export interface AWSBedrockKeyVault {
sessionToken?: string;
}
export interface CloudflareKeyVault {
apiKey?: string;
baseURLOrAccountID?: string;
}
export interface SenseNovaKeyVault {
sensenovaAccessKeyID?: string;
sensenovaAccessKeySecret?: string;
@ -33,6 +38,7 @@ export interface UserKeyVaults {
azure?: AzureOpenAIKeyVault;
baichuan?: OpenAICompatibleKeyVault;
bedrock?: AWSBedrockKeyVault;
cloudflare?: CloudflareKeyVault;
deepseek?: OpenAICompatibleKeyVault;
fireworksai?: OpenAICompatibleKeyVault;
github?: OpenAICompatibleKeyVault;