simplify gemini

This commit is contained in:
Andrew 2024-10-30 22:45:44 -07:00
parent 8cd1ea305c
commit a8010eec15
4 changed files with 54 additions and 67 deletions

View file

@ -7,12 +7,10 @@
"": {
"name": "void",
"version": "0.0.1",
"dependencies": {
"@google/generative-ai": "^0.21.0"
},
"devDependencies": {
"@anthropic-ai/sdk": "^0.29.2",
"@eslint/js": "^9.9.1",
"@google/generative-ai": "^0.21.0",
"@monaco-editor/react": "^4.6.0",
"@rrweb/types": "^2.0.0-alpha.17",
"@types/diff": "^5.2.2",
@ -744,6 +742,7 @@
"version": "0.21.0",
"resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.21.0.tgz",
"integrity": "sha512-7XhUbtnlkSEZK15kN3t+tzIMxsbKm/dSkKBFalj+20NvPKe1kBY7mR2P7vuijEn+f06z5+A8bVGKO0v39cr6Wg==",
"dev": true,
"license": "Apache-2.0",
"engines": {
"node": ">=18.0.0"

View file

@ -111,6 +111,7 @@
"devDependencies": {
"@anthropic-ai/sdk": "^0.29.2",
"@eslint/js": "^9.9.1",
"@google/generative-ai": "^0.21.0",
"@monaco-editor/react": "^4.6.0",
"@rrweb/types": "^2.0.0-alpha.17",
"@types/diff": "^5.2.2",
@ -151,8 +152,5 @@
"typescript": "5.5.4",
"typescript-eslint": "^8.3.0",
"uuid": "^10.0.0"
},
"dependencies": {
"@google/generative-ai": "^0.21.0"
}
}

View file

@ -1,7 +1,7 @@
import Anthropic from '@anthropic-ai/sdk';
import OpenAI from 'openai';
import { Ollama } from 'ollama/browser'
import { GoogleGenerativeAI } from '@google/generative-ai';
import { Content, GoogleGenerativeAI, GoogleGenerativeAIError, GoogleGenerativeAIFetchError } from '@google/generative-ai';
import { VoidConfig } from '../webviews/common/contextForConfig'
export type AbortRef = { current: (() => void) | null }
@ -64,7 +64,7 @@ const sendAnthropicMsg: SendLLMMessageFnTypeInternal = ({ messages, onText, onFi
system: systemMessage,
messages: anthropicMessages,
model: voidConfig.anthropic.model,
max_tokens: parseInt(voidConfig.default.maxTokens),
max_tokens: parseMaxTokensStr(voidConfig.default.maxTokens)!, // this might be undefined, but it will just throw an error for the user
});
let did_abort = false
@ -104,6 +104,7 @@ const sendAnthropicMsg: SendLLMMessageFnTypeInternal = ({ messages, onText, onFi
// Gemini
const sendGeminiMsg: SendLLMMessageFnTypeInternal = async ({ messages, onText, onFinalMessage, onError, voidConfig, abortRef }) => {
let didAbort = false
let fullText = ''
@ -111,63 +112,52 @@ const sendGeminiMsg: SendLLMMessageFnTypeInternal = async ({ messages, onText, o
didAbort = true
}
try {
const genAI = new GoogleGenerativeAI(voidConfig.Gemini.apikey);
// Force the model to be exactly what's configured
const modelName = voidConfig.Gemini.model;
const model = genAI.getGenerativeModel({ model: modelName });
const genAI = new GoogleGenerativeAI(voidConfig.gemini.apikey);
const model = genAI.getGenerativeModel({ model: voidConfig.gemini.model });
// Combine system messages with the first user message
let systemContent = messages
.filter(msg => msg.role === 'system')
.map(msg => msg.content)
.join('\n');
// Filter out system messages and modify first user message if needed
let geminiMessages = messages.filter(msg => msg.role !== 'system');
if (systemContent && geminiMessages.length > 0 && geminiMessages[0].role === 'user') {
geminiMessages[0] = {
...geminiMessages[0],
content: `${systemContent}\n\n${geminiMessages[0].content}`
};
}
// remove system messages that get sent to Gemini
// str of all system messages
let systemMessage = messages
.filter(msg => msg.role === 'system')
.map(msg => msg.content)
.join('\n');
// Convert remaining messages to Gemini format
const history = geminiMessages.map(msg => ({
role: msg.role === 'assistant' ? 'model' : msg.role,
parts: [{ text: msg.content }]
}));
// Convert messages to Gemini format
const geminiMessages: Content[] = messages
.filter(msg => msg.role !== 'system')
.map((msg, i) => ({
parts: [{ text: msg.content }],
role: msg.role === 'assistant' ? 'model' : 'user'
}))
const chat = model.startChat({
history: history.slice(0, -1), // Exclude last message
generationConfig: {
maxOutputTokens: parseInt(voidConfig.default.maxTokens)
// Removed model from generationConfig since it's not a valid property
// Model is already set when creating the model instance above
model.generateContentStream({ contents: geminiMessages, systemInstruction: systemMessage, })
.then(async response => {
abortRef.current = () => {
// response.stream.return(fullText)
didAbort = true;
}
});
const lastMessage = messages[messages.length - 1].content;
const result = await chat.sendMessageStream(lastMessage);
for await (const chunk of result.stream) {
if (didAbort) return;
const newText = chunk.text();
fullText += newText;
onText(newText, fullText);
}
onFinalMessage(fullText);
} catch (error: unknown) {
if (error instanceof Error && error.message?.includes('API key')) {
onError('Invalid API key.');
} else if (error instanceof Error) {
onError(error.message || 'Error connecting to Gemini');
} else {
onError('Error connecting to Gemini');
}
}
};
for await (const chunk of response.stream) {
if (didAbort) return;
const newText = chunk.text();
fullText += newText;
onText(newText, fullText);
}
onFinalMessage(fullText);
})
.catch((error) => {
if (error instanceof GoogleGenerativeAIFetchError) {
if (error.status === 400) {
onError('Invalid API key.');
}
else {
onError(`${error.name}:\n${error.message}`);
}
}
else {
onError(error);
}
})
}
// OpenAI, OpenRouter, OpenAICompatible
const sendOpenAIMsg: SendLLMMessageFnTypeInternal = ({ messages, onText, onFinalMessage, onError, voidConfig, abortRef }) => {
@ -231,7 +221,7 @@ const sendOpenAIMsg: SendLLMMessageFnTypeInternal = ({ messages, onText, onFinal
onError('Invalid API key.');
}
else {
onError(error.message);
onError(`${error.name}:\n${error.message}`);
}
}
else {
@ -258,7 +248,7 @@ export const sendOllamaMsg: SendLLMMessageFnTypeInternal = ({ messages, onText,
model: voidConfig.ollama.model,
messages: messages,
stream: true,
options: { num_predict: parseInt(voidConfig.default.maxTokens) } // this is max_tokens
options: { num_predict: parseMaxTokensStr(voidConfig.default.maxTokens) } // this is max_tokens
})
.then(async stream => {
abortRef.current = () => {
@ -363,12 +353,12 @@ export const sendLLMMessage: SendLLMMessageFnTypeExternal = ({ messages, onText,
case 'openRouter':
case 'openAICompatible':
return sendOpenAIMsg({ messages, onText, onFinalMessage, onError, voidConfig, abortRef });
case 'gemini':
return sendGeminiMsg({ messages, onText, onFinalMessage, onError, voidConfig, abortRef });
case 'ollama':
return sendOllamaMsg({ messages, onText, onFinalMessage, onError, voidConfig, abortRef });
case 'greptile':
return sendGreptileMsg({ messages, onText, onFinalMessage, onError, voidConfig, abortRef });
case 'gemini':
return sendGeminiMsg({ messages, onText, onFinalMessage, onError, voidConfig, abortRef });
default:
onError(`Error: whichApi was ${voidConfig.default.whichApi}, which is not recognized!`)
}

View file

@ -21,12 +21,12 @@ const configString = (description: string, defaultVal: string) => {
export const configFields = [
'anthropic',
'openAI',
'gemini',
'greptile',
'ollama',
'openRouter',
'openAICompatible',
'azure',
'Gemini'
] as const
@ -165,7 +165,7 @@ const voidConfigInfo: Record<
// }
// },
},
Gemini: {
gemini: {
apikey: configString('Google API key.', ''),
model: configEnum(
'Gemini model to use.',