autocomplete type context

This commit is contained in:
Mathew Pareles 2025-01-29 01:19:05 -08:00
parent 528b8a6b9b
commit 41057c20fc
2 changed files with 418 additions and 83 deletions

View file

@ -54,10 +54,10 @@ export const sendOllamaFIM: _InternalOllamaFIMMessageFnType = ({ messages, onTex
suffix: messages.suffix,
options: {
stop: messages.stopTokens,
num_predict: 300 // max tokens
},
raw: true,
stream: true,
// options: { num_predict: parseMaxTokensStr(thisConfig.maxTokens) } // this is max_tokens
})
.then(async stream => {
_setAborter(() => stream.abort())

View file

@ -8,7 +8,7 @@ import { ILanguageFeaturesService } from '../../../../editor/common/services/lan
import { createDecorator } from '../../../../platform/instantiation/common/instantiation.js';
import { ITextModel } from '../../../../editor/common/model.js';
import { Position } from '../../../../editor/common/core/position.js';
import { InlineCompletion, InlineCompletionContext, LocationLink } from '../../../../editor/common/languages.js';
import { DocumentSymbol, InlineCompletion, InlineCompletionContext, Location, } from '../../../../editor/common/languages.js';
import { CancellationToken } from '../../../../base/common/cancellation.js';
import { Range } from '../../../../editor/common/core/range.js';
import { ILLMMessageService } from '../../../../platform/void/common/llmMessageService.js';
@ -155,6 +155,7 @@ type Autocompletion = {
llmPromise: Promise<string> | undefined,
insertText: string,
requestId: string | null,
_newlineCount: number,
}
const DEBOUNCE_TIME = 500
@ -163,7 +164,7 @@ const MAX_CACHE_SIZE = 20
const MAX_PENDING_REQUESTS = 2
// postprocesses the result
const joinSpaces = (result: string) => {
const processStartAndEndSpaces = (result: string) => {
// trim all whitespace except for a single leading/trailing space
// return result.trim()
@ -196,22 +197,26 @@ const removeLeftTabsAndTrimEnds = (s: string): string => {
const removeAllWhitespace = (str: string): string => str.replace(/\s+/g, '');
function isSubsequence({ of, subsequence }: { of: string, subsequence: string }): boolean {
if (subsequence.length === 0) return true;
if (of.length === 0) return false;
function getIsSubsequence({ of, subsequence }: { of: string, subsequence: string }): [boolean, string] {
if (subsequence.length === 0) return [true, ''];
if (of.length === 0) return [false, ''];
let subsequenceIndex = 0;
let lastMatchChar = '';
for (let i = 0; i < of.length; i++) {
if (of[i] === subsequence[subsequenceIndex]) {
lastMatchChar = of[i];
subsequenceIndex++;
}
if (subsequenceIndex === subsequence.length) {
return true;
return [true, lastMatchChar];
}
}
return false;
return [false, lastMatchChar];
}
@ -251,7 +256,6 @@ function getStringUpToUnbalancedClosingParenthesis(s: string, prefix: string): s
}
// further trim the autocompletion
const postprocessAutocompletion = ({ autocompletionMatchup, autocompletion, prefixAndSuffix }: { autocompletionMatchup: AutocompletionMatchupBounds, autocompletion: Autocompletion, prefixAndSuffix: PrefixAndSuffixInfo }) => {
@ -357,15 +361,24 @@ const toInlineCompletions = ({ autocompletionMatchup, autocompletion, prefixAndS
// if we redid the suffix, replace the suffix
if (autocompletion.type === 'single-line-redo-suffix') {
if (isSubsequence({ // check that the old text contains the same brackets + symbols as the new text
subsequence: removeAllWhitespace(prefixAndSuffix.suffixToTheRightOfCursor), // old suffix
of: removeAllWhitespace(autocompletion.insertText), // new suffix (note that this should not be `trimmedInsertText`)
})) {
const oldSuffix = prefixAndSuffix.suffixToTheRightOfCursor
const newSuffix = autocompletion.insertText
const [isSubsequence, lastMatchingChar] = getIsSubsequence({ // check that the old text contains the same brackets + symbols as the new text
subsequence: removeAllWhitespace(oldSuffix), // old suffix
of: removeAllWhitespace(newSuffix), // new suffix
})
if (isSubsequence) {
rangeToReplace = new Range(position.lineNumber, position.column, position.lineNumber, Number.MAX_SAFE_INTEGER)
}
else {
// TODO redo the autocompletion
trimmedInsertText = '' // for now set the mismatched text to ''
const lastMatchupIdx = trimmedInsertText.lastIndexOf(lastMatchingChar)
trimmedInsertText = trimmedInsertText.slice(0, lastMatchupIdx + 1)
const numCharsToReplace = oldSuffix.lastIndexOf(lastMatchingChar) + 1
rangeToReplace = new Range(position.lineNumber, position.column, position.lineNumber, position.column + numCharsToReplace)
console.log('show____', trimmedInsertText, rangeToReplace)
}
}
@ -733,12 +746,9 @@ export class AutocompleteService extends Disposable implements IAutocompleteServ
// gather relevant context from the code around the user's selection and definitions
const relevantContext = await this._gatherRelevantContextForPosition(
model,
position,
3, //recursion depth
1 // number of lines to view in each recursion
);
const relevantContext = await this._gatherRelevantContextForPosition(model, position);
console.log('@@---------------------\n' + relevantContext)
const { shouldGenerate, predictionType, llmPrefix, llmSuffix, stopTokens } = getCompletionOptions(prefixAndSuffix, relevantContext, justAcceptedAutocompletion)
@ -766,6 +776,7 @@ export class AutocompleteService extends Disposable implements IAutocompleteServ
llmPromise: undefined,
insertText: '',
requestId: null,
_newlineCount: 0,
}
console.log('BB')
@ -782,28 +793,34 @@ export class AutocompleteService extends Disposable implements IAutocompleteServ
stopTokens: stopTokens,
},
logging: { loggingName: 'Autocomplete' },
onText: async ({ fullText }) => {
onText: async ({ fullText, newText }) => {
newAutocompletion.insertText = fullText
// if generation doesn't match the prefix for the first few tokens generated, reject it
// count newlines in newText
const numNewlines = newText.match(/\n|\r\n/g)?.length || 0
newAutocompletion._newlineCount += numNewlines
// if too many newlines, resolve up to last newline
if (newAutocompletion._newlineCount > 10) {
const lastNewlinePos = fullText.lastIndexOf('\n')
newAutocompletion.insertText = fullText.substring(0, lastNewlinePos)
resolve(newAutocompletion.insertText)
return
}
// if (!getAutocompletionMatchup({ prefix: this._lastPrefix, autocompletion: newAutocompletion })) {
// reject('LLM response did not match user\'s text.')
// reject('LLM response did not match user\'s text.')
// }
},
onFinalMessage: ({ fullText }) => {
console.log('____res: ', JSON.stringify(newAutocompletion.insertText))
// newAutocompletion.prefix = prefix
// newAutocompletion.suffix = suffix
// newAutocompletion.startTime = Date.now()
newAutocompletion.endTime = Date.now()
// newAutocompletion.abortRef = { current: () => { } }
newAutocompletion.status = 'finished'
// newAutocompletion.promise = undefined
const [text, _] = extractCodeFromRegular({ text: fullText, recentlyAddedTextLen: 0 })
newAutocompletion.insertText = joinSpaces(text)
newAutocompletion.insertText = processStartAndEndSpaces(text)
// handle special case for predicting starting on the next line, add a newline character
if (newAutocompletion.type === 'multi-line-start-on-next-line') {
@ -853,84 +870,400 @@ export class AutocompleteService extends Disposable implements IAutocompleteServ
}
// helper method to gather ~N lines above and below the user's current line,
// and recursively gather lines around any symbol definitions encountered.
// TODO! Given a user's cursor position, get relevant context.
// algorithm pseudocode:
// 1. get all relevant symbols (functions, variables, and types)
// 1a. get all symbols that are `numNearbyLines` lines above and below the current position
// eg. if the context is this:
// ```
// ...
// const addVectors = (a: Vector, b: Vector) => {
//
// ... 100+ LINES OF CODE
// return addVectorsElementWise(a,b, Math.min(a.length, b.length) as NumberType) [[CURSOR]]
// }
// ...
// ```
// then these are all of the symbols it should consider that are above and below the position: ['addVectorsElementWise', 'Math.min', 'a.length', 'b.length', 'NumberType']
// 1b. look at where the parent function is defined and get its nearby symbols `numParentLines`
// ex.
// ```
// ...
// const addVectors = (a: Vector, b: Vector) => { [[THIS IS THE PARENT FUNCTION]]
// ... 100+ LINES OF CODE
// return addVectorsElementWise(a ,b, Math.min(a.length, b.length)) [[CURSOR IS HERE]]
// }
// ...
// ```
// the symbols of the parent function are ['const', 'addVectors', 'a', 'Vector', 'b', 'Vector']
// 2. Cmd+Click on each symbol in step 1. (view instances and definitions)
// check that you don't visit the same place twice
// if this location is new, get `` lines above and below this new location and save that string to an array
// 3. for each of the new positions found in step 2., use step 1 to find all their symbols again. This is the recursive step.
// use `maxRecursionDepth` to prevent slowness
// set `numNearbyLines` and `numParentLines` to 2 after the first step to increase performance
// 4. when finished, return snippets.join('\n----------------\n')
private _docSymbolsCache: {
[docUri: string]: {
version: number;
symbols: DocumentSymbol[];
};
} = Object.create(null);
// For each file, store per-symbol lookups we've done.
// e.g. _symbolLookupCache[docUri][fileVersion]["root"] => Location[] results
private _symbolLookupCache: {
[docUri: string]: {
[version: number]: {
[symbolName: string]: Location[];
};
};
} = Object.create(null);
private async _gatherRelevantContextForPosition(
model: ITextModel,
position: Position,
recursionDepth: number,
linesAround: number
maxRecursionDepth: number = 3,
numNearbyLines: number = 5,
numParentLines: number = 5,
numSaveLines: number = 10
): Promise<string> {
// We'll do a BFS-like approach: for each position or definition, gather lines around it,
// then attempt to find the definition of any symbols in that range, up to 'recursionDepth' times.
/****************************************************************************
* A. Quick Helpers & caches
****************************************************************************/
type EditorLocation = import('vs/editor/common/languages').Location;
// A set of "key" strings to avoid repeating the same location or line chunk
const visitedRanges = new Set<string>();
const collectedSnippets: string[] = [];
const docUri = model.uri.toString();
const fileVersion = model.getAlternativeVersionId();
// If you prefer, do a text-based hash or use model.getVersionId() instead.
// A queue of tasks, each being a tuple of: (model, position, depth)
const tasks: Array<{ model: ITextModel, position: Position, depth: number }> = [];
tasks.push({ model, position, depth: recursionDepth });
// 1) Ensure docSymbols cache
let docSymCache = this._docSymbolsCache[docUri];
if (!docSymCache || docSymCache.version !== fileVersion) {
docSymCache = {
version: fileVersion,
symbols: await this._getDocumentSymbolsOnce(model) // see helper below
};
this._docSymbolsCache[docUri] = docSymCache;
}
const allDocumentSymbols = docSymCache.symbols;
const getSnippetAroundLine = (model: ITextModel, lineNumber: number, linesAround: number): string => {
const startLine = Math.max(1, lineNumber - linesAround);
const endLine = Math.min(model.getLineCount(), lineNumber + linesAround);
// 2) Ensure symbol lookup cache
if (!this._symbolLookupCache[docUri]) {
this._symbolLookupCache[docUri] = {};
}
if (!this._symbolLookupCache[docUri][fileVersion]) {
this._symbolLookupCache[docUri][fileVersion] = {};
}
const symbolLookupForFile = this._symbolLookupCache[docUri][fileVersion];
// Basic numeric clamps
const clampLine = (line: number): number => {
const maxLine = model.getLineCount();
return Math.max(1, Math.min(line, maxLine));
};
// Return a snippet of lines [start..end] in the document
const snippetForRange = (startLine: number, endLine: number): string => {
const lines: string[] = [];
for (let i = startLine; i <= endLine; i++) {
lines.push(model.getLineContent(i));
for (let ln = startLine; ln <= endLine; ln++) {
lines.push(model.getLineContent(ln));
}
return lines.join('\n');
};
while (tasks.length > 0) {
const { model: currentModel, position: currentPos, depth } = tasks.shift()!;
if (depth < 0) {
continue;
/****************************************************************************
* B. Interval-based BFS to gather code blocks without duplication
****************************************************************************/
interface Interval { start: number; end: number; }
function addInterval(intervals: Interval[], start: number, end: number) {
// Merge new [start..end] with existing intervals if they overlap or touch
for (let i = 0; i < intervals.length; i++) {
const iv = intervals[i];
if (!(end < iv.start - 1 || start > iv.end + 1)) {
// Overlaps (or touches); merge
const mergedStart = Math.min(iv.start, start);
const mergedEnd = Math.max(iv.end, end);
intervals.splice(i, 1); // remove old
addInterval(intervals, mergedStart, mergedEnd); // re-run
return;
}
}
intervals.push({ start, end });
}
// Gather snippet around the current line
const snippet = getSnippetAroundLine(currentModel, currentPos.lineNumber, linesAround);
const snippetKey = `${currentModel.uri.toString()}:${currentPos.lineNumber}`;
if (!visitedRanges.has(snippetKey)) {
visitedRanges.add(snippetKey);
collectedSnippets.push(`-- Snippet around line ${currentPos.lineNumber} --\n${snippet}\n`);
}
function intervalsToString(intervals: Interval[]): string {
intervals.sort((a, b) => a.start - b.start);
return intervals
.map(iv => snippetForRange(iv.start, iv.end))
.join('\n------------------------------\n');
}
// Attempt to gather definitions for the symbol at this position
// We just pick all definition providers and see if any has a definition
const providers = this._langFeatureService.definitionProvider.ordered(currentModel);
for (const provider of providers) {
try {
const definitions = await provider.provideDefinition(currentModel, currentPos, CancellationToken.None);
if (!definitions) continue;
const intervals: Interval[] = [];
const visitedRanges = new Set<string>();
// definitions can be a single LocationLink or an array
const defArray: LocationLink[] = Array.isArray(definitions) ? definitions : [definitions];
for (const def of defArray) {
if (!def.uri) continue;
if (typeof def.range === 'undefined') continue;
const definitionModel = this._modelService.getModel(def.uri);
if (!definitionModel) continue;
function markVisited(s: number, e: number) { visitedRanges.add(`${s}-${e}`); }
function isVisited(s: number, e: number) { return visitedRanges.has(`${s}-${e}`); }
// We'll queue up a new task for that definition range
const defPos = new Position(def.range.startLineNumber, def.range.startColumn);
const defKey = `${def.uri.toString()}:${defPos.lineNumber}`;
if (!visitedRanges.has(defKey)) {
tasks.push({ model: definitionModel, position: defPos, depth: depth - 1 });
/****************************************************************************
* C. Compute initial intervals (cursor region, parent symbol region)
****************************************************************************/
const lineNumber = position.lineNumber;
const localStart = clampLine(lineNumber - numNearbyLines);
const localEnd = clampLine(lineNumber + numNearbyLines);
addInterval(intervals, clampLine(localStart - numSaveLines), clampLine(localEnd + numSaveLines));
markVisited(localStart, localEnd);
// get parent symbol, add interval for it
const parent = this._findEnclosingSymbol(allDocumentSymbols, lineNumber);
if (parent) {
const pStart = clampLine(parent.range.startLineNumber - numParentLines);
const pEnd = clampLine(parent.range.endLineNumber + numParentLines);
addInterval(intervals, pStart, pEnd);
markVisited(pStart, pEnd);
}
/****************************************************************************
* D. BFS data structures
****************************************************************************/
interface QItem { start: number; end: number; depth: number; }
const queue: QItem[] = [];
queue.push({ start: localStart, end: localEnd, depth: 1 });
if (parent) {
const pStart = clampLine(parent.range.startLineNumber - numParentLines);
const pEnd = clampLine(parent.range.endLineNumber + numParentLines);
queue.push({ start: pStart, end: pEnd, depth: 1 });
}
// We'll keep a set of symbols we've done "references + definitions" for:
const visitedSymbolNames = new Set<string>();
// Providers
const definitionProviders = this._langFeatureService.definitionProvider.ordered(model);
const referenceProviders = this._langFeatureService.referenceProvider.ordered(model);
/****************************************************************************
* E. BFS Loop
****************************************************************************/
while (queue.length) {
const { start, end, depth } = queue.shift()!;
if (depth >= maxRecursionDepth) continue;
// Step 1: Gather all symbols in [start..end]
const regionSyms = this._gatherSymbolsInLineRange(allDocumentSymbols, start, end);
// For each symbol, do references/defs once per symbol name
for (const sym of regionSyms) {
// If we already resolved that symbolName, skip
const symName = sym.name || '';
if (!symName) continue;
if (visitedSymbolNames.has(symName)) continue;
visitedSymbolNames.add(symName);
// If symbol was cached before, skip re-resolving references
if (symbolLookupForFile[symName]) {
// We already have references/definitions => merge them into intervals
const existingLocs = symbolLookupForFile[symName];
for (const loc of existingLocs) {
const rng = loc.range;
const locStart = clampLine(rng.startLineNumber - numSaveLines);
const locEnd = clampLine(rng.endLineNumber + numSaveLines);
if (!isVisited(locStart, locEnd)) {
markVisited(locStart, locEnd);
addInterval(intervals, locStart, locEnd);
queue.push({ start: locStart, end: locEnd, depth: depth + 1 });
}
}
continue;
}
// Not cached => actually ask definitionProviders / referenceProviders
const symPos = this._symbolPosition(sym); // see helper below
let foundLocs: EditorLocation[] = [];
for (const dp of definitionProviders) {
try {
const defs = await dp.provideDefinition(model, symPos, CancellationToken.None);
if (defs) foundLocs.push(...(Array.isArray(defs) ? defs : [defs]));
} catch {/* ignore */ }
}
for (const rp of referenceProviders) {
try {
const refs = await rp.provideReferences(
model, symPos, { includeDeclaration: true }, CancellationToken.None
);
if (refs) foundLocs.push(...refs);
} catch {/* ignore */ }
}
// Filter same-file only
foundLocs = foundLocs.filter(loc => loc.uri.toString() === docUri);
// Cache them
symbolLookupForFile[symName] = foundLocs;
// Enqueue each discovered reference/definition
for (const loc of foundLocs) {
const rng = loc.range;
const locStart = clampLine(rng.startLineNumber - numSaveLines);
const locEnd = clampLine(rng.endLineNumber + numSaveLines);
if (!isVisited(locStart, locEnd)) {
markVisited(locStart, locEnd);
addInterval(intervals, locStart, locEnd);
queue.push({ start: locStart, end: locEnd, depth: depth + 1 });
}
}
}
// Step 2: Also do naive token-scan for lines in [start..end],
// so e.g. 'root()' calls get recognized if not in docSymbols.
// We can do basically the same "cache symbol name" logic, if you want:
for (let ln = start; ln <= end; ln++) {
const text = model.getLineContent(ln);
const tokens = text.match(/[a-zA-Z_][a-zA-Z0-9_]*/g) || [];
for (const token of tokens) {
if (visitedSymbolNames.has(token)) continue;
visitedSymbolNames.add(token);
// If cached, merge intervals from cache
if (symbolLookupForFile[token]) {
for (const loc of symbolLookupForFile[token]) {
const rng = loc.range;
const locStart = clampLine(rng.startLineNumber - numSaveLines);
const locEnd = clampLine(rng.endLineNumber + numSaveLines);
if (!isVisited(locStart, locEnd)) {
markVisited(locStart, locEnd);
addInterval(intervals, locStart, locEnd);
queue.push({ start: locStart, end: locEnd, depth: depth + 1 });
}
}
continue;
}
// Actually compute definitions/references
const colIdx = text.indexOf(token);
if (colIdx < 0) continue; // should not happen, but just in case
const tokenPos = new Position(ln, colIdx + 1);
let foundLocs: EditorLocation[] = [];
for (const dp of definitionProviders) {
try {
const defs = await dp.provideDefinition(model, tokenPos, CancellationToken.None);
if (defs) foundLocs.push(...(Array.isArray(defs) ? defs : [defs]));
} catch {/* ignore */ }
}
for (const rp of referenceProviders) {
try {
const refs = await rp.provideReferences(
model, tokenPos, { includeDeclaration: true }, CancellationToken.None
);
if (refs) foundLocs.push(...refs);
} catch {/* ignore */ }
}
foundLocs = foundLocs.filter(loc => loc.uri.toString() === docUri);
// Cache them
symbolLookupForFile[token] = foundLocs;
// Add intervals
for (const loc of foundLocs) {
const rng = loc.range;
const locStart = clampLine(rng.startLineNumber - numSaveLines);
const locEnd = clampLine(rng.endLineNumber + numSaveLines);
if (!isVisited(locStart, locEnd)) {
markVisited(locStart, locEnd);
addInterval(intervals, locStart, locEnd);
queue.push({ start: locStart, end: locEnd, depth: depth + 1 });
}
}
} catch (err) {
// If a provider fails, ignore
}
}
}
// Return the joined context
return collectedSnippets.join('\n');
/****************************************************************************
* F. Finally, merge intervals and produce final snippet
****************************************************************************/
return intervalsToString(intervals);
}
/******************************************************************************
* Additional Helpers
******************************************************************************/
private async _getDocumentSymbolsOnce(model: ITextModel): Promise<DocumentSymbol[]> {
const providers = this._langFeatureService.documentSymbolProvider.ordered(model);
let result: DocumentSymbol[] = [];
for (const p of providers) {
try {
const syms = await p.provideDocumentSymbols(model, CancellationToken.None);
if (syms) {
result.push(...syms);
}
} catch {/* ignore */ }
}
return result;
}
private _findEnclosingSymbol(symbols: DocumentSymbol[], line: number): DocumentSymbol | undefined {
for (const s of symbols) {
if (s.range.startLineNumber <= line && s.range.endLineNumber >= line) {
// Recurse deeper
const child = this._findEnclosingSymbol(s.children || [], line);
return child || s;
}
}
return undefined;
}
private _symbolPosition(ds: DocumentSymbol): Position {
return new Position(ds.selectionRange.startLineNumber, ds.selectionRange.startColumn);
}
private _gatherSymbolsInLineRange(
symbols: DocumentSymbol[],
startLine: number,
endLine: number
): DocumentSymbol[] {
const out: DocumentSymbol[] = [];
for (const ds of symbols) {
if (ds.range.endLineNumber >= startLine && ds.range.startLineNumber <= endLine) {
out.push(ds);
}
if (ds.children?.length) {
out.push(...this._gatherSymbolsInLineRange(ds.children, startLine, endLine));
}
}
return out;
}
constructor(
@ILanguageFeaturesService private _langFeatureService: ILanguageFeaturesService,
@ILLMMessageService private readonly _llmMessageService: ILLMMessageService,
@ -966,8 +1299,10 @@ export class AutocompleteService extends Disposable implements IAutocompleteServ
// go through cached items and remove matching ones
// autocompletion.prefix + autocompletion.insertedText ~== insertedText
this._autocompletionsOfDocument[docUriStr].items.forEach((autocompletion: Autocompletion) => {
// const matchup = getAutocompletionMatchup({ prefix, autocompletion })
// we can do this more efficiently, I just didn't want to deal with all of the edge cases
const matchup = removeAllWhitespace(prefix) === removeAllWhitespace(autocompletion.prefix + autocompletion.insertText)
if (matchup) {
console.log('ACCEPT', autocompletion.id)
this._lastCompletionAccept = Date.now()