From 7decc8e14652455390db4cfb2d39e951825e17ea Mon Sep 17 00:00:00 2001 From: Andrew Pareles Date: Mon, 31 Mar 2025 20:35:40 -0700 Subject: [PATCH] add checkpoints --- .../contrib/void/browser/chatThreadService.ts | 540 ++++++++++-------- .../react/src/sidebar-tsx/SidebarChat.tsx | 4 +- .../void/common/chatThreadServiceTypes.ts | 17 +- .../contrib/void/common/voidModelService.ts | 10 +- 4 files changed, 311 insertions(+), 260 deletions(-) diff --git a/src/vs/workbench/contrib/void/browser/chatThreadService.ts b/src/vs/workbench/contrib/void/browser/chatThreadService.ts index dddd7d18..5553b32d 100644 --- a/src/vs/workbench/contrib/void/browser/chatThreadService.ts +++ b/src/vs/workbench/contrib/void/browser/chatThreadService.ts @@ -21,7 +21,7 @@ import { ToolName, ToolCallParams, ToolResultType, toolNamesThatRequireApproval, import { IToolsService } from './toolsService.js'; import { CancellationToken } from '../../../../base/common/cancellation.js'; import { ILanguageFeaturesService } from '../../../../editor/common/services/languageFeatures.js'; -import { ChatMessage, CodespanLocationLink, StagingSelectionItem, ToolRequestApproval } from '../common/chatThreadServiceTypes.js'; +import { ChatMessage, CheckpointEntry, CodespanLocationLink, StagingSelectionItem, ToolRequestApproval } from '../common/chatThreadServiceTypes.js'; import { Position } from '../../../../editor/common/core/position.js'; import { ITerminalToolService } from './terminalToolService.js'; import { IMetricsService } from '../common/metricsService.js'; @@ -29,16 +29,35 @@ import { shorten } from '../../../../base/common/labels.js'; import { IVoidModelService } from '../common/voidModelService.js'; import { IEditorService } from '../../../services/editor/common/editorService.js'; import { ICodeEditorService } from '../../../../editor/browser/services/codeEditorService.js'; +import { findLastIdx } from '../../../../base/common/arraysFind.js'; -const findLastIndex = (arr: T[], condition: (t: T) => boolean): number => { - for (let i = arr.length - 1; i >= 0; i--) { - if (condition(arr[i])) { - return i; - } - } - return -1; -} +type LLMCheckpoint = CheckpointEntry & { type: 'after_tool_edits' } +type UserCheckpoint = CheckpointEntry & { type: 'after_user_edits' } +/* +Checkpoints: +pivots: user | tool (edit) +if there are repeated pivots, a checkpoint goes directly after the last one +checkpoint_modifications always go directly after a checkpoint + +user +-- checkpoint -------- +assistant +tool (edit) + -------- checkpoint - starts here <-- know exact change (file A after) +assistant | +tool (edit) v +-- checkpoint -------- +assistant +tool (not edit) +assistant +user +-- checkpoint -------- user checkpoint (JIT) - compute change from all files to here when need to +-- checkpoint_modifications --------- - these always come DIRECLY after a checkpoint, and reflect the user's modifications on this one checkpoint only. + (only counts when reverting to/from this exact checkpoint, not past it). + Added when user jumps to another checkpoint but made changes here. + +*/ const toLLMChatMessages = (chatMessages: ChatMessage[]): LLMChatMessage[] => { @@ -75,11 +94,15 @@ type ThreadType = { id: string; // store the id here too createdAt: string; // ISO string lastModified: string; // ISO string + messages: ChatMessage[]; - currentHistoryIdx: number | null; // index in messages, ALWAYS points to a LLMHistoryEntry or UserHistoryEntry, or -1 if no changes. current code is inclusive of the current index's change + firstStrOfURI: { [fsPath: string]: string | undefined }; // part of checkpointing + // this doesn't need to go in a state object, but feels right state: { + latestCheckpointIdx: number | null; // the latest checkpoint we're standing at or null + stagingSelections: StagingSelectionItem[]; focusedMessageIdx: number | undefined; // index of the user message that is being edited (undefined if none) @@ -96,6 +119,7 @@ type ChatThreads = { } export const defaultThreadState: ThreadType['state'] = { + latestCheckpointIdx: null, stagingSelections: [], focusedMessageIdx: undefined, linksOfMessageIdx: {}, @@ -130,7 +154,7 @@ const newThreadObject = () => { lastModified: now, messages: [], state: defaultThreadState, - currentHistoryIdx: null, + firstStrOfURI: {}, } satisfies ThreadType } @@ -141,6 +165,8 @@ const newThreadObject = () => { export const THREAD_STORAGE_KEY = 'void.chatThreadStorageI' + + export interface IChatThreadService { readonly _serviceBrand: undefined; @@ -184,8 +210,8 @@ export interface IChatThreadService { addUserMessageAndStreamResponse({ userMessage, threadId }: { userMessage: string, threadId: string }): Promise; // approve/reject - approveTool(threadId: string): void; - rejectTool(threadId: string): void; + approveLatestToolRequest(threadId: string): void; + rejectLatestToolRequest(threadId: string): void; } export const IChatThreadService = createDecorator('voidChatThreadService'); @@ -311,225 +337,6 @@ class ChatThreadService extends Disposable implements IChatThreadService { } const threads = this._convertThreadDataFromStorage(threadsStr); - // threads['abc'] = { - // id: 'abc', - // createdAt: new Date().toISOString(), - // lastModified: new Date().toISOString(), - // messages: [ - // { - // role: 'tool', - // name: 'pathname_search', - // id: 'tool-1', - // paramsStr: '{"query": "hello", "pageNumber": 0}', - // content: '/users/andrew/void/Desktop/etc/abc.txt', - // result: { type: 'success', params: { queryStr: 'hello', pageNumber: 0 }, value: { uris: [URI.file('/Users/username/Downloads/helloooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo.txt'), URI.file('/Users/username/Downloads/hello1.txt'), URI.file('/Users/username/Downloads/hello2.txt'), URI.file('/Users/username/Downloads/hello3.txt'), URI.file('/Users/username/hello.txt')], hasNextPage: true } }, - // } satisfies ToolMessage<'pathname_search'>, - // { - // role: 'tool', - // name: 'pathname_search', - // id: 'tool-1', - // paramsStr: '{"query": "hello", "pageNumber": 0}', - // content: '/users/andrew/void/Desktop/etc/abc.txt', - // result: { type: 'success', params: { queryStr: 'hello', pageNumber: 0 }, value: { uris: [], hasNextPage: false } }, - // } satisfies ToolMessage<'pathname_search'>, - - // // { - // // role: 'tool_request', - // // name: 'pathname_search', - // // params: { queryStr: 'hello', pageNumber: 0 }, - // // paramsStr: '{"query": "hello", "pageNumber": 0}', - // // id: 'request-1', - // // } satisfies ToolRequestApproval<'pathname_search'>, - - // { - // role: 'tool', - // name: 'list_dir', - // id: 'tool-2', - // paramsStr: '{"uri": "/Users/username/Documents"}', - // content: 'Directory listing of /Users/username/Documents', - // result: { - // type: 'success', - // params: { rootURI: URI.file('/Users/username/Documents'), pageNumber: 1, }, - // value: { - // children: [ - // { uri: URI.file('/Users/username/Documents/file1.txt'), name: 'file1.txt', isDirectory: false, isSymbolicLink: false }, - // { uri: URI.file('/Users/username/Documents/folder1'), name: 'folder1', isDirectory: true, isSymbolicLink: false } - // ], - // hasNextPage: true, - // hasPrevPage: true, - // itemsRemaining: 5, - // } - // }, - // } satisfies ToolMessage<'list_dir'>, - - // // { - // // role: 'tool_request', - // // name: 'list_dir', - // // params: { rootURI: URI.file('/Users/username/Documents'), pageNumber: 0 }, - // // paramsStr: '{"uri": "/Users/username/Documents"}', - // // id: 'request-2', - // // } satisfies ToolRequestApproval<'list_dir'>, - - // { - // role: 'tool', - // name: 'read_file', - // id: 'tool-3', - // paramsStr: '{"uri": "/Users/username/Documents/file1.txt"}', - // content: 'Content of file1.txt\nThis is a sample file.\nHello world!', - // result: { - // type: 'success', - // params: { uri: URI.file('/src/vs/workbench/hi'), pageNumber: 0 }, - // value: { fileContents: 'Content of file1.txt\nThis is a sample file.\nHello world!', hasNextPage: false } - // }, - // } satisfies ToolMessage<'read_file'>, - - // // { - // // role: 'tool_request', - // // name: 'read_file', - // // params: { uri: URI.file('/Users/username/Documents/file1.txt'), pageNumber: 0 }, - // // paramsStr: '{"uri": "/Users/username/Documents/file1.txt"}', - // // id: 'request-3', - // // } satisfies ToolRequestApproval<'read_file'>, - - // { - // role: 'tool', - // name: 'grep_search', - // id: 'tool-4', - // paramsStr: '{"query": "function main"}', - // content: 'Found matches in 3 files', - // result: { - // type: 'success', - // params: { queryStr: 'function main', pageNumber: 0 }, - // value: { - // uris: [ - // URI.file('/Users/username/Project/main.js'), - // URI.file('/Users/username/Project/src/app.js'), - // URI.file('/Users/username/Project/test/test.js') - // ], - // hasNextPage: false - // } - // }, - // } satisfies ToolMessage<'grep_search'>, - - // // { - // // role: 'tool_request', - // // name: 'grep_search', - // // params: { queryStr: 'function main', pageNumber: 0 }, - // // paramsStr: '{"query": "function main"}', - // // id: 'request-4', - // // } satisfies ToolRequestApproval<'grep_search'>, - - // // --- - - // { - // role: 'tool', - // name: 'edit', - // id: 'tool-5', - // paramsStr: '{"uri": "/Users/username/Project/main.js", "changeDescription": "Add console.log statement"}', - // content: 'Successfully edited the file at /Users/username/Project/main.js', - // result: { - // type: 'success', - // params: { uri: URI.file('/Users/username/Project/main.js'), changeDescription: 'I think we should do this:\n```typescript\n//Add console.log statement\n for i in ...\n\t\tdo:\nabc\n```' }, - // value: Promise.resolve() - // }, - // } satisfies ToolMessage<'edit'>, - // { - // role: 'tool_request', - // name: 'edit', - // params: { uri: URI.file('/Users/username/Project/main.js'), changeDescription: 'I think we should do this:\n```typescript\n//Add console.log statement\n for i in ...\n\t\tdo:\nabc\n```' }, - // paramsStr: '{"uri": "/Users/username/Project/main.js", "changeDescription": "I think we should do this:```Add console.log statement\n for i in ...\n\t\tdo:\nabc```"}', - // id: 'request-5', - // } satisfies ToolRequestApproval<'edit'>, - - // { - // role: 'tool', - // name: 'create_uri', - // id: 'tool-6', - // paramsStr: '{"uri": "/Users/username/Project/new-file.js"}', - // content: 'Successfully created file at /Users/username/Project/new-file/', - // result: { - // type: 'success', - // params: { uri: URI.file('Users/andrew/Desktop/void/src/vs/workbench/hi/'), isFolder: true }, - // value: {} - // }, - // } satisfies ToolMessage<'create_uri'>, - // { - // role: 'tool_request', - // name: 'create_uri', - // params: { uri: URI.file('/Users/username/Project/new-file.js'), isFolder: false }, - // paramsStr: '{"uri": "/Users/username/Project/new-file.js"}', - // id: 'request-6', - // } satisfies ToolRequestApproval<'create_uri'>, - - // { - // role: 'tool', - // name: 'delete_uri', - // id: 'tool-7', - // paramsStr: '{"uri": "/Users/username/Project/old-file.js", "params": ""}', - // content: 'Successfully deleted file at /Users/username/Project/old-file.js', - // result: { - // type: 'success', - // params: { uri: URI.file('/Users/username/Project/old-file.js'), isRecursive: false, isFolder: false }, - // value: {} - // }, - // } satisfies ToolMessage<'delete_uri'>, - // { - // role: 'tool_request', - // name: 'delete_uri', - // params: { uri: URI.file('/Users/username/Project/old-file.js'), isRecursive: false, isFolder: false }, - // paramsStr: '{"uri": "/Users/username/Project/old-file.js", "params": ""}', - // id: 'request-7', - // } satisfies ToolRequestApproval<'delete_uri'>, - - // { - // role: 'tool', - // name: 'terminal_command', - // id: 'tool-8', - // paramsStr: '{"command": "npm install", "waitForCompletion": "true"}', - // content: 'Command executed: npm install\nAdded 123 packages in 3.5s', - // result: { - // type: 'success', - // params: { command: 'npm install', proposedTerminalId: '1', waitForCompletion: true }, - // value: { - // terminalId: '1', - // didCreateTerminal: false, - // result: 'Added 123 packages in 3.5s', - // resolveReason: { type: 'done', exitCode: 0 } - // } - // }, - // } satisfies ToolMessage<'terminal_command'>, - // { - // role: 'tool_request', - // name: 'terminal_command', - // params: { command: 'npm install', proposedTerminalId: '1', waitForCompletion: true }, - // paramsStr: '{"command": "npm install", "waitForCompletion": "true"}', - // id: 'request-8', - // } satisfies ToolRequestApproval<'terminal_command'>, - - - - // // Examples of error and rejected states - // { - // role: 'tool', - // name: 'pathname_search', - // id: 'tool-error', - // paramsStr: '{"query": "invalid**query"}', - // content: 'Error: Invalid search pattern', - // result: { type: 'error', params: { queryStr: 'invalid**query', pageNumber: 0 }, value: 'Error: Invalid search pattern' }, - // } satisfies ToolMessage<'pathname_search'>, - - // { - // role: 'tool', - // name: 'pathname_search', - // id: 'tool-rejected', - // paramsStr: '{"query": "sensitive-data"}', - // content: 'Tool call was rejected by the user.', - // result: { type: 'rejected', params: { queryStr: 'sensitive-data', pageNumber: 0 } }, - // } satisfies ToolMessage<'pathname_search'>, - // ], - // state: defaultThreadState, - // } - return threads } @@ -631,7 +438,7 @@ class ChatThreadService extends Disposable implements IChatThreadService { } - approveTool(threadId: string) { + approveLatestToolRequest(threadId: string) { const thread = this.state.allThreads[threadId] if (!thread) return // should never happen @@ -639,7 +446,7 @@ class ChatThreadService extends Disposable implements IChatThreadService { const lastMessage = thread.messages[thread.messages.length - 1] if (lastMessage.role !== 'tool_request') return // should never happen - const lastUserMsgIdx = findLastIndex(thread.messages, m => m.role === 'user') + const lastUserMsgIdx = findLastIdx(thread.messages, m => m.role === 'user') const lastUserMessage = thread.messages[lastUserMsgIdx] as ChatMessage & { role: 'user' } if (lastUserMsgIdx === -1 || !lastUserMessage) return // should never happen @@ -651,7 +458,7 @@ class ChatThreadService extends Disposable implements IChatThreadService { this._runChatAgent({ callThisToolFirst, prevSelns, currSelns, threadId, userMessageContent: instructions, ...this._currentModelSelectionProps() }) } - rejectTool(threadId: string) { + rejectLatestToolRequest(threadId: string) { const thread = this.state.allThreads[threadId] if (!thread) return // should never happen @@ -670,7 +477,7 @@ class ChatThreadService extends Disposable implements IChatThreadService { const isRunning = this.streamState[threadId]?.isRunning // reject the tool for the user if (isRunning === 'awaiting_user') { - this.rejectTool(threadId) + this.rejectLatestToolRequest(threadId) } // interrupt the tool else if (isRunning === 'tool') { @@ -741,7 +548,7 @@ class ChatThreadService extends Disposable implements IChatThreadService { const thread = this.state.allThreads[threadId] const latestMessages = thread?.messages ?? [] const messages_ = toLLMChatMessages(latestMessages) - const lastUserMsgIdx = findLastIndex(messages_, m => m.role === 'user') + const lastUserMsgIdx = findLastIdx(messages_, m => m.role === 'user') if (lastUserMsgIdx === -1) return [] // should never happen (or how did they send the message?!) // system message @@ -944,11 +751,226 @@ class ChatThreadService extends Disposable implements IChatThreadService { } - async callWhenJumpBackToIdx(toIdx: number) { - // TODO!!! + // merge any LLM checkpoint before this one (and after a user checkpoint if one exists), and add the checkpoint + // call this right after LLM edits a file + addOrUpdateToolEditCheckpoint({ threadId, uri, }: { threadId: string, uri: URI }) { + const thread = this.state.allThreads[threadId] + if (!thread) return + const { model } = this._voidModelService.getModel(uri) + if (!model) return // should never happen + + const lastUserCheckpointIdx = findLastIdx(thread.messages, (m) => m.role === 'checkpoint' && m.type === 'after_user_edits') + const prevLLMCheckpointIdx = thread.messages.findIndex((m, i) => i > lastUserCheckpointIdx && m.role === 'checkpoint' && m.type === 'after_tool_edits') + + const afterStr = model.getValue() // afterStr = the value of the file right after the edit + + let prevLLMCheckpoint: LLMCheckpoint | undefined = undefined + if (prevLLMCheckpointIdx !== -1) { + prevLLMCheckpoint = thread.messages[prevLLMCheckpointIdx] as ChatMessage & { role: 'checkpoint', type: 'after_tool_edits' } + this._removeMessageFromThread(threadId, prevLLMCheckpointIdx) + } + const newLLMCheckpoint: LLMCheckpoint = { + role: 'checkpoint', + type: 'after_tool_edits', + afterStrOfURI: { + ...prevLLMCheckpoint?.afterStrOfURI, + [uri.fsPath]: afterStr, + }, + } + console.log('NEW LLM CHECKPOINT', newLLMCheckpoint, JSON.stringify(this.state.allThreads[threadId], null, 2)) + this._addMessageToThread(threadId, newLLMCheckpoint) } + // user checkpoints are always computed JIT + // we assume there are no messages after the checkpoint we're adding here + // call this right before user sends message + addOrUpdateUserMessageCheckpoint({ threadId, }: { threadId: string, }) { + const thread = this.state.allThreads[threadId] + if (!thread) return + + const newUserCheckpoint: UserCheckpoint = { + role: 'checkpoint', + type: 'after_user_edits', // user backup + afterStrOfURI: {}, + } + + // first get the last user checkpoint + const lastNonUserCheckpointIdx = findLastIdx(thread.messages, (m) => m.role === 'checkpoint' && m.type !== 'after_user_edits') + + // merge all recent user checkpoints + const latestAfterStrOfURI: { [fsPath: string]: string } = {} // helps merge user edits + for (let k = 0; k <= thread.messages.length; k += 1) { + const message = thread.messages[k] + if (message.role !== 'checkpoint') continue + for (const uri in message.afterStrOfURI) + latestAfterStrOfURI[uri] = message.afterStrOfURI[uri] + + // remove any user messages that come after the last LLM checkpoint (we're merging them into one big user message) + if (k > lastNonUserCheckpointIdx) + this._removeMessageFromThread(threadId, k) + } + + // compute afterStr of all files we detected, and if they're different, add them as a user edit + for (const fsPath in latestAfterStrOfURI) { + const uri = URI.file(fsPath) + const { model } = this._voidModelService.getModel(uri) + if (!model) continue + const oldAfterStr = latestAfterStrOfURI[uri.fsPath] + const currentAfterStr = model.getValue() + if (oldAfterStr === currentAfterStr) continue + // if there was a change, add it as a user edit + newUserCheckpoint.afterStrOfURI = { + ...newUserCheckpoint.afterStrOfURI, + [uri.fsPath]: currentAfterStr + } + } + + + this._addMessageToThread(threadId, newUserCheckpoint) + + // update latest checkpoint idx to the one we just added + const newThread = this.state.allThreads[threadId] + if (!newThread) return // should never happen + const latestCheckpointIdx = newThread.messages.length - 1 + this._setThreadState(threadId, { latestCheckpointIdx }) + + + console.log('NEW USER CHECKPOINT', latestCheckpointIdx, newUserCheckpoint, JSON.stringify(this.state.allThreads[threadId], null, 2)) + } + + + private _getCheckpointAfter = ({ threadId, messageIdx: afterIdx }: { threadId: string, messageIdx: number }): [CheckpointEntry, number] | undefined => { + const thread = this.state.allThreads[threadId] + if (!thread) return undefined + for (let i = afterIdx; i < thread.messages.length; i++) { + const message = thread.messages[i] + if (message.role === 'checkpoint') { + return [message, i] + } + } + return undefined + } + + private _getAllChangedCheckpointURIs({ threadId, fromIdx, toIdx }: { threadId: string, fromIdx: number, toIdx: number }) { + const thread = this.state.allThreads[threadId] + if (!thread) return null // should never happen + const fsPaths: Set = new Set() + for (let i = fromIdx; i <= toIdx; i += 1) { + const message = thread.messages[i] + if (message.role !== 'checkpoint') continue + for (const fsPath in message.afterStrOfURI) { + fsPaths.add(fsPath) + } + } + return fsPaths + } + + jumpToCheckpointAfterMessageIdx({ threadId, messageIdx }: { threadId: string, messageIdx: number }) { + const thread = this.state.allThreads[threadId] + if (!thread) return + + const c = this._getCheckpointAfter({ threadId, messageIdx }) + if (c === undefined) return // should never happen + + const fromIdx = thread.state.latestCheckpointIdx + if (fromIdx === null) return // should never happen + + // TODO!!! change toIdx if there's a checkpointModification on the To, and add a checkpoint modification on the from + const [_, toIdx_] = c + const toIdx = toIdx_ + if (toIdx === fromIdx) return + + const writeFullFile = ({ fsPath, text }: { fsPath: string, text: string }) => { + const { model } = this._voidModelService.getModelFromFsPath(fsPath) + if (!model) return // should never happen + model.applyEdits([{ + range: { startLineNumber: 1, startColumn: 1, endLineNumber: model.getLineCount(), endColumn: Number.MAX_SAFE_INTEGER }, // whole file + text + }]) + } + + /* +if undoing + +A,B,C are all files. +x means a checkpoint where the file changed. + +A B C D E F G H I +x x x x x x x x x +| | | | | | | | | +x | | | | | | | x +---x-|-|-|-x-|-x-|----- <-- to + x | | | | | x + | | x x | + | | | | +-------x-|---x-x------- <-- from + x + +We need to revert anything that happened between to+1 and from. +**We do this by finding the last x from 0...`to` for each file and applying those contents.** +We only need to do it for files that were edited since `to`, ie files between to+1...from. +*/ + if (toIdx < fromIdx) { + const checkpointURIs = this._getAllChangedCheckpointURIs({ threadId, toIdx: toIdx + 1, fromIdx }) + for (const fsPath of checkpointURIs ?? []) { + let found = false + + // apply lowest down content for each uri (or original if not found) + + for (let k = toIdx; k >= 0; k -= 1) { + const message = thread.messages[k] + if (message.role !== 'checkpoint') continue + if (fsPath in message.afterStrOfURI) { + found = true + writeFullFile({ fsPath, text: message.afterStrOfURI[fsPath] }) + break + } + } + if (!found) { + const originalStr = thread.firstStrOfURI[fsPath] + if (originalStr === undefined) continue + writeFullFile({ fsPath, text: originalStr }) + } + } + } + + /* +if redoing + +A B C D E F G H I +x x x x x x x x x +| | | | | | | | | +x | | | | | | | x +---x-|-|-|-x-|-x-|----- <-- from + x | | | | | x + | | x x | + | | | | +-------x-|---x-x------- <-- to + x + +We need to apply latest change for anything that happened between from+1 and to. +We only need to do it for files that were edited since `from`, ie files between from+1...to. +*/ + if (toIdx > fromIdx) { + const checkpointURIs = this._getAllChangedCheckpointURIs({ threadId, fromIdx: fromIdx + 1, toIdx }) + for (const fsPath of checkpointURIs ?? []) { + // apply lowest down content for each uri + // (do not need to apply original since we're only applying to files that changed) + for (let k = toIdx; k >= fromIdx + 1; k -= 1) { + const message = thread.messages[k] + if (message.role !== 'checkpoint') continue + if (fsPath in message.afterStrOfURI) { + writeFullFile({ fsPath, text: message.afterStrOfURI[fsPath] }) + break + } + } + } + } + + this._setThreadState(threadId, { latestCheckpointIdx: toIdx }) + // TODO!!! add/merge a checkpoint modification if relevant + } async addUserMessageAndStreamResponse({ userMessage, _chatSelections, threadId }: { userMessage: string, _chatSelections?: { prevSelns?: StagingSelectionItem[], currSelns?: StagingSelectionItem[], }, threadId: string }) { @@ -970,6 +992,7 @@ class ChatThreadService extends Disposable implements IChatThreadService { const userHistoryElt: ChatMessage = { role: 'user', content: userMessageContent, displayContent: instructions, selections: currSelns, state: defaultMessageState } this._addMessageToThread(threadId, userHistoryElt) + this.addOrUpdateUserMessageCheckpoint({ threadId }) this._runChatAgent({ prevSelns, currSelns, threadId, userMessageContent, ...this._currentModelSelectionProps(), }) } @@ -1255,7 +1278,7 @@ class ChatThreadService extends Disposable implements IChatThreadService { // add the current file as a staging selection const model = this._codeEditorService.getActiveCodeEditor()?.getModel() if (model) { - this._setCurrentThreadState({ + this._setThreadState(this.state.currentThreadId, { ...defaultThreadState, stagingSelections: [{ type: 'File', @@ -1286,25 +1309,48 @@ class ChatThreadService extends Disposable implements IChatThreadService { } - _addMessageToThread(threadId: string, message: ChatMessage) { + private _addMessageToThread(threadId: string, message: ChatMessage) { const { allThreads } = this.state - const oldThread = allThreads[threadId] if (!oldThread) return // should never happen - // update state and store it const newThreads = { ...allThreads, [oldThread.id]: { ...oldThread, lastModified: new Date().toISOString(), - messages: [...oldThread.messages, message], + messages: [ + ...oldThread.messages, + message + ], } } this._storeAllThreads(newThreads) this._setState({ allThreads: newThreads }, true) // the current thread just changed (it had a message added to it) } + + private _removeMessageFromThread(threadId: string, messageIdx: number) { + const { allThreads } = this.state + const oldThread = allThreads[threadId] + if (!oldThread) return // should never happen + // update state and store it + const newThreads = { + ...allThreads, + [oldThread.id]: { + ...oldThread, + lastModified: new Date().toISOString(), + messages: [ + ...oldThread.messages.slice(0, messageIdx), + ...oldThread.messages.slice(messageIdx + 1, Infinity), + ], + } + } + this._storeAllThreads(newThreads) + this._setState({ allThreads: newThreads }, true) // the current thread just changed (it had a message added to it) + } + + // sets the currently selected message (must be undefined if no message is selected) setCurrentlyFocusedMessageIdx(messageIdx: number | undefined) { @@ -1354,9 +1400,7 @@ class ChatThreadService extends Disposable implements IChatThreadService { } // set thread.state - private _setCurrentThreadState(state: Partial): void { - - const threadId = this.state.currentThreadId + private _setThreadState(threadId: string, state: Partial): void { const thread = this.state.allThreads[threadId] if (!thread) return @@ -1409,7 +1453,7 @@ class ChatThreadService extends Disposable implements IChatThreadService { return currentThread.state } setCurrentThreadState = (newState: Partial) => { - this._setCurrentThreadState(newState) + this._setThreadState(this.state.currentThreadId, newState) } // gets `staging` and `setStaging` of the currently focused element, given the index of the currently selected message (or undefined if no message is selected) diff --git a/src/vs/workbench/contrib/void/browser/react/src/sidebar-tsx/SidebarChat.tsx b/src/vs/workbench/contrib/void/browser/react/src/sidebar-tsx/SidebarChat.tsx index 00b26d15..8eb3eebb 100644 --- a/src/vs/workbench/contrib/void/browser/react/src/sidebar-tsx/SidebarChat.tsx +++ b/src/vs/workbench/contrib/void/browser/react/src/sidebar-tsx/SidebarChat.tsx @@ -1248,7 +1248,7 @@ const ToolRequestAcceptRejectButtons = () => { const onAccept = useCallback(() => { try { // this doesn't need to be wrapped in try/catch anymore const threadId = chatThreadsService.state.currentThreadId - chatThreadsService.approveTool(threadId) + chatThreadsService.approveLatestToolRequest(threadId) metricsService.capture('Tool Request Accepted', {}) } catch (e) { console.error('Error while approving message in chat:', e) } }, [chatThreadsService, metricsService]) @@ -1256,7 +1256,7 @@ const ToolRequestAcceptRejectButtons = () => { const onReject = useCallback(() => { try { const threadId = chatThreadsService.state.currentThreadId - chatThreadsService.rejectTool(threadId) + chatThreadsService.rejectLatestToolRequest(threadId) } catch (e) { console.error('Error while approving message in chat:', e) } metricsService.capture('Tool Request Rejected', {}) }, [chatThreadsService, metricsService]) diff --git a/src/vs/workbench/contrib/void/common/chatThreadServiceTypes.ts b/src/vs/workbench/contrib/void/common/chatThreadServiceTypes.ts index 7dfe39e5..b4af2b89 100644 --- a/src/vs/workbench/contrib/void/common/chatThreadServiceTypes.ts +++ b/src/vs/workbench/contrib/void/common/chatThreadServiceTypes.ts @@ -26,16 +26,18 @@ export type ToolRequestApproval = { // checkpoints -export type LLMHistoryEntry = { // ALWAYS comes right after a {role:'tool', name:'edit'} message - role: 'LLM_changes'; - afterStrOfURI: { [fsPath: string]: string }; -} -export type UserHistoryEntry = { // ALWAYS comes right before a {role:'user'} message, or if it's the last message (w/o a user message yet) - role: 'user_changes'; +export type CheckpointEntry = { + role: 'checkpoint'; + type: 'after_user_edits' | 'after_tool_edits'; + afterStrOfURI: { [fsPath: string]: string }; +} | { // modifications that only count when undoing/redoing + role: 'checkpoint_modification'; + type: 'user_modifications'; afterStrOfURI: { [fsPath: string]: string }; } + // WARNING: changing this format is a big deal!!!!!! need to migrate old format to new format on users' computers so people don't get errors. export type ChatMessage = | { @@ -56,8 +58,7 @@ export type ChatMessage = } | ToolMessage | ToolRequestApproval - | LLMHistoryEntry // invisible - | UserHistoryEntry // invisible + | CheckpointEntry // one of the square items that indicates a selection in a chat bubble (NOT a file, a Selection of text) diff --git a/src/vs/workbench/contrib/void/common/voidModelService.ts b/src/vs/workbench/contrib/void/common/voidModelService.ts index 8cbf4ac9..a463b50a 100644 --- a/src/vs/workbench/contrib/void/common/voidModelService.ts +++ b/src/vs/workbench/contrib/void/common/voidModelService.ts @@ -14,6 +14,7 @@ export interface IVoidModelService { readonly _serviceBrand: undefined; initializeModel(uri: URI): Promise; getModel(uri: URI): VoidModelType; + getModelFromFsPath(fsPath: string): VoidModelType; getModelSafe(uri: URI): Promise; } @@ -37,8 +38,8 @@ class VoidModelService extends Disposable implements IVoidModelService { this._modelRefOfURI[uri.fsPath] = editorModelRef; }; - getModel = (uri: URI): VoidModelType => { - const editorModelRef = this._modelRefOfURI[uri.fsPath]; + getModelFromFsPath = (fsPath: string): VoidModelType => { + const editorModelRef = this._modelRefOfURI[fsPath]; if (!editorModelRef) { return { model: null, editorModel: null }; } @@ -52,6 +53,11 @@ class VoidModelService extends Disposable implements IVoidModelService { return { model, editorModel: editorModelRef.object }; }; + getModel = (uri: URI) => { + return this.getModelFromFsPath(uri.fsPath) + } + + getModelSafe = async (uri: URI): Promise => { if (!(uri.fsPath in this._modelRefOfURI)) await this.initializeModel(uri); return this.getModel(uri);