refactor to have openAI conventions

This commit is contained in:
Andrew Pareles 2024-11-17 02:19:21 -08:00
parent 91c65b6060
commit 5140c79dcf

View file

@ -3,7 +3,7 @@ import OpenAI from 'openai';
import { Ollama } from 'ollama/browser' import { Ollama } from 'ollama/browser'
import { Content, GoogleGenerativeAI, GoogleGenerativeAIError, GoogleGenerativeAIFetchError } from '@google/generative-ai'; import { Content, GoogleGenerativeAI, GoogleGenerativeAIError, GoogleGenerativeAIFetchError } from '@google/generative-ai';
import { VoidConfig } from '../webviews/common/contextForConfig' import { VoidConfig } from '../webviews/common/contextForConfig'
import Groq from 'groq-sdk'; import Groq, { GroqError } from 'groq-sdk';
export type AbortRef = { current: (() => void) | null } export type AbortRef = { current: (() => void) | null }
@ -358,6 +358,7 @@ const sendGreptileMsg: SendLLMMessageFnTypeInternal = ({ messages, onText, onFin
} }
// Groq
const sendGroqMsg: SendLLMMessageFnTypeInternal = async ({ messages, onText, onFinalMessage, onError, voidConfig, abortRef }) => { const sendGroqMsg: SendLLMMessageFnTypeInternal = async ({ messages, onText, onFinalMessage, onError, voidConfig, abortRef }) => {
let didAbort = false; let didAbort = false;
let fullText = ''; let fullText = '';
@ -366,39 +367,32 @@ const sendGroqMsg: SendLLMMessageFnTypeInternal = async ({ messages, onText, onF
didAbort = true; didAbort = true;
}; };
const max_tokens = parseMaxTokensStr(voidConfig.default.maxTokens)
const options = { model: voidConfig.groq.model, messages: messages, stream: true, max_tokens: max_tokens, } as const
const groq = new Groq({ apiKey: voidConfig.groq.apikey, dangerouslyAllowBrowser: true }); const groq = new Groq({ apiKey: voidConfig.groq.apikey, dangerouslyAllowBrowser: true });
try { groq.chat.completions
const stream = await groq.chat.completions.create({ .create(options)
messages: messages, .then(async response => {
model: voidConfig.groq.model, for await (const chunk of response) {
stream: true, if (didAbort) return;
temperature: 0.7, const newText = chunk.choices[0]?.delta?.content || '';
max_tokens: parseMaxTokensStr(voidConfig.default.maxTokens),
});
for await (const chunk of stream) {
if (didAbort) {
break;
}
const newText = chunk.choices[0]?.delta?.content || '';
if (newText) {
fullText += newText; fullText += newText;
onText(newText, fullText); onText(newText, fullText);
} }
}
if (!didAbort) {
onFinalMessage(fullText); onFinalMessage(fullText);
} })
} catch (error: any) { // when error/fail - this catches errors of both .create() and .then(for await)
if (error?.status === 401) { .catch(error => {
onError('Invalid API key.'); if (error instanceof GroqError) {
} else { onError(`${error.name}:\n${error.message}`);
onError(error.message || 'An error occurred while connecting to Groq.'); }
} else {
} onError(error);
}
})
}; };
export const sendLLMMessage: SendLLMMessageFnTypeExternal = ({ messages, onText, onFinalMessage, onError, voidConfig, abortRef }) => { export const sendLLMMessage: SendLLMMessageFnTypeExternal = ({ messages, onText, onFinalMessage, onError, voidConfig, abortRef }) => {