From 339aff5d3143bca928f3607c8e12c3ce5d68b37f Mon Sep 17 00:00:00 2001 From: mp Date: Tue, 12 Nov 2024 01:36:32 -0800 Subject: [PATCH] Refactor sendLLMMessage and add FIM mode --- extensions/void/src/common/getPrompt.ts | 105 ++++++++++++++++++ extensions/void/src/common/sendLLMMessage.ts | 81 +++++++++++--- extensions/void/src/common/systemPrompts.ts | 16 ++- .../void/src/extension/AutcompleteProvider.ts | 71 ++++++------ extensions/void/src/extension/ctrlK.ts | 5 +- 5 files changed, 222 insertions(+), 56 deletions(-) create mode 100644 extensions/void/src/common/getPrompt.ts diff --git a/extensions/void/src/common/getPrompt.ts b/extensions/void/src/common/getPrompt.ts new file mode 100644 index 00000000..f1c7567b --- /dev/null +++ b/extensions/void/src/common/getPrompt.ts @@ -0,0 +1,105 @@ +import { configFields, VoidConfig } from "../webviews/common/contextForConfig" +import { FimInfo } from "./sendLLMMessage" + + +type GetFIMPrompt = ({ voidConfig, fimInfo }: { voidConfig: VoidConfig, fimInfo: FimInfo, }) => string + +export const getFIMSystem: GetFIMPrompt = ({ voidConfig, fimInfo }) => { + + switch (voidConfig.default.whichApi) { + case 'ollama': + return '' + case 'anthropic': + case 'openAI': + case 'gemini': + case 'greptile': + case 'openRouter': + case 'openAICompatible': + case 'azure': + default: + return `You are given the START and END to a piece of code. Please FILL IN THE MIDDLE between the START and END. + +Instruction summary: +1. Return the MIDDLE of the code between the START and END. +2. Do not give an explanation, description, or any other code besides the middle. +2. Do not return duplicate code from either START or END. +3. Make sure the MIDDLE piece of code has balanced brackets that match the START and END. +4. The MIDDLE begins on the same line as START. Please include a newline character if you want to begin on the next line. + +# EXAMPLE + +## START: +\`\`\` python +def add(a,b): + return a + b +def subtract(a,b): + return a - b +\`\`\` +## END: +\`\`\` python +def divide(a,b): + return a / b +\`\`\` +## EXPECTED OUTPUT: +\`\`\` python + +def multiply(a,b): + return a * b +\`\`\` + +# EXAMPLE +## START: +\`\`\` javascript +const x = 1 + +const y +\`\`\` +## END: +\`\`\` javascript + +const z = 3 +\`\`\` +## EXPECTED OUTPUT: +\`\`\` javascript += 2 +\`\`\` +` + } + + +} + + +export const getFIMPrompt: GetFIMPrompt = ({ voidConfig, fimInfo }) => { + + // if no prefix or suffix, return empty string + if (!fimInfo.prefix.trim() && !fimInfo.suffix.trim()) return '' + + // TODO may want to trim the prefix and suffix + switch (voidConfig.default.whichApi) { + case 'ollama': + if (voidConfig.ollama.model === 'codestral') { + return `[SUFFIX]${fimInfo.suffix}[PREFIX] ${fimInfo.prefix}` + } + return '' + case 'anthropic': + case 'openAI': + case 'gemini': + case 'greptile': + case 'openRouter': + case 'openAICompatible': + case 'azure': + default: + return `## START: +\`\`\` +${fimInfo.prefix} +\`\`\` +## END: +\`\`\` +${fimInfo.suffix} +\`\`\` +` + + } +} + diff --git a/extensions/void/src/common/sendLLMMessage.ts b/extensions/void/src/common/sendLLMMessage.ts index 5a6e8f0f..a86f5c49 100644 --- a/extensions/void/src/common/sendLLMMessage.ts +++ b/extensions/void/src/common/sendLLMMessage.ts @@ -3,6 +3,7 @@ import OpenAI from 'openai'; import { Ollama } from 'ollama/browser' import { Content, GoogleGenerativeAI, GoogleGenerativeAIError, GoogleGenerativeAIFetchError } from '@google/generative-ai'; import { VoidConfig } from '../webviews/common/contextForConfig' +import { getFIMPrompt, getFIMSystem } from './getPrompt'; export type AbortRef = { current: (() => void) | null } @@ -21,23 +22,32 @@ export type LLMMessage = { } type SendLLMMessageFnTypeInternal = (params: { + mode: 'chat' | 'fim', messages: LLMMessage[], onText: OnText, onFinalMessage: OnFinalMessage, onError: (error: string) => void, - voidConfig: VoidConfig, abortRef: AbortRef, + voidConfig: VoidConfig, }) => void -type SendLLMMessageFnTypeExternal = (params: { - messages: LLMMessage[], + +type SendLLMMessageFnTypeExternal = (params: ( + | { mode?: 'chat', messages: LLMMessage[], fimInfo?: undefined, } + | { mode: 'fim', fimInfo: FimInfo, messages?: undefined, } +) & { onText: OnText, - onFinalMessage: (fullText: string) => void, + onFinalMessage: OnFinalMessage, onError: (error: string) => void, - voidConfig: VoidConfig | null, abortRef: AbortRef, + voidConfig: VoidConfig | null, // these may be absent }) => void +export type FimInfo = { + prefix: string, + suffix: string, +} + const parseMaxTokensStr = (maxTokensStr: string) => { // parse the string but only if the full string is a valid number, eg parseInt('100abc') should return NaN let int = isNaN(Number(maxTokensStr)) ? undefined : parseInt(maxTokensStr) @@ -232,7 +242,7 @@ const sendOpenAIMsg: SendLLMMessageFnTypeInternal = ({ messages, onText, onFinal }; // Ollama -export const sendOllamaMsg: SendLLMMessageFnTypeInternal = ({ messages, onText, onFinalMessage, onError, voidConfig, abortRef }) => { +export const sendOllamaMsg: SendLLMMessageFnTypeInternal = ({ mode, messages, onText, onFinalMessage, onError, voidConfig, abortRef }) => { let didAbort = false let fullText = "" @@ -243,6 +253,10 @@ export const sendOllamaMsg: SendLLMMessageFnTypeInternal = ({ messages, onText, const ollama = new Ollama({ host: voidConfig.ollama.endpoint }) + type GenerateResponse = Awaited> + type ChatResponse = Awaited> + + // First check if model exists ollama.list() .then(async models => { @@ -256,6 +270,18 @@ export const sendOllamaMsg: SendLLMMessageFnTypeInternal = ({ messages, onText, return Promise.reject(); } + if (mode === 'fim') { + + // the fim prompt is the last message + let prompt = messages[messages.length - 1].content + return ollama.generate({ + model: voidConfig.ollama.model, + prompt: prompt, + stream: true, + raw: true, + }) + } + return ollama.chat({ model: voidConfig.ollama.model, messages: messages, @@ -271,7 +297,11 @@ export const sendOllamaMsg: SendLLMMessageFnTypeInternal = ({ messages, onText, } for await (const chunk of stream) { if (didAbort) return; - const newText = chunk.message.content; + + const newText = (mode === 'fim' + ? (chunk as GenerateResponse).response + : (chunk as ChatResponse).message.content + ) fullText += newText; onText(newText, fullText); } @@ -357,26 +387,49 @@ const sendGreptileMsg: SendLLMMessageFnTypeInternal = ({ messages, onText, onFin } -export const sendLLMMessage: SendLLMMessageFnTypeExternal = ({ messages, onText, onFinalMessage, onError, voidConfig, abortRef }) => { - if (!voidConfig) return; +export const sendLLMMessage: SendLLMMessageFnTypeExternal = ({ mode, messages, fimInfo, onText, onFinalMessage, onError, voidConfig, abortRef }) => { + if (!voidConfig) + return onError('No config file found for LLM.'); + + // handle defaults + if (!mode) mode = 'chat' + if (!messages) messages = [] + + // build messages + if (mode === 'chat') { + // nothing needed + } else if (mode === 'fim') { + fimInfo = fimInfo! + + const system = getFIMSystem({ voidConfig, fimInfo }) + const prompt = getFIMPrompt({ voidConfig, fimInfo }) + messages = ([ + { role: 'system', content: system }, + { role: 'user', content: prompt } + ] as const) + .filter(m => m.content.trim() !== '') + } // trim message content (Anthropic and other providers give an error if there is trailing whitespace) messages = messages.map(m => ({ ...m, content: m.content.trim() })) + if (messages.length === 0) + return onError('No messages provided to LLM.'); switch (voidConfig.default.whichApi) { case 'anthropic': - return sendAnthropicMsg({ messages, onText, onFinalMessage, onError, voidConfig, abortRef }); + return sendAnthropicMsg({ mode, messages, onText, onFinalMessage, onError, voidConfig, abortRef }); case 'openAI': case 'openRouter': case 'openAICompatible': - return sendOpenAIMsg({ messages, onText, onFinalMessage, onError, voidConfig, abortRef }); + return sendOpenAIMsg({ mode, messages, onText, onFinalMessage, onError, voidConfig, abortRef }); case 'gemini': - return sendGeminiMsg({ messages, onText, onFinalMessage, onError, voidConfig, abortRef }); + return sendGeminiMsg({ mode, messages, onText, onFinalMessage, onError, voidConfig, abortRef }); case 'ollama': - return sendOllamaMsg({ messages, onText, onFinalMessage, onError, voidConfig, abortRef }); + return sendOllamaMsg({ mode, messages, onText, onFinalMessage, onError, voidConfig, abortRef }); case 'greptile': - return sendGreptileMsg({ messages, onText, onFinalMessage, onError, voidConfig, abortRef }); + return sendGreptileMsg({ mode, messages, onText, onFinalMessage, onError, voidConfig, abortRef }); default: onError(`Error: whichApi was ${voidConfig.default.whichApi}, which is not recognized!`) } + } diff --git a/extensions/void/src/common/systemPrompts.ts b/extensions/void/src/common/systemPrompts.ts index edbfa03b..276c8570 100644 --- a/extensions/void/src/common/systemPrompts.ts +++ b/extensions/void/src/common/systemPrompts.ts @@ -1,4 +1,14 @@ + + +// used for ctrl+l +const partialGenerationInstructions = `` + + +// used for ctrl+k, autocomplete +const fimInstructions = `` + + const generateDiffInstructions = ` You are a coding assistant. You are given a list of relevant files \`files\`, a selection that the user is making \`selection\`, and instructions to follow \`instructions\`. @@ -397,12 +407,6 @@ COMPLETION export default Sidebar;\`\`\` ` -// used for ctrl+l -const partialGenerationInstructions = `` - - -// used for ctrl+k, autocomplete -const fimInstructions = `` diff --git a/extensions/void/src/extension/AutcompleteProvider.ts b/extensions/void/src/extension/AutcompleteProvider.ts index 28db0bcf..336d587c 100644 --- a/extensions/void/src/extension/AutcompleteProvider.ts +++ b/extensions/void/src/extension/AutcompleteProvider.ts @@ -1,6 +1,7 @@ import * as vscode from 'vscode'; import { AbortRef, LLMMessage, sendLLMMessage } from '../common/sendLLMMessage'; import { getVoidConfigFromPartial, VoidConfig } from '../webviews/common/contextForConfig'; +import { result } from 'lodash'; type AutocompletionStatus = 'pending' | 'finished' | 'error'; type Autocompletion = { @@ -17,6 +18,30 @@ type Autocompletion = { const DEBOUNCE_TIME = 500 const TIMEOUT_TIME = 60000 +// postprocesses the result +const postprocessResult = (result: string) => { + + // remove leading whitespace from result + return result.trimStart() + +} + + +const extractCodeFromResult = (result: string) => { + + // extract the code between triple backticks + const parts = result.split(/```/); + + // if there is no ``` then return the raw result + if (parts.length === 1) { + return result; + } + + // else return the code between the triple backticks + return parts[1] + +} + // trims the end of the prefix to improve cache hit rate const trimPrefix = (prefix: string) => { const trimmedPrefix = prefix.trimEnd() @@ -31,7 +56,13 @@ const trimPrefix = (prefix: string) => { return trimmedPrefix } -// finds the text in the autocompletion to display +// finds the text in the autocompletion to display, assuming the prefix is already matched +// example: +// originalPrefix = abcd +// generatedMiddle = efgh +// originalSuffix = ijkl +// the user has typed "ef" so prefix = abcdef +// we want to return the rest of the generatedMiddle, which is "gh" const toInlineCompletion = ({ prefix, autocompletion }: { prefix: string, autocompletion: Autocompletion }): vscode.InlineCompletionItem => { const originalPrefix = autocompletion.prefix const generatedMiddle = autocompletion.result @@ -44,18 +75,13 @@ const toInlineCompletion = ({ prefix, autocompletion }: { prefix: string, autoco console.log('generatedMiddle ', generatedMiddle) console.log('trimmedOriginalPrefix ', trimmedOriginalPrefix) console.log('trimmedCurrentPrefix ', trimmedCurrentPrefix) - console.log('lastMatchupIndex ', lastMatchupIndex) + console.log('index: ', lastMatchupIndex) if (lastMatchupIndex < 0) { return new vscode.InlineCompletionItem('') } - // example: - // originalPrefix = abcd - // generatedMiddle = efgh - // originalSuffix = ijkl - // the user has typed "ef" so prefix = abcdef - // we want to return the rest of the generatedMiddle, which is "gh" const completionStr = generatedMiddle.substring(lastMatchupIndex) + console.log('completionStr: ', completionStr) return new vscode.InlineCompletionItem(completionStr) @@ -131,8 +157,6 @@ export class AutocompleteProvider implements vscode.InlineCompletionItemProvider console.log('AAA1') const inlineCompletion = toInlineCompletion({ autocompletion: cachedAutocompletion, prefix, }) - - return [inlineCompletion] } else if (cachedAutocompletion.status === 'pending') { @@ -189,30 +213,12 @@ export class AutocompleteProvider implements vscode.InlineCompletionItemProvider result: '', } - - let messages: LLMMessage[] = [] - switch (voidConfig.default.whichApi) { - case 'ollama': - messages = [ - { role: 'user', content: `[SUFFIX]${suffix}[PREFIX]${prefix} Fill in the middle between the prefix and suffix. Return only the middle. [MIDDLE]` } - ] - break; - case 'anthropic': - case 'openAI': - messages = [ - { role: 'system', content: 'Fill in the prefix up to the suffix. Return only the result and be very concise.' }, - { role: 'user', content: `[SUFFIX]${suffix}[PREFIX]${prefix}` }, - ] - break; - default: - throw new Error(`We do not recommend using autocomplete with your selected provider (${voidConfig.default.whichApi}).`); - } - // set parameters of `newAutocompletion` appropriately newAutocompletion.promise = new Promise((resolve, reject) => { sendLLMMessage({ - messages: messages, + mode: 'fim', + fimInfo: { prefix, suffix }, onText: async (tokenStr, completionStr) => { // TODO filter out bad responses here newAutocompletion.result = completionStr @@ -226,9 +232,10 @@ export class AutocompleteProvider implements vscode.InlineCompletionItemProvider // newAutocompletion.abortRef = { current: () => { } } newAutocompletion.status = 'finished' // newAutocompletion.promise = undefined - newAutocompletion.result = finalMessage + newAutocompletion.result = postprocessResult(extractCodeFromResult(finalMessage)) + + resolve(newAutocompletion.result) - resolve(finalMessage) }, onError: (e) => { newAutocompletion.endTime = Date.now() diff --git a/extensions/void/src/extension/ctrlK.ts b/extensions/void/src/extension/ctrlK.ts index 2d7e35cd..8b3f2ab6 100644 --- a/extensions/void/src/extension/ctrlK.ts +++ b/extensions/void/src/extension/ctrlK.ts @@ -1,8 +1,6 @@ import * as vscode from 'vscode'; import { AbortRef, OnFinalMessage, OnText, sendLLMMessage } from "../common/sendLLMMessage" import { VoidConfig } from '../webviews/common/contextForConfig'; -import { searchDiffChunkInstructions, writeFileWithDiffInstructions } from '../common/systemPrompts'; -import { throttle } from 'lodash'; import { readFileContentOfUri } from './extensionLib/readFileContentOfUri'; const applyCtrlK = async ({ fileUri, startLine, endLine, instructions, voidConfig, abortRef }: { fileUri: vscode.Uri, startLine: number, endLine: number, instructions: string, voidConfig: VoidConfig, abortRef: AbortRef }) => { @@ -22,14 +20,13 @@ const applyCtrlK = async ({ fileUri, startLine, endLine, instructions, voidConfi The user wants to apply the following instructions to the selection: ${instructions} - Instructions: 1. Follow the user's instructions 2. You may ONLY CHANGE the selection, and nothing else in the file 3. Make sure all brackets in the new selection are balanced the same was as in the original selection 4. Be careful not to duplicate or remove variables, comments, or other syntax by mistake -Please rewrite the complete the following code, following the user's instructions. +Please rewrite the complete the following code, following the instructions. \`\`\`
${prefix}
${suffix}