add checkpoints

This commit is contained in:
Andrew Pareles 2025-03-31 20:35:40 -07:00
parent 7c0ba71314
commit 7decc8e146
4 changed files with 311 additions and 260 deletions

View file

@ -21,7 +21,7 @@ import { ToolName, ToolCallParams, ToolResultType, toolNamesThatRequireApproval,
import { IToolsService } from './toolsService.js';
import { CancellationToken } from '../../../../base/common/cancellation.js';
import { ILanguageFeaturesService } from '../../../../editor/common/services/languageFeatures.js';
import { ChatMessage, CodespanLocationLink, StagingSelectionItem, ToolRequestApproval } from '../common/chatThreadServiceTypes.js';
import { ChatMessage, CheckpointEntry, CodespanLocationLink, StagingSelectionItem, ToolRequestApproval } from '../common/chatThreadServiceTypes.js';
import { Position } from '../../../../editor/common/core/position.js';
import { ITerminalToolService } from './terminalToolService.js';
import { IMetricsService } from '../common/metricsService.js';
@ -29,16 +29,35 @@ import { shorten } from '../../../../base/common/labels.js';
import { IVoidModelService } from '../common/voidModelService.js';
import { IEditorService } from '../../../services/editor/common/editorService.js';
import { ICodeEditorService } from '../../../../editor/browser/services/codeEditorService.js';
import { findLastIdx } from '../../../../base/common/arraysFind.js';
const findLastIndex = <T>(arr: T[], condition: (t: T) => boolean): number => {
for (let i = arr.length - 1; i >= 0; i--) {
if (condition(arr[i])) {
return i;
}
}
return -1;
}
type LLMCheckpoint = CheckpointEntry & { type: 'after_tool_edits' }
type UserCheckpoint = CheckpointEntry & { type: 'after_user_edits' }
/*
Checkpoints:
pivots: user | tool (edit)
if there are repeated pivots, a checkpoint goes directly after the last one
checkpoint_modifications always go directly after a checkpoint
user
-- checkpoint --------
assistant
tool (edit)
-------- checkpoint - starts here <-- know exact change (file A after)
assistant |
tool (edit) v
-- checkpoint --------
assistant
tool (not edit)
assistant
user
-- checkpoint -------- user checkpoint (JIT) - compute change from all files to here when need to
-- checkpoint_modifications --------- - these always come DIRECLY after a checkpoint, and reflect the user's modifications on this one checkpoint only.
(only counts when reverting to/from this exact checkpoint, not past it).
Added when user jumps to another checkpoint but made changes here.
*/
const toLLMChatMessages = (chatMessages: ChatMessage[]): LLMChatMessage[] => {
@ -75,11 +94,15 @@ type ThreadType = {
id: string; // store the id here too
createdAt: string; // ISO string
lastModified: string; // ISO string
messages: ChatMessage[];
currentHistoryIdx: number | null; // index in messages, ALWAYS points to a LLMHistoryEntry or UserHistoryEntry, or -1 if no changes. current code is inclusive of the current index's change
firstStrOfURI: { [fsPath: string]: string | undefined }; // part of checkpointing
// this doesn't need to go in a state object, but feels right
state: {
latestCheckpointIdx: number | null; // the latest checkpoint we're standing at or null
stagingSelections: StagingSelectionItem[];
focusedMessageIdx: number | undefined; // index of the user message that is being edited (undefined if none)
@ -96,6 +119,7 @@ type ChatThreads = {
}
export const defaultThreadState: ThreadType['state'] = {
latestCheckpointIdx: null,
stagingSelections: [],
focusedMessageIdx: undefined,
linksOfMessageIdx: {},
@ -130,7 +154,7 @@ const newThreadObject = () => {
lastModified: now,
messages: [],
state: defaultThreadState,
currentHistoryIdx: null,
firstStrOfURI: {},
} satisfies ThreadType
}
@ -141,6 +165,8 @@ const newThreadObject = () => {
export const THREAD_STORAGE_KEY = 'void.chatThreadStorageI'
export interface IChatThreadService {
readonly _serviceBrand: undefined;
@ -184,8 +210,8 @@ export interface IChatThreadService {
addUserMessageAndStreamResponse({ userMessage, threadId }: { userMessage: string, threadId: string }): Promise<void>;
// approve/reject
approveTool(threadId: string): void;
rejectTool(threadId: string): void;
approveLatestToolRequest(threadId: string): void;
rejectLatestToolRequest(threadId: string): void;
}
export const IChatThreadService = createDecorator<IChatThreadService>('voidChatThreadService');
@ -311,225 +337,6 @@ class ChatThreadService extends Disposable implements IChatThreadService {
}
const threads = this._convertThreadDataFromStorage(threadsStr);
// threads['abc'] = {
// id: 'abc',
// createdAt: new Date().toISOString(),
// lastModified: new Date().toISOString(),
// messages: [
// {
// role: 'tool',
// name: 'pathname_search',
// id: 'tool-1',
// paramsStr: '{"query": "hello", "pageNumber": 0}',
// content: '/users/andrew/void/Desktop/etc/abc.txt',
// result: { type: 'success', params: { queryStr: 'hello', pageNumber: 0 }, value: { uris: [URI.file('/Users/username/Downloads/helloooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo.txt'), URI.file('/Users/username/Downloads/hello1.txt'), URI.file('/Users/username/Downloads/hello2.txt'), URI.file('/Users/username/Downloads/hello3.txt'), URI.file('/Users/username/hello.txt')], hasNextPage: true } },
// } satisfies ToolMessage<'pathname_search'>,
// {
// role: 'tool',
// name: 'pathname_search',
// id: 'tool-1',
// paramsStr: '{"query": "hello", "pageNumber": 0}',
// content: '/users/andrew/void/Desktop/etc/abc.txt',
// result: { type: 'success', params: { queryStr: 'hello', pageNumber: 0 }, value: { uris: [], hasNextPage: false } },
// } satisfies ToolMessage<'pathname_search'>,
// // {
// // role: 'tool_request',
// // name: 'pathname_search',
// // params: { queryStr: 'hello', pageNumber: 0 },
// // paramsStr: '{"query": "hello", "pageNumber": 0}',
// // id: 'request-1',
// // } satisfies ToolRequestApproval<'pathname_search'>,
// {
// role: 'tool',
// name: 'list_dir',
// id: 'tool-2',
// paramsStr: '{"uri": "/Users/username/Documents"}',
// content: 'Directory listing of /Users/username/Documents',
// result: {
// type: 'success',
// params: { rootURI: URI.file('/Users/username/Documents'), pageNumber: 1, },
// value: {
// children: [
// { uri: URI.file('/Users/username/Documents/file1.txt'), name: 'file1.txt', isDirectory: false, isSymbolicLink: false },
// { uri: URI.file('/Users/username/Documents/folder1'), name: 'folder1', isDirectory: true, isSymbolicLink: false }
// ],
// hasNextPage: true,
// hasPrevPage: true,
// itemsRemaining: 5,
// }
// },
// } satisfies ToolMessage<'list_dir'>,
// // {
// // role: 'tool_request',
// // name: 'list_dir',
// // params: { rootURI: URI.file('/Users/username/Documents'), pageNumber: 0 },
// // paramsStr: '{"uri": "/Users/username/Documents"}',
// // id: 'request-2',
// // } satisfies ToolRequestApproval<'list_dir'>,
// {
// role: 'tool',
// name: 'read_file',
// id: 'tool-3',
// paramsStr: '{"uri": "/Users/username/Documents/file1.txt"}',
// content: 'Content of file1.txt\nThis is a sample file.\nHello world!',
// result: {
// type: 'success',
// params: { uri: URI.file('/src/vs/workbench/hi'), pageNumber: 0 },
// value: { fileContents: 'Content of file1.txt\nThis is a sample file.\nHello world!', hasNextPage: false }
// },
// } satisfies ToolMessage<'read_file'>,
// // {
// // role: 'tool_request',
// // name: 'read_file',
// // params: { uri: URI.file('/Users/username/Documents/file1.txt'), pageNumber: 0 },
// // paramsStr: '{"uri": "/Users/username/Documents/file1.txt"}',
// // id: 'request-3',
// // } satisfies ToolRequestApproval<'read_file'>,
// {
// role: 'tool',
// name: 'grep_search',
// id: 'tool-4',
// paramsStr: '{"query": "function main"}',
// content: 'Found matches in 3 files',
// result: {
// type: 'success',
// params: { queryStr: 'function main', pageNumber: 0 },
// value: {
// uris: [
// URI.file('/Users/username/Project/main.js'),
// URI.file('/Users/username/Project/src/app.js'),
// URI.file('/Users/username/Project/test/test.js')
// ],
// hasNextPage: false
// }
// },
// } satisfies ToolMessage<'grep_search'>,
// // {
// // role: 'tool_request',
// // name: 'grep_search',
// // params: { queryStr: 'function main', pageNumber: 0 },
// // paramsStr: '{"query": "function main"}',
// // id: 'request-4',
// // } satisfies ToolRequestApproval<'grep_search'>,
// // ---
// {
// role: 'tool',
// name: 'edit',
// id: 'tool-5',
// paramsStr: '{"uri": "/Users/username/Project/main.js", "changeDescription": "Add console.log statement"}',
// content: 'Successfully edited the file at /Users/username/Project/main.js',
// result: {
// type: 'success',
// params: { uri: URI.file('/Users/username/Project/main.js'), changeDescription: 'I think we should do this:\n```typescript\n//Add console.log statement\n for i in ...\n\t\tdo:\nabc\n```' },
// value: Promise.resolve()
// },
// } satisfies ToolMessage<'edit'>,
// {
// role: 'tool_request',
// name: 'edit',
// params: { uri: URI.file('/Users/username/Project/main.js'), changeDescription: 'I think we should do this:\n```typescript\n//Add console.log statement\n for i in ...\n\t\tdo:\nabc\n```' },
// paramsStr: '{"uri": "/Users/username/Project/main.js", "changeDescription": "I think we should do this:```Add console.log statement\n for i in ...\n\t\tdo:\nabc```"}',
// id: 'request-5',
// } satisfies ToolRequestApproval<'edit'>,
// {
// role: 'tool',
// name: 'create_uri',
// id: 'tool-6',
// paramsStr: '{"uri": "/Users/username/Project/new-file.js"}',
// content: 'Successfully created file at /Users/username/Project/new-file/',
// result: {
// type: 'success',
// params: { uri: URI.file('Users/andrew/Desktop/void/src/vs/workbench/hi/'), isFolder: true },
// value: {}
// },
// } satisfies ToolMessage<'create_uri'>,
// {
// role: 'tool_request',
// name: 'create_uri',
// params: { uri: URI.file('/Users/username/Project/new-file.js'), isFolder: false },
// paramsStr: '{"uri": "/Users/username/Project/new-file.js"}',
// id: 'request-6',
// } satisfies ToolRequestApproval<'create_uri'>,
// {
// role: 'tool',
// name: 'delete_uri',
// id: 'tool-7',
// paramsStr: '{"uri": "/Users/username/Project/old-file.js", "params": ""}',
// content: 'Successfully deleted file at /Users/username/Project/old-file.js',
// result: {
// type: 'success',
// params: { uri: URI.file('/Users/username/Project/old-file.js'), isRecursive: false, isFolder: false },
// value: {}
// },
// } satisfies ToolMessage<'delete_uri'>,
// {
// role: 'tool_request',
// name: 'delete_uri',
// params: { uri: URI.file('/Users/username/Project/old-file.js'), isRecursive: false, isFolder: false },
// paramsStr: '{"uri": "/Users/username/Project/old-file.js", "params": ""}',
// id: 'request-7',
// } satisfies ToolRequestApproval<'delete_uri'>,
// {
// role: 'tool',
// name: 'terminal_command',
// id: 'tool-8',
// paramsStr: '{"command": "npm install", "waitForCompletion": "true"}',
// content: 'Command executed: npm install\nAdded 123 packages in 3.5s',
// result: {
// type: 'success',
// params: { command: 'npm install', proposedTerminalId: '1', waitForCompletion: true },
// value: {
// terminalId: '1',
// didCreateTerminal: false,
// result: 'Added 123 packages in 3.5s',
// resolveReason: { type: 'done', exitCode: 0 }
// }
// },
// } satisfies ToolMessage<'terminal_command'>,
// {
// role: 'tool_request',
// name: 'terminal_command',
// params: { command: 'npm install', proposedTerminalId: '1', waitForCompletion: true },
// paramsStr: '{"command": "npm install", "waitForCompletion": "true"}',
// id: 'request-8',
// } satisfies ToolRequestApproval<'terminal_command'>,
// // Examples of error and rejected states
// {
// role: 'tool',
// name: 'pathname_search',
// id: 'tool-error',
// paramsStr: '{"query": "invalid**query"}',
// content: 'Error: Invalid search pattern',
// result: { type: 'error', params: { queryStr: 'invalid**query', pageNumber: 0 }, value: 'Error: Invalid search pattern' },
// } satisfies ToolMessage<'pathname_search'>,
// {
// role: 'tool',
// name: 'pathname_search',
// id: 'tool-rejected',
// paramsStr: '{"query": "sensitive-data"}',
// content: 'Tool call was rejected by the user.',
// result: { type: 'rejected', params: { queryStr: 'sensitive-data', pageNumber: 0 } },
// } satisfies ToolMessage<'pathname_search'>,
// ],
// state: defaultThreadState,
// }
return threads
}
@ -631,7 +438,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
}
approveTool(threadId: string) {
approveLatestToolRequest(threadId: string) {
const thread = this.state.allThreads[threadId]
if (!thread) return // should never happen
@ -639,7 +446,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
const lastMessage = thread.messages[thread.messages.length - 1]
if (lastMessage.role !== 'tool_request') return // should never happen
const lastUserMsgIdx = findLastIndex(thread.messages, m => m.role === 'user')
const lastUserMsgIdx = findLastIdx(thread.messages, m => m.role === 'user')
const lastUserMessage = thread.messages[lastUserMsgIdx] as ChatMessage & { role: 'user' }
if (lastUserMsgIdx === -1 || !lastUserMessage) return // should never happen
@ -651,7 +458,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
this._runChatAgent({ callThisToolFirst, prevSelns, currSelns, threadId, userMessageContent: instructions, ...this._currentModelSelectionProps() })
}
rejectTool(threadId: string) {
rejectLatestToolRequest(threadId: string) {
const thread = this.state.allThreads[threadId]
if (!thread) return // should never happen
@ -670,7 +477,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
const isRunning = this.streamState[threadId]?.isRunning
// reject the tool for the user
if (isRunning === 'awaiting_user') {
this.rejectTool(threadId)
this.rejectLatestToolRequest(threadId)
}
// interrupt the tool
else if (isRunning === 'tool') {
@ -741,7 +548,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
const thread = this.state.allThreads[threadId]
const latestMessages = thread?.messages ?? []
const messages_ = toLLMChatMessages(latestMessages)
const lastUserMsgIdx = findLastIndex(messages_, m => m.role === 'user')
const lastUserMsgIdx = findLastIdx(messages_, m => m.role === 'user')
if (lastUserMsgIdx === -1) return [] // should never happen (or how did they send the message?!)
// system message
@ -944,11 +751,226 @@ class ChatThreadService extends Disposable implements IChatThreadService {
}
async callWhenJumpBackToIdx(toIdx: number) {
// TODO!!!
// merge any LLM checkpoint before this one (and after a user checkpoint if one exists), and add the checkpoint
// call this right after LLM edits a file
addOrUpdateToolEditCheckpoint({ threadId, uri, }: { threadId: string, uri: URI }) {
const thread = this.state.allThreads[threadId]
if (!thread) return
const { model } = this._voidModelService.getModel(uri)
if (!model) return // should never happen
const lastUserCheckpointIdx = findLastIdx(thread.messages, (m) => m.role === 'checkpoint' && m.type === 'after_user_edits')
const prevLLMCheckpointIdx = thread.messages.findIndex((m, i) => i > lastUserCheckpointIdx && m.role === 'checkpoint' && m.type === 'after_tool_edits')
const afterStr = model.getValue() // afterStr = the value of the file right after the edit
let prevLLMCheckpoint: LLMCheckpoint | undefined = undefined
if (prevLLMCheckpointIdx !== -1) {
prevLLMCheckpoint = thread.messages[prevLLMCheckpointIdx] as ChatMessage & { role: 'checkpoint', type: 'after_tool_edits' }
this._removeMessageFromThread(threadId, prevLLMCheckpointIdx)
}
const newLLMCheckpoint: LLMCheckpoint = {
role: 'checkpoint',
type: 'after_tool_edits',
afterStrOfURI: {
...prevLLMCheckpoint?.afterStrOfURI,
[uri.fsPath]: afterStr,
},
}
console.log('NEW LLM CHECKPOINT', newLLMCheckpoint, JSON.stringify(this.state.allThreads[threadId], null, 2))
this._addMessageToThread(threadId, newLLMCheckpoint)
}
// user checkpoints are always computed JIT
// we assume there are no messages after the checkpoint we're adding here
// call this right before user sends message
addOrUpdateUserMessageCheckpoint({ threadId, }: { threadId: string, }) {
const thread = this.state.allThreads[threadId]
if (!thread) return
const newUserCheckpoint: UserCheckpoint = {
role: 'checkpoint',
type: 'after_user_edits', // user backup
afterStrOfURI: {},
}
// first get the last user checkpoint
const lastNonUserCheckpointIdx = findLastIdx(thread.messages, (m) => m.role === 'checkpoint' && m.type !== 'after_user_edits')
// merge all recent user checkpoints
const latestAfterStrOfURI: { [fsPath: string]: string } = {} // helps merge user edits
for (let k = 0; k <= thread.messages.length; k += 1) {
const message = thread.messages[k]
if (message.role !== 'checkpoint') continue
for (const uri in message.afterStrOfURI)
latestAfterStrOfURI[uri] = message.afterStrOfURI[uri]
// remove any user messages that come after the last LLM checkpoint (we're merging them into one big user message)
if (k > lastNonUserCheckpointIdx)
this._removeMessageFromThread(threadId, k)
}
// compute afterStr of all files we detected, and if they're different, add them as a user edit
for (const fsPath in latestAfterStrOfURI) {
const uri = URI.file(fsPath)
const { model } = this._voidModelService.getModel(uri)
if (!model) continue
const oldAfterStr = latestAfterStrOfURI[uri.fsPath]
const currentAfterStr = model.getValue()
if (oldAfterStr === currentAfterStr) continue
// if there was a change, add it as a user edit
newUserCheckpoint.afterStrOfURI = {
...newUserCheckpoint.afterStrOfURI,
[uri.fsPath]: currentAfterStr
}
}
this._addMessageToThread(threadId, newUserCheckpoint)
// update latest checkpoint idx to the one we just added
const newThread = this.state.allThreads[threadId]
if (!newThread) return // should never happen
const latestCheckpointIdx = newThread.messages.length - 1
this._setThreadState(threadId, { latestCheckpointIdx })
console.log('NEW USER CHECKPOINT', latestCheckpointIdx, newUserCheckpoint, JSON.stringify(this.state.allThreads[threadId], null, 2))
}
private _getCheckpointAfter = ({ threadId, messageIdx: afterIdx }: { threadId: string, messageIdx: number }): [CheckpointEntry, number] | undefined => {
const thread = this.state.allThreads[threadId]
if (!thread) return undefined
for (let i = afterIdx; i < thread.messages.length; i++) {
const message = thread.messages[i]
if (message.role === 'checkpoint') {
return [message, i]
}
}
return undefined
}
private _getAllChangedCheckpointURIs({ threadId, fromIdx, toIdx }: { threadId: string, fromIdx: number, toIdx: number }) {
const thread = this.state.allThreads[threadId]
if (!thread) return null // should never happen
const fsPaths: Set<string> = new Set()
for (let i = fromIdx; i <= toIdx; i += 1) {
const message = thread.messages[i]
if (message.role !== 'checkpoint') continue
for (const fsPath in message.afterStrOfURI) {
fsPaths.add(fsPath)
}
}
return fsPaths
}
jumpToCheckpointAfterMessageIdx({ threadId, messageIdx }: { threadId: string, messageIdx: number }) {
const thread = this.state.allThreads[threadId]
if (!thread) return
const c = this._getCheckpointAfter({ threadId, messageIdx })
if (c === undefined) return // should never happen
const fromIdx = thread.state.latestCheckpointIdx
if (fromIdx === null) return // should never happen
// TODO!!! change toIdx if there's a checkpointModification on the To, and add a checkpoint modification on the from
const [_, toIdx_] = c
const toIdx = toIdx_
if (toIdx === fromIdx) return
const writeFullFile = ({ fsPath, text }: { fsPath: string, text: string }) => {
const { model } = this._voidModelService.getModelFromFsPath(fsPath)
if (!model) return // should never happen
model.applyEdits([{
range: { startLineNumber: 1, startColumn: 1, endLineNumber: model.getLineCount(), endColumn: Number.MAX_SAFE_INTEGER }, // whole file
text
}])
}
/*
if undoing
A,B,C are all files.
x means a checkpoint where the file changed.
A B C D E F G H I
x x x x x x x x x
| | | | | | | | |
x | | | | | | | x
---x-|-|-|-x-|-x-|----- <-- to
x | | | | | x
| | x x |
| | | |
-------x-|---x-x------- <-- from
x
We need to revert anything that happened between to+1 and from.
**We do this by finding the last x from 0...`to` for each file and applying those contents.**
We only need to do it for files that were edited since `to`, ie files between to+1...from.
*/
if (toIdx < fromIdx) {
const checkpointURIs = this._getAllChangedCheckpointURIs({ threadId, toIdx: toIdx + 1, fromIdx })
for (const fsPath of checkpointURIs ?? []) {
let found = false
// apply lowest down content for each uri (or original if not found)
for (let k = toIdx; k >= 0; k -= 1) {
const message = thread.messages[k]
if (message.role !== 'checkpoint') continue
if (fsPath in message.afterStrOfURI) {
found = true
writeFullFile({ fsPath, text: message.afterStrOfURI[fsPath] })
break
}
}
if (!found) {
const originalStr = thread.firstStrOfURI[fsPath]
if (originalStr === undefined) continue
writeFullFile({ fsPath, text: originalStr })
}
}
}
/*
if redoing
A B C D E F G H I
x x x x x x x x x
| | | | | | | | |
x | | | | | | | x
---x-|-|-|-x-|-x-|----- <-- from
x | | | | | x
| | x x |
| | | |
-------x-|---x-x------- <-- to
x
We need to apply latest change for anything that happened between from+1 and to.
We only need to do it for files that were edited since `from`, ie files between from+1...to.
*/
if (toIdx > fromIdx) {
const checkpointURIs = this._getAllChangedCheckpointURIs({ threadId, fromIdx: fromIdx + 1, toIdx })
for (const fsPath of checkpointURIs ?? []) {
// apply lowest down content for each uri
// (do not need to apply original since we're only applying to files that changed)
for (let k = toIdx; k >= fromIdx + 1; k -= 1) {
const message = thread.messages[k]
if (message.role !== 'checkpoint') continue
if (fsPath in message.afterStrOfURI) {
writeFullFile({ fsPath, text: message.afterStrOfURI[fsPath] })
break
}
}
}
}
this._setThreadState(threadId, { latestCheckpointIdx: toIdx })
// TODO!!! add/merge a checkpoint modification if relevant
}
async addUserMessageAndStreamResponse({ userMessage, _chatSelections, threadId }: { userMessage: string, _chatSelections?: { prevSelns?: StagingSelectionItem[], currSelns?: StagingSelectionItem[], }, threadId: string }) {
@ -970,6 +992,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
const userHistoryElt: ChatMessage = { role: 'user', content: userMessageContent, displayContent: instructions, selections: currSelns, state: defaultMessageState }
this._addMessageToThread(threadId, userHistoryElt)
this.addOrUpdateUserMessageCheckpoint({ threadId })
this._runChatAgent({ prevSelns, currSelns, threadId, userMessageContent, ...this._currentModelSelectionProps(), })
}
@ -1255,7 +1278,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
// add the current file as a staging selection
const model = this._codeEditorService.getActiveCodeEditor()?.getModel()
if (model) {
this._setCurrentThreadState({
this._setThreadState(this.state.currentThreadId, {
...defaultThreadState,
stagingSelections: [{
type: 'File',
@ -1286,25 +1309,48 @@ class ChatThreadService extends Disposable implements IChatThreadService {
}
_addMessageToThread(threadId: string, message: ChatMessage) {
private _addMessageToThread(threadId: string, message: ChatMessage) {
const { allThreads } = this.state
const oldThread = allThreads[threadId]
if (!oldThread) return // should never happen
// update state and store it
const newThreads = {
...allThreads,
[oldThread.id]: {
...oldThread,
lastModified: new Date().toISOString(),
messages: [...oldThread.messages, message],
messages: [
...oldThread.messages,
message
],
}
}
this._storeAllThreads(newThreads)
this._setState({ allThreads: newThreads }, true) // the current thread just changed (it had a message added to it)
}
private _removeMessageFromThread(threadId: string, messageIdx: number) {
const { allThreads } = this.state
const oldThread = allThreads[threadId]
if (!oldThread) return // should never happen
// update state and store it
const newThreads = {
...allThreads,
[oldThread.id]: {
...oldThread,
lastModified: new Date().toISOString(),
messages: [
...oldThread.messages.slice(0, messageIdx),
...oldThread.messages.slice(messageIdx + 1, Infinity),
],
}
}
this._storeAllThreads(newThreads)
this._setState({ allThreads: newThreads }, true) // the current thread just changed (it had a message added to it)
}
// sets the currently selected message (must be undefined if no message is selected)
setCurrentlyFocusedMessageIdx(messageIdx: number | undefined) {
@ -1354,9 +1400,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
}
// set thread.state
private _setCurrentThreadState(state: Partial<ThreadType['state']>): void {
const threadId = this.state.currentThreadId
private _setThreadState(threadId: string, state: Partial<ThreadType['state']>): void {
const thread = this.state.allThreads[threadId]
if (!thread) return
@ -1409,7 +1453,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
return currentThread.state
}
setCurrentThreadState = (newState: Partial<ThreadType['state']>) => {
this._setCurrentThreadState(newState)
this._setThreadState(this.state.currentThreadId, newState)
}
// gets `staging` and `setStaging` of the currently focused element, given the index of the currently selected message (or undefined if no message is selected)

View file

@ -1248,7 +1248,7 @@ const ToolRequestAcceptRejectButtons = () => {
const onAccept = useCallback(() => {
try { // this doesn't need to be wrapped in try/catch anymore
const threadId = chatThreadsService.state.currentThreadId
chatThreadsService.approveTool(threadId)
chatThreadsService.approveLatestToolRequest(threadId)
metricsService.capture('Tool Request Accepted', {})
} catch (e) { console.error('Error while approving message in chat:', e) }
}, [chatThreadsService, metricsService])
@ -1256,7 +1256,7 @@ const ToolRequestAcceptRejectButtons = () => {
const onReject = useCallback(() => {
try {
const threadId = chatThreadsService.state.currentThreadId
chatThreadsService.rejectTool(threadId)
chatThreadsService.rejectLatestToolRequest(threadId)
} catch (e) { console.error('Error while approving message in chat:', e) }
metricsService.capture('Tool Request Rejected', {})
}, [chatThreadsService, metricsService])

View file

@ -26,16 +26,18 @@ export type ToolRequestApproval<T extends ToolName> = {
// checkpoints
export type LLMHistoryEntry = { // ALWAYS comes right after a {role:'tool', name:'edit'} message
role: 'LLM_changes';
afterStrOfURI: { [fsPath: string]: string };
}
export type UserHistoryEntry = { // ALWAYS comes right before a {role:'user'} message, or if it's the last message (w/o a user message yet)
role: 'user_changes';
export type CheckpointEntry = {
role: 'checkpoint';
type: 'after_user_edits' | 'after_tool_edits';
afterStrOfURI: { [fsPath: string]: string };
} | { // modifications that only count when undoing/redoing
role: 'checkpoint_modification';
type: 'user_modifications';
afterStrOfURI: { [fsPath: string]: string };
}
// WARNING: changing this format is a big deal!!!!!! need to migrate old format to new format on users' computers so people don't get errors.
export type ChatMessage =
| {
@ -56,8 +58,7 @@ export type ChatMessage =
}
| ToolMessage<ToolName>
| ToolRequestApproval<ToolName>
| LLMHistoryEntry // invisible
| UserHistoryEntry // invisible
| CheckpointEntry
// one of the square items that indicates a selection in a chat bubble (NOT a file, a Selection of text)

View file

@ -14,6 +14,7 @@ export interface IVoidModelService {
readonly _serviceBrand: undefined;
initializeModel(uri: URI): Promise<void>;
getModel(uri: URI): VoidModelType;
getModelFromFsPath(fsPath: string): VoidModelType;
getModelSafe(uri: URI): Promise<VoidModelType>;
}
@ -37,8 +38,8 @@ class VoidModelService extends Disposable implements IVoidModelService {
this._modelRefOfURI[uri.fsPath] = editorModelRef;
};
getModel = (uri: URI): VoidModelType => {
const editorModelRef = this._modelRefOfURI[uri.fsPath];
getModelFromFsPath = (fsPath: string): VoidModelType => {
const editorModelRef = this._modelRefOfURI[fsPath];
if (!editorModelRef) {
return { model: null, editorModel: null };
}
@ -52,6 +53,11 @@ class VoidModelService extends Disposable implements IVoidModelService {
return { model, editorModel: editorModelRef.object };
};
getModel = (uri: URI) => {
return this.getModelFromFsPath(uri.fsPath)
}
getModelSafe = async (uri: URI): Promise<VoidModelType> => {
if (!(uri.fsPath in this._modelRefOfURI)) await this.initializeModel(uri);
return this.getModel(uri);