Refactor sendLLMMessage and add FIM mode

This commit is contained in:
mp 2024-11-12 01:36:32 -08:00
parent aab065139e
commit 339aff5d31
5 changed files with 222 additions and 56 deletions

View file

@ -0,0 +1,105 @@
import { configFields, VoidConfig } from "../webviews/common/contextForConfig"
import { FimInfo } from "./sendLLMMessage"
type GetFIMPrompt = ({ voidConfig, fimInfo }: { voidConfig: VoidConfig, fimInfo: FimInfo, }) => string
export const getFIMSystem: GetFIMPrompt = ({ voidConfig, fimInfo }) => {
switch (voidConfig.default.whichApi) {
case 'ollama':
return ''
case 'anthropic':
case 'openAI':
case 'gemini':
case 'greptile':
case 'openRouter':
case 'openAICompatible':
case 'azure':
default:
return `You are given the START and END to a piece of code. Please FILL IN THE MIDDLE between the START and END.
Instruction summary:
1. Return the MIDDLE of the code between the START and END.
2. Do not give an explanation, description, or any other code besides the middle.
2. Do not return duplicate code from either START or END.
3. Make sure the MIDDLE piece of code has balanced brackets that match the START and END.
4. The MIDDLE begins on the same line as START. Please include a newline character if you want to begin on the next line.
# EXAMPLE
## START:
\`\`\` python
def add(a,b):
return a + b
def subtract(a,b):
return a - b
\`\`\`
## END:
\`\`\` python
def divide(a,b):
return a / b
\`\`\`
## EXPECTED OUTPUT:
\`\`\` python
def multiply(a,b):
return a * b
\`\`\`
# EXAMPLE
## START:
\`\`\` javascript
const x = 1
const y
\`\`\`
## END:
\`\`\` javascript
const z = 3
\`\`\`
## EXPECTED OUTPUT:
\`\`\` javascript
= 2
\`\`\`
`
}
}
export const getFIMPrompt: GetFIMPrompt = ({ voidConfig, fimInfo }) => {
// if no prefix or suffix, return empty string
if (!fimInfo.prefix.trim() && !fimInfo.suffix.trim()) return ''
// TODO may want to trim the prefix and suffix
switch (voidConfig.default.whichApi) {
case 'ollama':
if (voidConfig.ollama.model === 'codestral') {
return `[SUFFIX]${fimInfo.suffix}[PREFIX] ${fimInfo.prefix}`
}
return ''
case 'anthropic':
case 'openAI':
case 'gemini':
case 'greptile':
case 'openRouter':
case 'openAICompatible':
case 'azure':
default:
return `## START:
\`\`\`
${fimInfo.prefix}
\`\`\`
## END:
\`\`\`
${fimInfo.suffix}
\`\`\`
`
}
}

View file

@ -3,6 +3,7 @@ import OpenAI from 'openai';
import { Ollama } from 'ollama/browser'
import { Content, GoogleGenerativeAI, GoogleGenerativeAIError, GoogleGenerativeAIFetchError } from '@google/generative-ai';
import { VoidConfig } from '../webviews/common/contextForConfig'
import { getFIMPrompt, getFIMSystem } from './getPrompt';
export type AbortRef = { current: (() => void) | null }
@ -21,23 +22,32 @@ export type LLMMessage = {
}
type SendLLMMessageFnTypeInternal = (params: {
mode: 'chat' | 'fim',
messages: LLMMessage[],
onText: OnText,
onFinalMessage: OnFinalMessage,
onError: (error: string) => void,
voidConfig: VoidConfig,
abortRef: AbortRef,
voidConfig: VoidConfig,
}) => void
type SendLLMMessageFnTypeExternal = (params: {
messages: LLMMessage[],
type SendLLMMessageFnTypeExternal = (params: (
| { mode?: 'chat', messages: LLMMessage[], fimInfo?: undefined, }
| { mode: 'fim', fimInfo: FimInfo, messages?: undefined, }
) & {
onText: OnText,
onFinalMessage: (fullText: string) => void,
onFinalMessage: OnFinalMessage,
onError: (error: string) => void,
voidConfig: VoidConfig | null,
abortRef: AbortRef,
voidConfig: VoidConfig | null, // these may be absent
}) => void
export type FimInfo = {
prefix: string,
suffix: string,
}
const parseMaxTokensStr = (maxTokensStr: string) => {
// parse the string but only if the full string is a valid number, eg parseInt('100abc') should return NaN
let int = isNaN(Number(maxTokensStr)) ? undefined : parseInt(maxTokensStr)
@ -232,7 +242,7 @@ const sendOpenAIMsg: SendLLMMessageFnTypeInternal = ({ messages, onText, onFinal
};
// Ollama
export const sendOllamaMsg: SendLLMMessageFnTypeInternal = ({ messages, onText, onFinalMessage, onError, voidConfig, abortRef }) => {
export const sendOllamaMsg: SendLLMMessageFnTypeInternal = ({ mode, messages, onText, onFinalMessage, onError, voidConfig, abortRef }) => {
let didAbort = false
let fullText = ""
@ -243,6 +253,10 @@ export const sendOllamaMsg: SendLLMMessageFnTypeInternal = ({ messages, onText,
const ollama = new Ollama({ host: voidConfig.ollama.endpoint })
type GenerateResponse = Awaited<ReturnType<(typeof ollama.generate)>>
type ChatResponse = Awaited<ReturnType<(typeof ollama.chat)>>
// First check if model exists
ollama.list()
.then(async models => {
@ -256,6 +270,18 @@ export const sendOllamaMsg: SendLLMMessageFnTypeInternal = ({ messages, onText,
return Promise.reject();
}
if (mode === 'fim') {
// the fim prompt is the last message
let prompt = messages[messages.length - 1].content
return ollama.generate({
model: voidConfig.ollama.model,
prompt: prompt,
stream: true,
raw: true,
})
}
return ollama.chat({
model: voidConfig.ollama.model,
messages: messages,
@ -271,7 +297,11 @@ export const sendOllamaMsg: SendLLMMessageFnTypeInternal = ({ messages, onText,
}
for await (const chunk of stream) {
if (didAbort) return;
const newText = chunk.message.content;
const newText = (mode === 'fim'
? (chunk as GenerateResponse).response
: (chunk as ChatResponse).message.content
)
fullText += newText;
onText(newText, fullText);
}
@ -357,26 +387,49 @@ const sendGreptileMsg: SendLLMMessageFnTypeInternal = ({ messages, onText, onFin
}
export const sendLLMMessage: SendLLMMessageFnTypeExternal = ({ messages, onText, onFinalMessage, onError, voidConfig, abortRef }) => {
if (!voidConfig) return;
export const sendLLMMessage: SendLLMMessageFnTypeExternal = ({ mode, messages, fimInfo, onText, onFinalMessage, onError, voidConfig, abortRef }) => {
if (!voidConfig)
return onError('No config file found for LLM.');
// handle defaults
if (!mode) mode = 'chat'
if (!messages) messages = []
// build messages
if (mode === 'chat') {
// nothing needed
} else if (mode === 'fim') {
fimInfo = fimInfo!
const system = getFIMSystem({ voidConfig, fimInfo })
const prompt = getFIMPrompt({ voidConfig, fimInfo })
messages = ([
{ role: 'system', content: system },
{ role: 'user', content: prompt }
] as const)
.filter(m => m.content.trim() !== '')
}
// trim message content (Anthropic and other providers give an error if there is trailing whitespace)
messages = messages.map(m => ({ ...m, content: m.content.trim() }))
if (messages.length === 0)
return onError('No messages provided to LLM.');
switch (voidConfig.default.whichApi) {
case 'anthropic':
return sendAnthropicMsg({ messages, onText, onFinalMessage, onError, voidConfig, abortRef });
return sendAnthropicMsg({ mode, messages, onText, onFinalMessage, onError, voidConfig, abortRef });
case 'openAI':
case 'openRouter':
case 'openAICompatible':
return sendOpenAIMsg({ messages, onText, onFinalMessage, onError, voidConfig, abortRef });
return sendOpenAIMsg({ mode, messages, onText, onFinalMessage, onError, voidConfig, abortRef });
case 'gemini':
return sendGeminiMsg({ messages, onText, onFinalMessage, onError, voidConfig, abortRef });
return sendGeminiMsg({ mode, messages, onText, onFinalMessage, onError, voidConfig, abortRef });
case 'ollama':
return sendOllamaMsg({ messages, onText, onFinalMessage, onError, voidConfig, abortRef });
return sendOllamaMsg({ mode, messages, onText, onFinalMessage, onError, voidConfig, abortRef });
case 'greptile':
return sendGreptileMsg({ messages, onText, onFinalMessage, onError, voidConfig, abortRef });
return sendGreptileMsg({ mode, messages, onText, onFinalMessage, onError, voidConfig, abortRef });
default:
onError(`Error: whichApi was ${voidConfig.default.whichApi}, which is not recognized!`)
}
}

View file

@ -1,4 +1,14 @@
// used for ctrl+l
const partialGenerationInstructions = ``
// used for ctrl+k, autocomplete
const fimInstructions = ``
const generateDiffInstructions = `
You are a coding assistant. You are given a list of relevant files \`files\`, a selection that the user is making \`selection\`, and instructions to follow \`instructions\`.
@ -397,12 +407,6 @@ COMPLETION
export default Sidebar;\`\`\`
`
// used for ctrl+l
const partialGenerationInstructions = ``
// used for ctrl+k, autocomplete
const fimInstructions = ``

View file

@ -1,6 +1,7 @@
import * as vscode from 'vscode';
import { AbortRef, LLMMessage, sendLLMMessage } from '../common/sendLLMMessage';
import { getVoidConfigFromPartial, VoidConfig } from '../webviews/common/contextForConfig';
import { result } from 'lodash';
type AutocompletionStatus = 'pending' | 'finished' | 'error';
type Autocompletion = {
@ -17,6 +18,30 @@ type Autocompletion = {
const DEBOUNCE_TIME = 500
const TIMEOUT_TIME = 60000
// postprocesses the result
const postprocessResult = (result: string) => {
// remove leading whitespace from result
return result.trimStart()
}
const extractCodeFromResult = (result: string) => {
// extract the code between triple backticks
const parts = result.split(/```/);
// if there is no ``` then return the raw result
if (parts.length === 1) {
return result;
}
// else return the code between the triple backticks
return parts[1]
}
// trims the end of the prefix to improve cache hit rate
const trimPrefix = (prefix: string) => {
const trimmedPrefix = prefix.trimEnd()
@ -31,7 +56,13 @@ const trimPrefix = (prefix: string) => {
return trimmedPrefix
}
// finds the text in the autocompletion to display
// finds the text in the autocompletion to display, assuming the prefix is already matched
// example:
// originalPrefix = abcd
// generatedMiddle = efgh
// originalSuffix = ijkl
// the user has typed "ef" so prefix = abcdef
// we want to return the rest of the generatedMiddle, which is "gh"
const toInlineCompletion = ({ prefix, autocompletion }: { prefix: string, autocompletion: Autocompletion }): vscode.InlineCompletionItem => {
const originalPrefix = autocompletion.prefix
const generatedMiddle = autocompletion.result
@ -44,18 +75,13 @@ const toInlineCompletion = ({ prefix, autocompletion }: { prefix: string, autoco
console.log('generatedMiddle ', generatedMiddle)
console.log('trimmedOriginalPrefix ', trimmedOriginalPrefix)
console.log('trimmedCurrentPrefix ', trimmedCurrentPrefix)
console.log('lastMatchupIndex ', lastMatchupIndex)
console.log('index: ', lastMatchupIndex)
if (lastMatchupIndex < 0) {
return new vscode.InlineCompletionItem('')
}
// example:
// originalPrefix = abcd
// generatedMiddle = efgh
// originalSuffix = ijkl
// the user has typed "ef" so prefix = abcdef
// we want to return the rest of the generatedMiddle, which is "gh"
const completionStr = generatedMiddle.substring(lastMatchupIndex)
console.log('completionStr: ', completionStr)
return new vscode.InlineCompletionItem(completionStr)
@ -131,8 +157,6 @@ export class AutocompleteProvider implements vscode.InlineCompletionItemProvider
console.log('AAA1')
const inlineCompletion = toInlineCompletion({ autocompletion: cachedAutocompletion, prefix, })
return [inlineCompletion]
} else if (cachedAutocompletion.status === 'pending') {
@ -189,30 +213,12 @@ export class AutocompleteProvider implements vscode.InlineCompletionItemProvider
result: '',
}
let messages: LLMMessage[] = []
switch (voidConfig.default.whichApi) {
case 'ollama':
messages = [
{ role: 'user', content: `[SUFFIX]${suffix}[PREFIX]${prefix} Fill in the middle between the prefix and suffix. Return only the middle. [MIDDLE]` }
]
break;
case 'anthropic':
case 'openAI':
messages = [
{ role: 'system', content: 'Fill in the prefix up to the suffix. Return only the result and be very concise.' },
{ role: 'user', content: `[SUFFIX]${suffix}[PREFIX]${prefix}` },
]
break;
default:
throw new Error(`We do not recommend using autocomplete with your selected provider (${voidConfig.default.whichApi}).`);
}
// set parameters of `newAutocompletion` appropriately
newAutocompletion.promise = new Promise((resolve, reject) => {
sendLLMMessage({
messages: messages,
mode: 'fim',
fimInfo: { prefix, suffix },
onText: async (tokenStr, completionStr) => {
// TODO filter out bad responses here
newAutocompletion.result = completionStr
@ -226,9 +232,10 @@ export class AutocompleteProvider implements vscode.InlineCompletionItemProvider
// newAutocompletion.abortRef = { current: () => { } }
newAutocompletion.status = 'finished'
// newAutocompletion.promise = undefined
newAutocompletion.result = finalMessage
newAutocompletion.result = postprocessResult(extractCodeFromResult(finalMessage))
resolve(newAutocompletion.result)
resolve(finalMessage)
},
onError: (e) => {
newAutocompletion.endTime = Date.now()

View file

@ -1,8 +1,6 @@
import * as vscode from 'vscode';
import { AbortRef, OnFinalMessage, OnText, sendLLMMessage } from "../common/sendLLMMessage"
import { VoidConfig } from '../webviews/common/contextForConfig';
import { searchDiffChunkInstructions, writeFileWithDiffInstructions } from '../common/systemPrompts';
import { throttle } from 'lodash';
import { readFileContentOfUri } from './extensionLib/readFileContentOfUri';
const applyCtrlK = async ({ fileUri, startLine, endLine, instructions, voidConfig, abortRef }: { fileUri: vscode.Uri, startLine: number, endLine: number, instructions: string, voidConfig: VoidConfig, abortRef: AbortRef }) => {
@ -22,14 +20,13 @@ const applyCtrlK = async ({ fileUri, startLine, endLine, instructions, voidConfi
The user wants to apply the following instructions to the selection:
${instructions}
Instructions:
1. Follow the user's instructions
2. You may ONLY CHANGE the selection, and nothing else in the file
3. Make sure all brackets in the new selection are balanced the same was as in the original selection
4. Be careful not to duplicate or remove variables, comments, or other syntax by mistake
Please rewrite the complete the following code, following the user's instructions.
Please rewrite the complete the following code, following the instructions.
\`\`\`
<PRE>${prefix}</PRE>
<SUF>${suffix}</SUF>