mirror of
https://github.com/voideditor/void
synced 2026-05-24 09:58:23 +00:00
add checkpoints
This commit is contained in:
parent
7c0ba71314
commit
7decc8e146
4 changed files with 311 additions and 260 deletions
|
|
@ -21,7 +21,7 @@ import { ToolName, ToolCallParams, ToolResultType, toolNamesThatRequireApproval,
|
|||
import { IToolsService } from './toolsService.js';
|
||||
import { CancellationToken } from '../../../../base/common/cancellation.js';
|
||||
import { ILanguageFeaturesService } from '../../../../editor/common/services/languageFeatures.js';
|
||||
import { ChatMessage, CodespanLocationLink, StagingSelectionItem, ToolRequestApproval } from '../common/chatThreadServiceTypes.js';
|
||||
import { ChatMessage, CheckpointEntry, CodespanLocationLink, StagingSelectionItem, ToolRequestApproval } from '../common/chatThreadServiceTypes.js';
|
||||
import { Position } from '../../../../editor/common/core/position.js';
|
||||
import { ITerminalToolService } from './terminalToolService.js';
|
||||
import { IMetricsService } from '../common/metricsService.js';
|
||||
|
|
@ -29,16 +29,35 @@ import { shorten } from '../../../../base/common/labels.js';
|
|||
import { IVoidModelService } from '../common/voidModelService.js';
|
||||
import { IEditorService } from '../../../services/editor/common/editorService.js';
|
||||
import { ICodeEditorService } from '../../../../editor/browser/services/codeEditorService.js';
|
||||
import { findLastIdx } from '../../../../base/common/arraysFind.js';
|
||||
|
||||
const findLastIndex = <T>(arr: T[], condition: (t: T) => boolean): number => {
|
||||
for (let i = arr.length - 1; i >= 0; i--) {
|
||||
if (condition(arr[i])) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
type LLMCheckpoint = CheckpointEntry & { type: 'after_tool_edits' }
|
||||
type UserCheckpoint = CheckpointEntry & { type: 'after_user_edits' }
|
||||
/*
|
||||
Checkpoints:
|
||||
pivots: user | tool (edit)
|
||||
if there are repeated pivots, a checkpoint goes directly after the last one
|
||||
checkpoint_modifications always go directly after a checkpoint
|
||||
|
||||
user
|
||||
-- checkpoint --------
|
||||
assistant
|
||||
tool (edit)
|
||||
-------- checkpoint - starts here <-- know exact change (file A after)
|
||||
assistant |
|
||||
tool (edit) v
|
||||
-- checkpoint --------
|
||||
assistant
|
||||
tool (not edit)
|
||||
assistant
|
||||
user
|
||||
-- checkpoint -------- user checkpoint (JIT) - compute change from all files to here when need to
|
||||
-- checkpoint_modifications --------- - these always come DIRECLY after a checkpoint, and reflect the user's modifications on this one checkpoint only.
|
||||
(only counts when reverting to/from this exact checkpoint, not past it).
|
||||
Added when user jumps to another checkpoint but made changes here.
|
||||
|
||||
*/
|
||||
|
||||
|
||||
const toLLMChatMessages = (chatMessages: ChatMessage[]): LLMChatMessage[] => {
|
||||
|
|
@ -75,11 +94,15 @@ type ThreadType = {
|
|||
id: string; // store the id here too
|
||||
createdAt: string; // ISO string
|
||||
lastModified: string; // ISO string
|
||||
|
||||
messages: ChatMessage[];
|
||||
currentHistoryIdx: number | null; // index in messages, ALWAYS points to a LLMHistoryEntry or UserHistoryEntry, or -1 if no changes. current code is inclusive of the current index's change
|
||||
firstStrOfURI: { [fsPath: string]: string | undefined }; // part of checkpointing
|
||||
|
||||
|
||||
// this doesn't need to go in a state object, but feels right
|
||||
state: {
|
||||
latestCheckpointIdx: number | null; // the latest checkpoint we're standing at or null
|
||||
|
||||
stagingSelections: StagingSelectionItem[];
|
||||
focusedMessageIdx: number | undefined; // index of the user message that is being edited (undefined if none)
|
||||
|
||||
|
|
@ -96,6 +119,7 @@ type ChatThreads = {
|
|||
}
|
||||
|
||||
export const defaultThreadState: ThreadType['state'] = {
|
||||
latestCheckpointIdx: null,
|
||||
stagingSelections: [],
|
||||
focusedMessageIdx: undefined,
|
||||
linksOfMessageIdx: {},
|
||||
|
|
@ -130,7 +154,7 @@ const newThreadObject = () => {
|
|||
lastModified: now,
|
||||
messages: [],
|
||||
state: defaultThreadState,
|
||||
currentHistoryIdx: null,
|
||||
firstStrOfURI: {},
|
||||
} satisfies ThreadType
|
||||
}
|
||||
|
||||
|
|
@ -141,6 +165,8 @@ const newThreadObject = () => {
|
|||
export const THREAD_STORAGE_KEY = 'void.chatThreadStorageI'
|
||||
|
||||
|
||||
|
||||
|
||||
export interface IChatThreadService {
|
||||
readonly _serviceBrand: undefined;
|
||||
|
||||
|
|
@ -184,8 +210,8 @@ export interface IChatThreadService {
|
|||
addUserMessageAndStreamResponse({ userMessage, threadId }: { userMessage: string, threadId: string }): Promise<void>;
|
||||
|
||||
// approve/reject
|
||||
approveTool(threadId: string): void;
|
||||
rejectTool(threadId: string): void;
|
||||
approveLatestToolRequest(threadId: string): void;
|
||||
rejectLatestToolRequest(threadId: string): void;
|
||||
}
|
||||
|
||||
export const IChatThreadService = createDecorator<IChatThreadService>('voidChatThreadService');
|
||||
|
|
@ -311,225 +337,6 @@ class ChatThreadService extends Disposable implements IChatThreadService {
|
|||
}
|
||||
const threads = this._convertThreadDataFromStorage(threadsStr);
|
||||
|
||||
// threads['abc'] = {
|
||||
// id: 'abc',
|
||||
// createdAt: new Date().toISOString(),
|
||||
// lastModified: new Date().toISOString(),
|
||||
// messages: [
|
||||
// {
|
||||
// role: 'tool',
|
||||
// name: 'pathname_search',
|
||||
// id: 'tool-1',
|
||||
// paramsStr: '{"query": "hello", "pageNumber": 0}',
|
||||
// content: '/users/andrew/void/Desktop/etc/abc.txt',
|
||||
// result: { type: 'success', params: { queryStr: 'hello', pageNumber: 0 }, value: { uris: [URI.file('/Users/username/Downloads/helloooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo.txt'), URI.file('/Users/username/Downloads/hello1.txt'), URI.file('/Users/username/Downloads/hello2.txt'), URI.file('/Users/username/Downloads/hello3.txt'), URI.file('/Users/username/hello.txt')], hasNextPage: true } },
|
||||
// } satisfies ToolMessage<'pathname_search'>,
|
||||
// {
|
||||
// role: 'tool',
|
||||
// name: 'pathname_search',
|
||||
// id: 'tool-1',
|
||||
// paramsStr: '{"query": "hello", "pageNumber": 0}',
|
||||
// content: '/users/andrew/void/Desktop/etc/abc.txt',
|
||||
// result: { type: 'success', params: { queryStr: 'hello', pageNumber: 0 }, value: { uris: [], hasNextPage: false } },
|
||||
// } satisfies ToolMessage<'pathname_search'>,
|
||||
|
||||
// // {
|
||||
// // role: 'tool_request',
|
||||
// // name: 'pathname_search',
|
||||
// // params: { queryStr: 'hello', pageNumber: 0 },
|
||||
// // paramsStr: '{"query": "hello", "pageNumber": 0}',
|
||||
// // id: 'request-1',
|
||||
// // } satisfies ToolRequestApproval<'pathname_search'>,
|
||||
|
||||
// {
|
||||
// role: 'tool',
|
||||
// name: 'list_dir',
|
||||
// id: 'tool-2',
|
||||
// paramsStr: '{"uri": "/Users/username/Documents"}',
|
||||
// content: 'Directory listing of /Users/username/Documents',
|
||||
// result: {
|
||||
// type: 'success',
|
||||
// params: { rootURI: URI.file('/Users/username/Documents'), pageNumber: 1, },
|
||||
// value: {
|
||||
// children: [
|
||||
// { uri: URI.file('/Users/username/Documents/file1.txt'), name: 'file1.txt', isDirectory: false, isSymbolicLink: false },
|
||||
// { uri: URI.file('/Users/username/Documents/folder1'), name: 'folder1', isDirectory: true, isSymbolicLink: false }
|
||||
// ],
|
||||
// hasNextPage: true,
|
||||
// hasPrevPage: true,
|
||||
// itemsRemaining: 5,
|
||||
// }
|
||||
// },
|
||||
// } satisfies ToolMessage<'list_dir'>,
|
||||
|
||||
// // {
|
||||
// // role: 'tool_request',
|
||||
// // name: 'list_dir',
|
||||
// // params: { rootURI: URI.file('/Users/username/Documents'), pageNumber: 0 },
|
||||
// // paramsStr: '{"uri": "/Users/username/Documents"}',
|
||||
// // id: 'request-2',
|
||||
// // } satisfies ToolRequestApproval<'list_dir'>,
|
||||
|
||||
// {
|
||||
// role: 'tool',
|
||||
// name: 'read_file',
|
||||
// id: 'tool-3',
|
||||
// paramsStr: '{"uri": "/Users/username/Documents/file1.txt"}',
|
||||
// content: 'Content of file1.txt\nThis is a sample file.\nHello world!',
|
||||
// result: {
|
||||
// type: 'success',
|
||||
// params: { uri: URI.file('/src/vs/workbench/hi'), pageNumber: 0 },
|
||||
// value: { fileContents: 'Content of file1.txt\nThis is a sample file.\nHello world!', hasNextPage: false }
|
||||
// },
|
||||
// } satisfies ToolMessage<'read_file'>,
|
||||
|
||||
// // {
|
||||
// // role: 'tool_request',
|
||||
// // name: 'read_file',
|
||||
// // params: { uri: URI.file('/Users/username/Documents/file1.txt'), pageNumber: 0 },
|
||||
// // paramsStr: '{"uri": "/Users/username/Documents/file1.txt"}',
|
||||
// // id: 'request-3',
|
||||
// // } satisfies ToolRequestApproval<'read_file'>,
|
||||
|
||||
// {
|
||||
// role: 'tool',
|
||||
// name: 'grep_search',
|
||||
// id: 'tool-4',
|
||||
// paramsStr: '{"query": "function main"}',
|
||||
// content: 'Found matches in 3 files',
|
||||
// result: {
|
||||
// type: 'success',
|
||||
// params: { queryStr: 'function main', pageNumber: 0 },
|
||||
// value: {
|
||||
// uris: [
|
||||
// URI.file('/Users/username/Project/main.js'),
|
||||
// URI.file('/Users/username/Project/src/app.js'),
|
||||
// URI.file('/Users/username/Project/test/test.js')
|
||||
// ],
|
||||
// hasNextPage: false
|
||||
// }
|
||||
// },
|
||||
// } satisfies ToolMessage<'grep_search'>,
|
||||
|
||||
// // {
|
||||
// // role: 'tool_request',
|
||||
// // name: 'grep_search',
|
||||
// // params: { queryStr: 'function main', pageNumber: 0 },
|
||||
// // paramsStr: '{"query": "function main"}',
|
||||
// // id: 'request-4',
|
||||
// // } satisfies ToolRequestApproval<'grep_search'>,
|
||||
|
||||
// // ---
|
||||
|
||||
// {
|
||||
// role: 'tool',
|
||||
// name: 'edit',
|
||||
// id: 'tool-5',
|
||||
// paramsStr: '{"uri": "/Users/username/Project/main.js", "changeDescription": "Add console.log statement"}',
|
||||
// content: 'Successfully edited the file at /Users/username/Project/main.js',
|
||||
// result: {
|
||||
// type: 'success',
|
||||
// params: { uri: URI.file('/Users/username/Project/main.js'), changeDescription: 'I think we should do this:\n```typescript\n//Add console.log statement\n for i in ...\n\t\tdo:\nabc\n```' },
|
||||
// value: Promise.resolve()
|
||||
// },
|
||||
// } satisfies ToolMessage<'edit'>,
|
||||
// {
|
||||
// role: 'tool_request',
|
||||
// name: 'edit',
|
||||
// params: { uri: URI.file('/Users/username/Project/main.js'), changeDescription: 'I think we should do this:\n```typescript\n//Add console.log statement\n for i in ...\n\t\tdo:\nabc\n```' },
|
||||
// paramsStr: '{"uri": "/Users/username/Project/main.js", "changeDescription": "I think we should do this:```Add console.log statement\n for i in ...\n\t\tdo:\nabc```"}',
|
||||
// id: 'request-5',
|
||||
// } satisfies ToolRequestApproval<'edit'>,
|
||||
|
||||
// {
|
||||
// role: 'tool',
|
||||
// name: 'create_uri',
|
||||
// id: 'tool-6',
|
||||
// paramsStr: '{"uri": "/Users/username/Project/new-file.js"}',
|
||||
// content: 'Successfully created file at /Users/username/Project/new-file/',
|
||||
// result: {
|
||||
// type: 'success',
|
||||
// params: { uri: URI.file('Users/andrew/Desktop/void/src/vs/workbench/hi/'), isFolder: true },
|
||||
// value: {}
|
||||
// },
|
||||
// } satisfies ToolMessage<'create_uri'>,
|
||||
// {
|
||||
// role: 'tool_request',
|
||||
// name: 'create_uri',
|
||||
// params: { uri: URI.file('/Users/username/Project/new-file.js'), isFolder: false },
|
||||
// paramsStr: '{"uri": "/Users/username/Project/new-file.js"}',
|
||||
// id: 'request-6',
|
||||
// } satisfies ToolRequestApproval<'create_uri'>,
|
||||
|
||||
// {
|
||||
// role: 'tool',
|
||||
// name: 'delete_uri',
|
||||
// id: 'tool-7',
|
||||
// paramsStr: '{"uri": "/Users/username/Project/old-file.js", "params": ""}',
|
||||
// content: 'Successfully deleted file at /Users/username/Project/old-file.js',
|
||||
// result: {
|
||||
// type: 'success',
|
||||
// params: { uri: URI.file('/Users/username/Project/old-file.js'), isRecursive: false, isFolder: false },
|
||||
// value: {}
|
||||
// },
|
||||
// } satisfies ToolMessage<'delete_uri'>,
|
||||
// {
|
||||
// role: 'tool_request',
|
||||
// name: 'delete_uri',
|
||||
// params: { uri: URI.file('/Users/username/Project/old-file.js'), isRecursive: false, isFolder: false },
|
||||
// paramsStr: '{"uri": "/Users/username/Project/old-file.js", "params": ""}',
|
||||
// id: 'request-7',
|
||||
// } satisfies ToolRequestApproval<'delete_uri'>,
|
||||
|
||||
// {
|
||||
// role: 'tool',
|
||||
// name: 'terminal_command',
|
||||
// id: 'tool-8',
|
||||
// paramsStr: '{"command": "npm install", "waitForCompletion": "true"}',
|
||||
// content: 'Command executed: npm install\nAdded 123 packages in 3.5s',
|
||||
// result: {
|
||||
// type: 'success',
|
||||
// params: { command: 'npm install', proposedTerminalId: '1', waitForCompletion: true },
|
||||
// value: {
|
||||
// terminalId: '1',
|
||||
// didCreateTerminal: false,
|
||||
// result: 'Added 123 packages in 3.5s',
|
||||
// resolveReason: { type: 'done', exitCode: 0 }
|
||||
// }
|
||||
// },
|
||||
// } satisfies ToolMessage<'terminal_command'>,
|
||||
// {
|
||||
// role: 'tool_request',
|
||||
// name: 'terminal_command',
|
||||
// params: { command: 'npm install', proposedTerminalId: '1', waitForCompletion: true },
|
||||
// paramsStr: '{"command": "npm install", "waitForCompletion": "true"}',
|
||||
// id: 'request-8',
|
||||
// } satisfies ToolRequestApproval<'terminal_command'>,
|
||||
|
||||
|
||||
|
||||
// // Examples of error and rejected states
|
||||
// {
|
||||
// role: 'tool',
|
||||
// name: 'pathname_search',
|
||||
// id: 'tool-error',
|
||||
// paramsStr: '{"query": "invalid**query"}',
|
||||
// content: 'Error: Invalid search pattern',
|
||||
// result: { type: 'error', params: { queryStr: 'invalid**query', pageNumber: 0 }, value: 'Error: Invalid search pattern' },
|
||||
// } satisfies ToolMessage<'pathname_search'>,
|
||||
|
||||
// {
|
||||
// role: 'tool',
|
||||
// name: 'pathname_search',
|
||||
// id: 'tool-rejected',
|
||||
// paramsStr: '{"query": "sensitive-data"}',
|
||||
// content: 'Tool call was rejected by the user.',
|
||||
// result: { type: 'rejected', params: { queryStr: 'sensitive-data', pageNumber: 0 } },
|
||||
// } satisfies ToolMessage<'pathname_search'>,
|
||||
// ],
|
||||
// state: defaultThreadState,
|
||||
// }
|
||||
|
||||
return threads
|
||||
}
|
||||
|
||||
|
|
@ -631,7 +438,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
|
|||
}
|
||||
|
||||
|
||||
approveTool(threadId: string) {
|
||||
approveLatestToolRequest(threadId: string) {
|
||||
const thread = this.state.allThreads[threadId]
|
||||
if (!thread) return // should never happen
|
||||
|
||||
|
|
@ -639,7 +446,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
|
|||
const lastMessage = thread.messages[thread.messages.length - 1]
|
||||
if (lastMessage.role !== 'tool_request') return // should never happen
|
||||
|
||||
const lastUserMsgIdx = findLastIndex(thread.messages, m => m.role === 'user')
|
||||
const lastUserMsgIdx = findLastIdx(thread.messages, m => m.role === 'user')
|
||||
const lastUserMessage = thread.messages[lastUserMsgIdx] as ChatMessage & { role: 'user' }
|
||||
if (lastUserMsgIdx === -1 || !lastUserMessage) return // should never happen
|
||||
|
||||
|
|
@ -651,7 +458,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
|
|||
|
||||
this._runChatAgent({ callThisToolFirst, prevSelns, currSelns, threadId, userMessageContent: instructions, ...this._currentModelSelectionProps() })
|
||||
}
|
||||
rejectTool(threadId: string) {
|
||||
rejectLatestToolRequest(threadId: string) {
|
||||
const thread = this.state.allThreads[threadId]
|
||||
if (!thread) return // should never happen
|
||||
|
||||
|
|
@ -670,7 +477,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
|
|||
const isRunning = this.streamState[threadId]?.isRunning
|
||||
// reject the tool for the user
|
||||
if (isRunning === 'awaiting_user') {
|
||||
this.rejectTool(threadId)
|
||||
this.rejectLatestToolRequest(threadId)
|
||||
}
|
||||
// interrupt the tool
|
||||
else if (isRunning === 'tool') {
|
||||
|
|
@ -741,7 +548,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
|
|||
const thread = this.state.allThreads[threadId]
|
||||
const latestMessages = thread?.messages ?? []
|
||||
const messages_ = toLLMChatMessages(latestMessages)
|
||||
const lastUserMsgIdx = findLastIndex(messages_, m => m.role === 'user')
|
||||
const lastUserMsgIdx = findLastIdx(messages_, m => m.role === 'user')
|
||||
if (lastUserMsgIdx === -1) return [] // should never happen (or how did they send the message?!)
|
||||
|
||||
// system message
|
||||
|
|
@ -944,11 +751,226 @@ class ChatThreadService extends Disposable implements IChatThreadService {
|
|||
}
|
||||
|
||||
|
||||
async callWhenJumpBackToIdx(toIdx: number) {
|
||||
// TODO!!!
|
||||
// merge any LLM checkpoint before this one (and after a user checkpoint if one exists), and add the checkpoint
|
||||
// call this right after LLM edits a file
|
||||
addOrUpdateToolEditCheckpoint({ threadId, uri, }: { threadId: string, uri: URI }) {
|
||||
const thread = this.state.allThreads[threadId]
|
||||
if (!thread) return
|
||||
const { model } = this._voidModelService.getModel(uri)
|
||||
if (!model) return // should never happen
|
||||
|
||||
const lastUserCheckpointIdx = findLastIdx(thread.messages, (m) => m.role === 'checkpoint' && m.type === 'after_user_edits')
|
||||
const prevLLMCheckpointIdx = thread.messages.findIndex((m, i) => i > lastUserCheckpointIdx && m.role === 'checkpoint' && m.type === 'after_tool_edits')
|
||||
|
||||
const afterStr = model.getValue() // afterStr = the value of the file right after the edit
|
||||
|
||||
let prevLLMCheckpoint: LLMCheckpoint | undefined = undefined
|
||||
if (prevLLMCheckpointIdx !== -1) {
|
||||
prevLLMCheckpoint = thread.messages[prevLLMCheckpointIdx] as ChatMessage & { role: 'checkpoint', type: 'after_tool_edits' }
|
||||
this._removeMessageFromThread(threadId, prevLLMCheckpointIdx)
|
||||
}
|
||||
const newLLMCheckpoint: LLMCheckpoint = {
|
||||
role: 'checkpoint',
|
||||
type: 'after_tool_edits',
|
||||
afterStrOfURI: {
|
||||
...prevLLMCheckpoint?.afterStrOfURI,
|
||||
[uri.fsPath]: afterStr,
|
||||
},
|
||||
}
|
||||
console.log('NEW LLM CHECKPOINT', newLLMCheckpoint, JSON.stringify(this.state.allThreads[threadId], null, 2))
|
||||
this._addMessageToThread(threadId, newLLMCheckpoint)
|
||||
}
|
||||
|
||||
|
||||
// user checkpoints are always computed JIT
|
||||
// we assume there are no messages after the checkpoint we're adding here
|
||||
// call this right before user sends message
|
||||
addOrUpdateUserMessageCheckpoint({ threadId, }: { threadId: string, }) {
|
||||
const thread = this.state.allThreads[threadId]
|
||||
if (!thread) return
|
||||
|
||||
const newUserCheckpoint: UserCheckpoint = {
|
||||
role: 'checkpoint',
|
||||
type: 'after_user_edits', // user backup
|
||||
afterStrOfURI: {},
|
||||
}
|
||||
|
||||
// first get the last user checkpoint
|
||||
const lastNonUserCheckpointIdx = findLastIdx(thread.messages, (m) => m.role === 'checkpoint' && m.type !== 'after_user_edits')
|
||||
|
||||
// merge all recent user checkpoints
|
||||
const latestAfterStrOfURI: { [fsPath: string]: string } = {} // helps merge user edits
|
||||
for (let k = 0; k <= thread.messages.length; k += 1) {
|
||||
const message = thread.messages[k]
|
||||
if (message.role !== 'checkpoint') continue
|
||||
for (const uri in message.afterStrOfURI)
|
||||
latestAfterStrOfURI[uri] = message.afterStrOfURI[uri]
|
||||
|
||||
// remove any user messages that come after the last LLM checkpoint (we're merging them into one big user message)
|
||||
if (k > lastNonUserCheckpointIdx)
|
||||
this._removeMessageFromThread(threadId, k)
|
||||
}
|
||||
|
||||
// compute afterStr of all files we detected, and if they're different, add them as a user edit
|
||||
for (const fsPath in latestAfterStrOfURI) {
|
||||
const uri = URI.file(fsPath)
|
||||
const { model } = this._voidModelService.getModel(uri)
|
||||
if (!model) continue
|
||||
const oldAfterStr = latestAfterStrOfURI[uri.fsPath]
|
||||
const currentAfterStr = model.getValue()
|
||||
if (oldAfterStr === currentAfterStr) continue
|
||||
// if there was a change, add it as a user edit
|
||||
newUserCheckpoint.afterStrOfURI = {
|
||||
...newUserCheckpoint.afterStrOfURI,
|
||||
[uri.fsPath]: currentAfterStr
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
this._addMessageToThread(threadId, newUserCheckpoint)
|
||||
|
||||
// update latest checkpoint idx to the one we just added
|
||||
const newThread = this.state.allThreads[threadId]
|
||||
if (!newThread) return // should never happen
|
||||
const latestCheckpointIdx = newThread.messages.length - 1
|
||||
this._setThreadState(threadId, { latestCheckpointIdx })
|
||||
|
||||
|
||||
console.log('NEW USER CHECKPOINT', latestCheckpointIdx, newUserCheckpoint, JSON.stringify(this.state.allThreads[threadId], null, 2))
|
||||
}
|
||||
|
||||
|
||||
private _getCheckpointAfter = ({ threadId, messageIdx: afterIdx }: { threadId: string, messageIdx: number }): [CheckpointEntry, number] | undefined => {
|
||||
const thread = this.state.allThreads[threadId]
|
||||
if (!thread) return undefined
|
||||
for (let i = afterIdx; i < thread.messages.length; i++) {
|
||||
const message = thread.messages[i]
|
||||
if (message.role === 'checkpoint') {
|
||||
return [message, i]
|
||||
}
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
private _getAllChangedCheckpointURIs({ threadId, fromIdx, toIdx }: { threadId: string, fromIdx: number, toIdx: number }) {
|
||||
const thread = this.state.allThreads[threadId]
|
||||
if (!thread) return null // should never happen
|
||||
const fsPaths: Set<string> = new Set()
|
||||
for (let i = fromIdx; i <= toIdx; i += 1) {
|
||||
const message = thread.messages[i]
|
||||
if (message.role !== 'checkpoint') continue
|
||||
for (const fsPath in message.afterStrOfURI) {
|
||||
fsPaths.add(fsPath)
|
||||
}
|
||||
}
|
||||
return fsPaths
|
||||
}
|
||||
|
||||
jumpToCheckpointAfterMessageIdx({ threadId, messageIdx }: { threadId: string, messageIdx: number }) {
|
||||
const thread = this.state.allThreads[threadId]
|
||||
if (!thread) return
|
||||
|
||||
const c = this._getCheckpointAfter({ threadId, messageIdx })
|
||||
if (c === undefined) return // should never happen
|
||||
|
||||
const fromIdx = thread.state.latestCheckpointIdx
|
||||
if (fromIdx === null) return // should never happen
|
||||
|
||||
// TODO!!! change toIdx if there's a checkpointModification on the To, and add a checkpoint modification on the from
|
||||
const [_, toIdx_] = c
|
||||
const toIdx = toIdx_
|
||||
if (toIdx === fromIdx) return
|
||||
|
||||
const writeFullFile = ({ fsPath, text }: { fsPath: string, text: string }) => {
|
||||
const { model } = this._voidModelService.getModelFromFsPath(fsPath)
|
||||
if (!model) return // should never happen
|
||||
model.applyEdits([{
|
||||
range: { startLineNumber: 1, startColumn: 1, endLineNumber: model.getLineCount(), endColumn: Number.MAX_SAFE_INTEGER }, // whole file
|
||||
text
|
||||
}])
|
||||
}
|
||||
|
||||
/*
|
||||
if undoing
|
||||
|
||||
A,B,C are all files.
|
||||
x means a checkpoint where the file changed.
|
||||
|
||||
A B C D E F G H I
|
||||
x x x x x x x x x
|
||||
| | | | | | | | |
|
||||
x | | | | | | | x
|
||||
---x-|-|-|-x-|-x-|----- <-- to
|
||||
x | | | | | x
|
||||
| | x x |
|
||||
| | | |
|
||||
-------x-|---x-x------- <-- from
|
||||
x
|
||||
|
||||
We need to revert anything that happened between to+1 and from.
|
||||
**We do this by finding the last x from 0...`to` for each file and applying those contents.**
|
||||
We only need to do it for files that were edited since `to`, ie files between to+1...from.
|
||||
*/
|
||||
if (toIdx < fromIdx) {
|
||||
const checkpointURIs = this._getAllChangedCheckpointURIs({ threadId, toIdx: toIdx + 1, fromIdx })
|
||||
for (const fsPath of checkpointURIs ?? []) {
|
||||
let found = false
|
||||
|
||||
// apply lowest down content for each uri (or original if not found)
|
||||
|
||||
for (let k = toIdx; k >= 0; k -= 1) {
|
||||
const message = thread.messages[k]
|
||||
if (message.role !== 'checkpoint') continue
|
||||
if (fsPath in message.afterStrOfURI) {
|
||||
found = true
|
||||
writeFullFile({ fsPath, text: message.afterStrOfURI[fsPath] })
|
||||
break
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
const originalStr = thread.firstStrOfURI[fsPath]
|
||||
if (originalStr === undefined) continue
|
||||
writeFullFile({ fsPath, text: originalStr })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
if redoing
|
||||
|
||||
A B C D E F G H I
|
||||
x x x x x x x x x
|
||||
| | | | | | | | |
|
||||
x | | | | | | | x
|
||||
---x-|-|-|-x-|-x-|----- <-- from
|
||||
x | | | | | x
|
||||
| | x x |
|
||||
| | | |
|
||||
-------x-|---x-x------- <-- to
|
||||
x
|
||||
|
||||
We need to apply latest change for anything that happened between from+1 and to.
|
||||
We only need to do it for files that were edited since `from`, ie files between from+1...to.
|
||||
*/
|
||||
if (toIdx > fromIdx) {
|
||||
const checkpointURIs = this._getAllChangedCheckpointURIs({ threadId, fromIdx: fromIdx + 1, toIdx })
|
||||
for (const fsPath of checkpointURIs ?? []) {
|
||||
// apply lowest down content for each uri
|
||||
// (do not need to apply original since we're only applying to files that changed)
|
||||
for (let k = toIdx; k >= fromIdx + 1; k -= 1) {
|
||||
const message = thread.messages[k]
|
||||
if (message.role !== 'checkpoint') continue
|
||||
if (fsPath in message.afterStrOfURI) {
|
||||
writeFullFile({ fsPath, text: message.afterStrOfURI[fsPath] })
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this._setThreadState(threadId, { latestCheckpointIdx: toIdx })
|
||||
// TODO!!! add/merge a checkpoint modification if relevant
|
||||
}
|
||||
|
||||
|
||||
async addUserMessageAndStreamResponse({ userMessage, _chatSelections, threadId }: { userMessage: string, _chatSelections?: { prevSelns?: StagingSelectionItem[], currSelns?: StagingSelectionItem[], }, threadId: string }) {
|
||||
|
|
@ -970,6 +992,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
|
|||
const userHistoryElt: ChatMessage = { role: 'user', content: userMessageContent, displayContent: instructions, selections: currSelns, state: defaultMessageState }
|
||||
this._addMessageToThread(threadId, userHistoryElt)
|
||||
|
||||
this.addOrUpdateUserMessageCheckpoint({ threadId })
|
||||
this._runChatAgent({ prevSelns, currSelns, threadId, userMessageContent, ...this._currentModelSelectionProps(), })
|
||||
}
|
||||
|
||||
|
|
@ -1255,7 +1278,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
|
|||
// add the current file as a staging selection
|
||||
const model = this._codeEditorService.getActiveCodeEditor()?.getModel()
|
||||
if (model) {
|
||||
this._setCurrentThreadState({
|
||||
this._setThreadState(this.state.currentThreadId, {
|
||||
...defaultThreadState,
|
||||
stagingSelections: [{
|
||||
type: 'File',
|
||||
|
|
@ -1286,25 +1309,48 @@ class ChatThreadService extends Disposable implements IChatThreadService {
|
|||
}
|
||||
|
||||
|
||||
_addMessageToThread(threadId: string, message: ChatMessage) {
|
||||
private _addMessageToThread(threadId: string, message: ChatMessage) {
|
||||
const { allThreads } = this.state
|
||||
|
||||
const oldThread = allThreads[threadId]
|
||||
if (!oldThread) return // should never happen
|
||||
|
||||
// update state and store it
|
||||
const newThreads = {
|
||||
...allThreads,
|
||||
[oldThread.id]: {
|
||||
...oldThread,
|
||||
lastModified: new Date().toISOString(),
|
||||
messages: [...oldThread.messages, message],
|
||||
messages: [
|
||||
...oldThread.messages,
|
||||
message
|
||||
],
|
||||
}
|
||||
}
|
||||
this._storeAllThreads(newThreads)
|
||||
this._setState({ allThreads: newThreads }, true) // the current thread just changed (it had a message added to it)
|
||||
}
|
||||
|
||||
|
||||
private _removeMessageFromThread(threadId: string, messageIdx: number) {
|
||||
const { allThreads } = this.state
|
||||
const oldThread = allThreads[threadId]
|
||||
if (!oldThread) return // should never happen
|
||||
// update state and store it
|
||||
const newThreads = {
|
||||
...allThreads,
|
||||
[oldThread.id]: {
|
||||
...oldThread,
|
||||
lastModified: new Date().toISOString(),
|
||||
messages: [
|
||||
...oldThread.messages.slice(0, messageIdx),
|
||||
...oldThread.messages.slice(messageIdx + 1, Infinity),
|
||||
],
|
||||
}
|
||||
}
|
||||
this._storeAllThreads(newThreads)
|
||||
this._setState({ allThreads: newThreads }, true) // the current thread just changed (it had a message added to it)
|
||||
}
|
||||
|
||||
|
||||
// sets the currently selected message (must be undefined if no message is selected)
|
||||
setCurrentlyFocusedMessageIdx(messageIdx: number | undefined) {
|
||||
|
||||
|
|
@ -1354,9 +1400,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
|
|||
}
|
||||
|
||||
// set thread.state
|
||||
private _setCurrentThreadState(state: Partial<ThreadType['state']>): void {
|
||||
|
||||
const threadId = this.state.currentThreadId
|
||||
private _setThreadState(threadId: string, state: Partial<ThreadType['state']>): void {
|
||||
const thread = this.state.allThreads[threadId]
|
||||
if (!thread) return
|
||||
|
||||
|
|
@ -1409,7 +1453,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
|
|||
return currentThread.state
|
||||
}
|
||||
setCurrentThreadState = (newState: Partial<ThreadType['state']>) => {
|
||||
this._setCurrentThreadState(newState)
|
||||
this._setThreadState(this.state.currentThreadId, newState)
|
||||
}
|
||||
|
||||
// gets `staging` and `setStaging` of the currently focused element, given the index of the currently selected message (or undefined if no message is selected)
|
||||
|
|
|
|||
|
|
@ -1248,7 +1248,7 @@ const ToolRequestAcceptRejectButtons = () => {
|
|||
const onAccept = useCallback(() => {
|
||||
try { // this doesn't need to be wrapped in try/catch anymore
|
||||
const threadId = chatThreadsService.state.currentThreadId
|
||||
chatThreadsService.approveTool(threadId)
|
||||
chatThreadsService.approveLatestToolRequest(threadId)
|
||||
metricsService.capture('Tool Request Accepted', {})
|
||||
} catch (e) { console.error('Error while approving message in chat:', e) }
|
||||
}, [chatThreadsService, metricsService])
|
||||
|
|
@ -1256,7 +1256,7 @@ const ToolRequestAcceptRejectButtons = () => {
|
|||
const onReject = useCallback(() => {
|
||||
try {
|
||||
const threadId = chatThreadsService.state.currentThreadId
|
||||
chatThreadsService.rejectTool(threadId)
|
||||
chatThreadsService.rejectLatestToolRequest(threadId)
|
||||
} catch (e) { console.error('Error while approving message in chat:', e) }
|
||||
metricsService.capture('Tool Request Rejected', {})
|
||||
}, [chatThreadsService, metricsService])
|
||||
|
|
|
|||
|
|
@ -26,16 +26,18 @@ export type ToolRequestApproval<T extends ToolName> = {
|
|||
|
||||
|
||||
// checkpoints
|
||||
export type LLMHistoryEntry = { // ALWAYS comes right after a {role:'tool', name:'edit'} message
|
||||
role: 'LLM_changes';
|
||||
afterStrOfURI: { [fsPath: string]: string };
|
||||
}
|
||||
export type UserHistoryEntry = { // ALWAYS comes right before a {role:'user'} message, or if it's the last message (w/o a user message yet)
|
||||
role: 'user_changes';
|
||||
export type CheckpointEntry = {
|
||||
role: 'checkpoint';
|
||||
type: 'after_user_edits' | 'after_tool_edits';
|
||||
afterStrOfURI: { [fsPath: string]: string };
|
||||
} | { // modifications that only count when undoing/redoing
|
||||
role: 'checkpoint_modification';
|
||||
type: 'user_modifications';
|
||||
afterStrOfURI: { [fsPath: string]: string };
|
||||
}
|
||||
|
||||
|
||||
|
||||
// WARNING: changing this format is a big deal!!!!!! need to migrate old format to new format on users' computers so people don't get errors.
|
||||
export type ChatMessage =
|
||||
| {
|
||||
|
|
@ -56,8 +58,7 @@ export type ChatMessage =
|
|||
}
|
||||
| ToolMessage<ToolName>
|
||||
| ToolRequestApproval<ToolName>
|
||||
| LLMHistoryEntry // invisible
|
||||
| UserHistoryEntry // invisible
|
||||
| CheckpointEntry
|
||||
|
||||
|
||||
// one of the square items that indicates a selection in a chat bubble (NOT a file, a Selection of text)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ export interface IVoidModelService {
|
|||
readonly _serviceBrand: undefined;
|
||||
initializeModel(uri: URI): Promise<void>;
|
||||
getModel(uri: URI): VoidModelType;
|
||||
getModelFromFsPath(fsPath: string): VoidModelType;
|
||||
getModelSafe(uri: URI): Promise<VoidModelType>;
|
||||
}
|
||||
|
||||
|
|
@ -37,8 +38,8 @@ class VoidModelService extends Disposable implements IVoidModelService {
|
|||
this._modelRefOfURI[uri.fsPath] = editorModelRef;
|
||||
};
|
||||
|
||||
getModel = (uri: URI): VoidModelType => {
|
||||
const editorModelRef = this._modelRefOfURI[uri.fsPath];
|
||||
getModelFromFsPath = (fsPath: string): VoidModelType => {
|
||||
const editorModelRef = this._modelRefOfURI[fsPath];
|
||||
if (!editorModelRef) {
|
||||
return { model: null, editorModel: null };
|
||||
}
|
||||
|
|
@ -52,6 +53,11 @@ class VoidModelService extends Disposable implements IVoidModelService {
|
|||
return { model, editorModel: editorModelRef.object };
|
||||
};
|
||||
|
||||
getModel = (uri: URI) => {
|
||||
return this.getModelFromFsPath(uri.fsPath)
|
||||
}
|
||||
|
||||
|
||||
getModelSafe = async (uri: URI): Promise<VoidModelType> => {
|
||||
if (!(uri.fsPath in this._modelRefOfURI)) await this.initializeModel(uri);
|
||||
return this.getModel(uri);
|
||||
|
|
|
|||
Loading…
Reference in a new issue