mirror of
https://github.com/voideditor/void
synced 2026-05-24 09:58:23 +00:00
checkpoints
This commit is contained in:
parent
3ac9dcf0c0
commit
bde51106a1
3 changed files with 197 additions and 193 deletions
|
|
@ -32,30 +32,21 @@ import { ICodeEditorService } from '../../../../editor/browser/services/codeEdit
|
|||
import { findLastIdx } from '../../../../base/common/arraysFind.js';
|
||||
|
||||
|
||||
type LLMCheckpoint = CheckpointEntry & { type: 'after_tool_edits' }
|
||||
type UserCheckpoint = CheckpointEntry & { type: 'after_user_edits' }
|
||||
/*
|
||||
Checkpoints:
|
||||
pivots: user | tool (edit)
|
||||
if there are repeated pivots, a checkpoint goes directly after the last one
|
||||
checkpoint_modifications always go directly after a checkpoint
|
||||
|
||||
user
|
||||
-- checkpoint --------
|
||||
assistant
|
||||
tool (edit)
|
||||
-------- checkpoint - starts here <-- know exact change (file A after)
|
||||
assistant |
|
||||
tool (edit) v
|
||||
-- checkpoint --------
|
||||
assistant
|
||||
tool (not edit)
|
||||
assistant
|
||||
user
|
||||
-- checkpoint -------- user checkpoint (JIT) - compute change from all files to here when need to
|
||||
-- checkpoint_modifications --------- - these always come DIRECLY after a checkpoint, and reflect the user's modifications on this one checkpoint only.
|
||||
(only counts when reverting to/from this exact checkpoint, not past it).
|
||||
Added when user jumps to another checkpoint but made changes here.
|
||||
Store a checkpoint of all "before" files on each x.
|
||||
x's show up before user messages and LLM edit tool calls.
|
||||
|
||||
x A (edited A -> A')
|
||||
(... user modified changes ...)
|
||||
User message
|
||||
|
||||
x A' B C (edited A'->A'', B->B', C->C')
|
||||
LLM Edit
|
||||
x
|
||||
LLM Edit
|
||||
x
|
||||
LLM Edit
|
||||
|
||||
*/
|
||||
|
||||
|
|
@ -74,8 +65,6 @@ const toLLMChatMessages = (chatMessages: ChatMessage[]): LLMChatMessage[] => {
|
|||
}
|
||||
else if (c.role === 'checkpoint') { // pass
|
||||
}
|
||||
else if (c.role === 'checkpoint_modification') { // pass
|
||||
}
|
||||
else {
|
||||
throw new Error(`Role ${(c as any).role} not recognized.`)
|
||||
}
|
||||
|
|
@ -99,12 +88,11 @@ type ThreadType = {
|
|||
lastModified: string; // ISO string
|
||||
|
||||
messages: ChatMessage[];
|
||||
firstStrOfURI: { [fsPath: string]: string | undefined }; // part of checkpointing
|
||||
|
||||
filesWithUserChanges: Set<string>;
|
||||
|
||||
// this doesn't need to go in a state object, but feels right
|
||||
state: {
|
||||
currCheckpointIdx: number | null; // the latest checkpoint we're standing at or null
|
||||
currCheckpointIdx: number | null; // the latest checkpoint we're at (always defined unless chat is empty so there are no checkpts)
|
||||
|
||||
stagingSelections: StagingSelectionItem[];
|
||||
focusedMessageIdx: number | undefined; // index of the user message that is being edited (undefined if none)
|
||||
|
|
@ -121,12 +109,6 @@ type ChatThreads = {
|
|||
[id: string]: undefined | ThreadType;
|
||||
}
|
||||
|
||||
export const defaultThreadState: ThreadType['state'] = {
|
||||
currCheckpointIdx: null,
|
||||
stagingSelections: [],
|
||||
focusedMessageIdx: undefined,
|
||||
linksOfMessageIdx: {},
|
||||
}
|
||||
|
||||
export type ThreadsState = {
|
||||
allThreads: ChatThreads;
|
||||
|
|
@ -156,8 +138,13 @@ const newThreadObject = () => {
|
|||
createdAt: now,
|
||||
lastModified: now,
|
||||
messages: [],
|
||||
state: defaultThreadState,
|
||||
firstStrOfURI: {},
|
||||
state: {
|
||||
currCheckpointIdx: null,
|
||||
stagingSelections: [],
|
||||
focusedMessageIdx: undefined,
|
||||
linksOfMessageIdx: {},
|
||||
},
|
||||
filesWithUserChanges: new Set()
|
||||
} satisfies ThreadType
|
||||
}
|
||||
|
||||
|
|
@ -217,7 +204,7 @@ export interface IChatThreadService {
|
|||
rejectLatestToolRequest(threadId: string): void;
|
||||
|
||||
// jump to history
|
||||
jumpToCheckpointAfterMessageIdx(opts: { threadId: string, messageIdx: number }): void;
|
||||
jumpToCheckpointBeforeMessageIdx(opts: { threadId: string, messageIdx: number, jumpToUserModified: boolean }): void;
|
||||
}
|
||||
|
||||
export const IChatThreadService = createDecorator<IChatThreadService>('voidChatThreadService');
|
||||
|
|
@ -234,6 +221,11 @@ class ChatThreadService extends Disposable implements IChatThreadService {
|
|||
readonly streamState: ThreadStreamState = {}
|
||||
state: ThreadsState // allThreads is persisted, currentThread is not
|
||||
|
||||
// used in checkpointing
|
||||
// private readonly _userModifiedFilesToCheckInCheckpoints = new LRUCache<string, null>(50)
|
||||
|
||||
|
||||
|
||||
constructor(
|
||||
@IStorageService private readonly _storageService: IStorageService,
|
||||
@IVoidModelService private readonly _voidModelService: IVoidModelService,
|
||||
|
|
@ -246,6 +238,8 @@ class ChatThreadService extends Disposable implements IChatThreadService {
|
|||
@IMetricsService private readonly _metricsService: IMetricsService,
|
||||
@IEditorService private readonly _editorService: IEditorService,
|
||||
@ICodeEditorService private readonly _codeEditorService: ICodeEditorService,
|
||||
// @IModelService private readonly _modelService: IModelService,
|
||||
|
||||
) {
|
||||
super()
|
||||
this.state = { allThreads: {}, currentThreadId: null as unknown as string } // default state
|
||||
|
|
@ -264,6 +258,22 @@ class ChatThreadService extends Disposable implements IChatThreadService {
|
|||
// when the user changes files, automatically add the new file as a stagingSelection
|
||||
this._register(this._editorService.onDidActiveEditorChange(() => this._addCurrentFileAsStagingSelectionDuringFileChange()));
|
||||
|
||||
|
||||
// keep track of user-modified files
|
||||
// const disposablesOfModelId: { [modelId: string]: IDisposable[] } = {}
|
||||
// this._register(
|
||||
// this._modelService.onModelAdded(e => {
|
||||
// if (!(e.id in disposablesOfModelId)) disposablesOfModelId[e.id] = []
|
||||
// disposablesOfModelId[e.id].push(
|
||||
// e.onDidChangeContent(() => { this._userModifiedFilesToCheckInCheckpoints.set(e.uri.fsPath, null) })
|
||||
// )
|
||||
// })
|
||||
// )
|
||||
// this._register(this._modelService.onModelRemoved(e => {
|
||||
// if (!(e.id in disposablesOfModelId)) return
|
||||
// disposablesOfModelId[e.id].forEach(d => d.dispose())
|
||||
// }))
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -619,8 +629,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
|
|||
this._setStreamState(threadId, { isRunning: 'tool' }, 'merge')
|
||||
let interrupted = false
|
||||
try {
|
||||
// add the original file if it wasn't seen before in this thread
|
||||
if (toolName === 'edit') { this._trackOriginalFileInURI({ threadId, uri: (toolParams as ToolCallParams['edit']).uri }) }
|
||||
if (toolName === 'edit') { this._addToolEditCheckpoint({ threadId, uri: (toolParams as ToolCallParams['edit']).uri }) }
|
||||
|
||||
const { result, interruptTool } = await this._toolsService.callTool[toolName](toolParams as any)
|
||||
this._currentlyRunningToolInterruptor[threadId] = () => {
|
||||
|
|
@ -653,8 +662,6 @@ class ChatThreadService extends Disposable implements IChatThreadService {
|
|||
// 5. add to history and keep going
|
||||
this._addMessageToThread(threadId, { role: 'tool', name: toolName, paramsStr: toolParamsStr, id: toolId, content: toolResultStr, result: { type: 'success', params: toolParams, value: toolResult }, })
|
||||
|
||||
// 6. add a checkpoint
|
||||
if (toolName === 'edit') { this._addToolEditCheckpoint({ threadId, uri: (toolParams as ToolCallParams['edit']).uri }) }
|
||||
return {}
|
||||
};
|
||||
|
||||
|
|
@ -763,23 +770,13 @@ class ChatThreadService extends Disposable implements IChatThreadService {
|
|||
}
|
||||
|
||||
|
||||
private _trackOriginalFileInURI({ threadId, uri }: { threadId: string, uri: URI }) {
|
||||
const thread = this.state.allThreads[threadId]
|
||||
if (!thread) return
|
||||
const { model } = this._voidModelService.getModel(uri)
|
||||
if (!model) return
|
||||
if (!(uri.fsPath in thread.firstStrOfURI)) {
|
||||
thread.firstStrOfURI[uri.fsPath] = model.getValue()
|
||||
}
|
||||
}
|
||||
|
||||
private _addCheckpoint(threadId: string, checkpoint: CheckpointEntry) {
|
||||
this._addMessageToThread(threadId, checkpoint)
|
||||
// update latest checkpoint idx to the one we just added
|
||||
const newThread = this.state.allThreads[threadId]
|
||||
if (!newThread) return // should never happen
|
||||
const latestCheckpointIdx = newThread.messages.length - 1
|
||||
this._setThreadState(threadId, { currCheckpointIdx: latestCheckpointIdx })
|
||||
const currCheckpointIdx = newThread.messages.length - 1
|
||||
this._setThreadState(threadId, { currCheckpointIdx })
|
||||
}
|
||||
|
||||
// merge any LLM checkpoint before this one (and after a user checkpoint if one exists), and add the checkpoint
|
||||
|
|
@ -790,88 +787,112 @@ class ChatThreadService extends Disposable implements IChatThreadService {
|
|||
const { model } = this._voidModelService.getModel(uri)
|
||||
if (!model) return // should never happen
|
||||
|
||||
const currValue = model.getValue() // afterStr = the value of the file right after the edit
|
||||
|
||||
const afterStr = model.getValue() // afterStr = the value of the file right after the edit
|
||||
|
||||
const newLLMCheckpoint: LLMCheckpoint = {
|
||||
this._addCheckpoint(threadId, {
|
||||
role: 'checkpoint',
|
||||
type: 'after_tool_edits',
|
||||
afterStrOfURI: {
|
||||
[uri.fsPath]: afterStr,
|
||||
},
|
||||
}
|
||||
|
||||
// remove and merge
|
||||
// const lastUserCheckpointIdx = findLastIdx(thread.messages, (m) => m.role === 'checkpoint' && m.type === 'after_user_edits')
|
||||
// const prevLLMCheckpointIdx = thread.messages.findIndex((m, i) => i > lastUserCheckpointIdx && m.role === 'checkpoint' && m.type === 'after_tool_edits')
|
||||
// let prevLLMCheckpoint: LLMCheckpoint | undefined = undefined
|
||||
// if (prevLLMCheckpointIdx !== -1) {
|
||||
// prevLLMCheckpoint = thread.messages[prevLLMCheckpointIdx] as ChatMessage & { role: 'checkpoint', type: 'after_tool_edits' }
|
||||
// this._removeMessageFromThread(threadId, prevLLMCheckpointIdx)
|
||||
// newLLMCheckpoint.afterStrOfURI = {
|
||||
// ...newLLMCheckpoint.afterStrOfURI,
|
||||
// ...prevLLMCheckpoint?.afterStrOfURI,
|
||||
// }
|
||||
// }
|
||||
|
||||
this._addCheckpoint(threadId, newLLMCheckpoint)
|
||||
type: 'tool_edit',
|
||||
beforeStrOfURI: { [uri.fsPath]: currValue, },
|
||||
userModifications: { beforeStrOfURI: {} },
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
private _editMessageInThread(threadId: string, messageIdx: number, newMessage: ChatMessage,) {
|
||||
const { allThreads } = this.state
|
||||
const oldThread = allThreads[threadId]
|
||||
if (!oldThread) return // should never happen
|
||||
// update state and store it
|
||||
const newThreads = {
|
||||
...allThreads,
|
||||
[oldThread.id]: {
|
||||
...oldThread,
|
||||
lastModified: new Date().toISOString(),
|
||||
messages: [
|
||||
...oldThread.messages.slice(0, messageIdx),
|
||||
newMessage,
|
||||
...oldThread.messages.slice(messageIdx + 1, Infinity),
|
||||
],
|
||||
}
|
||||
}
|
||||
this._storeAllThreads(newThreads)
|
||||
this._setState({ allThreads: newThreads }, true) // the current thread just changed (it had a message added to it)
|
||||
}
|
||||
|
||||
// user checkpoints are always computed JIT
|
||||
// we assume there are no messages after the checkpoint we're adding here
|
||||
// call this right before user sends message
|
||||
private _addOrUpdateUserMessageCheckpoint({ threadId, }: { threadId: string, }) {
|
||||
|
||||
|
||||
private _computeNeededCheckpointChanges({ threadId }: { threadId: string }) {
|
||||
const thread = this.state.allThreads[threadId]
|
||||
if (!thread) return
|
||||
const { currCheckpointIdx } = thread.state
|
||||
if (currCheckpointIdx === null) return
|
||||
|
||||
const currStrOfFsPath: { [fsPath: string]: string | undefined } = {}
|
||||
|
||||
// add a change for all the URIs in the checkpoint history
|
||||
const { lastIdxOfURI } = this._getCheckpointsBetween({ threadId, loIdx: 0, hiIdx: currCheckpointIdx, }) ?? {}
|
||||
for (const fsPath in lastIdxOfURI ?? {}) {
|
||||
const { model } = this._voidModelService.getModelFromFsPath(fsPath)
|
||||
if (!model) continue
|
||||
const checkpoint2 = thread.messages[lastIdxOfURI[fsPath]] || null
|
||||
if (!checkpoint2) continue
|
||||
if (checkpoint2.role !== 'checkpoint') continue
|
||||
const oldStr = this._getBeforeStrAtCheckpoint(checkpoint2, fsPath, { includeUserModifiedChanges: false })
|
||||
const currStr = model.getValue()
|
||||
if (oldStr === currStr) continue
|
||||
currStrOfFsPath[fsPath] = currStr
|
||||
}
|
||||
|
||||
// // add a change for all user-edited files (that aren't in the history)
|
||||
// for (const fsPath of this._userModifiedFilesToCheckInCheckpoints.keys()) {
|
||||
// if (fsPath in lastIdxOfURI) continue // if already visisted, don't visit again
|
||||
// const { model } = this._voidModelService.getModelFromFsPath(fsPath)
|
||||
// if (!model) continue
|
||||
// currStrOfFsPath[fsPath] = model.getValue()
|
||||
// }
|
||||
|
||||
return currStrOfFsPath
|
||||
}
|
||||
|
||||
// call this right before user sends message or reverts
|
||||
private _addUserCheckpoint({ threadId }: { threadId: string }) {
|
||||
const changes = this._computeNeededCheckpointChanges({ threadId })
|
||||
this._addCheckpoint(threadId, {
|
||||
role: 'checkpoint',
|
||||
type: 'user_edit',
|
||||
beforeStrOfURI: changes ?? {},
|
||||
userModifications: { beforeStrOfURI: {} },
|
||||
})
|
||||
}
|
||||
private _addUserModificationsToCurrCheckpoint({ threadId }: { threadId: string }) {
|
||||
const changes = this._computeNeededCheckpointChanges({ threadId })
|
||||
const res = this._getCurrentCheckpoint(threadId)
|
||||
if (!res) return
|
||||
const [checkpoint, checkpointIdx] = res
|
||||
this._editMessageInThread(threadId, checkpointIdx, {
|
||||
...checkpoint,
|
||||
userModifications: { beforeStrOfURI: changes ?? {} },
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
private _getCurrentCheckpoint(threadId: string): [CheckpointEntry, number] | undefined {
|
||||
const thread = this.state.allThreads[threadId]
|
||||
if (!thread) return
|
||||
|
||||
const newUserCheckpoint: UserCheckpoint = {
|
||||
role: 'checkpoint',
|
||||
type: 'after_user_edits', // user backup
|
||||
afterStrOfURI: {},
|
||||
}
|
||||
const { currCheckpointIdx } = thread.state
|
||||
if (currCheckpointIdx === null) return
|
||||
|
||||
// first get the last user checkpoint
|
||||
const lastNonUserCheckpointIdx = findLastIdx(thread.messages, (m) => m.role === 'checkpoint' && m.type !== 'after_user_edits')
|
||||
|
||||
// merge all recent user checkpoints and delete them
|
||||
const latestAfterStrOfURI: { [fsPath: string]: string } = {} // helps merge user edits
|
||||
for (let k = 0; k < thread.messages.length; k += 1) {
|
||||
const message = thread.messages[k]
|
||||
if (message.role !== 'checkpoint') continue
|
||||
for (const uri in message.afterStrOfURI)
|
||||
latestAfterStrOfURI[uri] = message.afterStrOfURI[uri]
|
||||
|
||||
// remove any user messages that come after the last LLM checkpoint (we're merging them into one big user message)
|
||||
if (k > lastNonUserCheckpointIdx)
|
||||
this._removeMessageFromThread(threadId, k)
|
||||
}
|
||||
|
||||
|
||||
// add a change for all the files where we detect a user change
|
||||
const allURIs = this._getAllChangedCheckpointURIs({ threadId, loIdx: 0, hiIdx: thread.messages.length - 1, })
|
||||
for (const fsPath of allURIs ?? []) {
|
||||
const { model } = this._voidModelService.getModelFromFsPath(fsPath)
|
||||
if (!model) continue
|
||||
const oldAfterStr = latestAfterStrOfURI[fsPath]
|
||||
const currentAfterStr = model.getValue()
|
||||
if (oldAfterStr === currentAfterStr) continue
|
||||
// if there was a change, add it as a user edit
|
||||
newUserCheckpoint.afterStrOfURI = {
|
||||
...newUserCheckpoint.afterStrOfURI,
|
||||
[fsPath]: currentAfterStr
|
||||
}
|
||||
}
|
||||
|
||||
this._addCheckpoint(threadId, newUserCheckpoint)
|
||||
const checkpoint = thread.messages[currCheckpointIdx]
|
||||
if (!checkpoint) return
|
||||
if (checkpoint.role !== 'checkpoint') return
|
||||
return [checkpoint, currCheckpointIdx]
|
||||
}
|
||||
|
||||
|
||||
private _getCheckpointAfter = ({ threadId, messageIdx: afterIdx }: { threadId: string, messageIdx: number }): [CheckpointEntry, number] | undefined => {
|
||||
private _getCheckpointBeforeMessage = ({ threadId, messageIdx }: { threadId: string, messageIdx: number }): [CheckpointEntry, number] | undefined => {
|
||||
const thread = this.state.allThreads[threadId]
|
||||
if (!thread) return undefined
|
||||
for (let i = afterIdx; i < thread.messages.length; i++) {
|
||||
for (let i = messageIdx; i >= 0; i--) {
|
||||
const message = thread.messages[i]
|
||||
if (message.role === 'checkpoint') {
|
||||
return [message, i]
|
||||
|
|
@ -880,45 +901,55 @@ class ChatThreadService extends Disposable implements IChatThreadService {
|
|||
return undefined
|
||||
}
|
||||
|
||||
private _getAllChangedCheckpointURIs({ threadId, loIdx, hiIdx }: { threadId: string, loIdx: number, hiIdx: number }) {
|
||||
private _getCheckpointsBetween({ threadId, loIdx, hiIdx }: { threadId: string, loIdx: number, hiIdx: number }) {
|
||||
const thread = this.state.allThreads[threadId]
|
||||
if (!thread) return null // should never happen
|
||||
const fsPaths: Set<string> = new Set()
|
||||
if (!thread) return { lastIdxOfURI: {} } // should never happen
|
||||
const lastIdxOfURI: { [fsPath: string]: number } = {}
|
||||
for (let i = loIdx; i <= hiIdx; i += 1) {
|
||||
const message = thread.messages[i]
|
||||
if (message.role !== 'checkpoint') continue
|
||||
for (const fsPath in message.afterStrOfURI) {
|
||||
fsPaths.add(fsPath)
|
||||
for (const fsPath in message.beforeStrOfURI) { // do not include userModified.beforeStrOfURI here, jumping should not include those changes
|
||||
lastIdxOfURI[fsPath] = i
|
||||
}
|
||||
}
|
||||
return fsPaths
|
||||
return { lastIdxOfURI }
|
||||
}
|
||||
|
||||
jumpToCheckpointAfterMessageIdx({ threadId, messageIdx }: { threadId: string, messageIdx: number }) {
|
||||
private _getBeforeStrAtCheckpoint = (checkpointMessage: ChatMessage & { role: 'checkpoint' }, fsPath: string, opts: { includeUserModifiedChanges: boolean }) => {
|
||||
const beforeStr = fsPath in checkpointMessage.beforeStrOfURI ? checkpointMessage.beforeStrOfURI[fsPath] ?? null : null
|
||||
if (!opts.includeUserModifiedChanges) return beforeStr
|
||||
const userModifiedBeforeStr = fsPath in checkpointMessage.userModifications.beforeStrOfURI ? checkpointMessage.userModifications.beforeStrOfURI[fsPath] ?? null : null
|
||||
return userModifiedBeforeStr ?? beforeStr
|
||||
}
|
||||
|
||||
|
||||
private _writeFullFile = ({ fsPath, text }: { fsPath: string, text: string }) => {
|
||||
const { model } = this._voidModelService.getModelFromFsPath(fsPath)
|
||||
if (!model) return // should never happen
|
||||
model.applyEdits([{
|
||||
range: { startLineNumber: 1, startColumn: 1, endLineNumber: model.getLineCount(), endColumn: Number.MAX_SAFE_INTEGER }, // whole file
|
||||
text
|
||||
}])
|
||||
}
|
||||
|
||||
jumpToCheckpointBeforeMessageIdx({ threadId, messageIdx, jumpToUserModified }: { threadId: string, messageIdx: number, jumpToUserModified: boolean }) {
|
||||
const thread = this.state.allThreads[threadId]
|
||||
if (!thread) return
|
||||
|
||||
const c = this._getCheckpointAfter({ threadId, messageIdx })
|
||||
const c = this._getCheckpointBeforeMessage({ threadId, messageIdx })
|
||||
if (c === undefined) return // should never happen
|
||||
|
||||
const fromIdx = thread.state.currCheckpointIdx
|
||||
if (fromIdx === null) return // should never happen
|
||||
|
||||
// TODO!!! change toIdx if there's a checkpointModification on the To, and add a checkpoint modification on the from
|
||||
const [_, toIdx_] = c
|
||||
const toIdx = toIdx_
|
||||
const [_, toIdx] = c
|
||||
if (toIdx === fromIdx) return
|
||||
|
||||
const writeFullFile = ({ fsPath, text }: { fsPath: string, text: string }) => {
|
||||
const { model } = this._voidModelService.getModelFromFsPath(fsPath)
|
||||
if (!model) return // should never happen
|
||||
model.applyEdits([{
|
||||
range: { startLineNumber: 1, startColumn: 1, endLineNumber: model.getLineCount(), endColumn: Number.MAX_SAFE_INTEGER }, // whole file
|
||||
text
|
||||
}])
|
||||
}
|
||||
console.log(`going from ${fromIdx} to ${toIdx}`)
|
||||
|
||||
// update the user's checkpoint
|
||||
this._addUserModificationsToCurrCheckpoint({ threadId })
|
||||
|
||||
/*
|
||||
if undoing
|
||||
|
||||
|
|
@ -941,26 +972,18 @@ We need to revert anything that happened between to+1 and from.
|
|||
We only need to do it for files that were edited since `to`, ie files between to+1...from.
|
||||
*/
|
||||
if (toIdx < fromIdx) {
|
||||
const checkpointURIs = this._getAllChangedCheckpointURIs({ threadId, loIdx: toIdx + 1, hiIdx: fromIdx })
|
||||
for (const fsPath of checkpointURIs ?? []) {
|
||||
let found = false
|
||||
|
||||
const { lastIdxOfURI } = this._getCheckpointsBetween({ threadId, loIdx: toIdx + 1, hiIdx: fromIdx })
|
||||
for (const fsPath in lastIdxOfURI) {
|
||||
// apply lowest down content for each uri (or original if not found)
|
||||
|
||||
for (let k = toIdx; k >= 0; k -= 1) {
|
||||
const message = thread.messages[k]
|
||||
if (message.role !== 'checkpoint') continue
|
||||
if (fsPath in message.afterStrOfURI) {
|
||||
found = true
|
||||
writeFullFile({ fsPath, text: message.afterStrOfURI[fsPath] })
|
||||
const beforeStr = this._getBeforeStrAtCheckpoint(message, fsPath, { includeUserModifiedChanges: jumpToUserModified })
|
||||
if (beforeStr !== null) {
|
||||
this._writeFullFile({ fsPath, text: beforeStr })
|
||||
break
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
const originalStr = thread.firstStrOfURI[fsPath]
|
||||
if (originalStr === undefined) continue
|
||||
writeFullFile({ fsPath, text: originalStr })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -982,15 +1005,15 @@ We need to apply latest change for anything that happened between from+1 and to.
|
|||
We only need to do it for files that were edited since `from`, ie files between from+1...to.
|
||||
*/
|
||||
if (toIdx > fromIdx) {
|
||||
const checkpointURIs = this._getAllChangedCheckpointURIs({ threadId, loIdx: fromIdx + 1, hiIdx: toIdx })
|
||||
for (const fsPath of checkpointURIs ?? []) {
|
||||
const { lastIdxOfURI } = this._getCheckpointsBetween({ threadId, loIdx: fromIdx + 1, hiIdx: toIdx })
|
||||
for (const fsPath in lastIdxOfURI) {
|
||||
// apply lowest down content for each uri
|
||||
// (do not need to apply original since we're only applying to files that changed)
|
||||
for (let k = toIdx; k >= fromIdx + 1; k -= 1) {
|
||||
const message = thread.messages[k]
|
||||
if (message.role !== 'checkpoint') continue
|
||||
if (fsPath in message.afterStrOfURI) {
|
||||
writeFullFile({ fsPath, text: message.afterStrOfURI[fsPath] })
|
||||
const beforeStr = this._getBeforeStrAtCheckpoint(message, fsPath, { includeUserModifiedChanges: jumpToUserModified })
|
||||
if (beforeStr !== null) {
|
||||
this._writeFullFile({ fsPath, text: beforeStr })
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
@ -1020,9 +1043,11 @@ We only need to do it for files that were edited since `from`, ie files between
|
|||
const userMessageContent = await chat_userMessageContent(instructions, currSelns) // user message + names of files (NOT content)
|
||||
const userHistoryElt: ChatMessage = { role: 'user', content: userMessageContent, displayContent: instructions, selections: currSelns, state: defaultMessageState }
|
||||
this._addMessageToThread(threadId, userHistoryElt)
|
||||
this._addOrUpdateUserMessageCheckpoint({ threadId })
|
||||
|
||||
this._runChatAgent({ prevSelns, currSelns, threadId, userMessageContent, ...this._currentModelSelectionProps(), })
|
||||
.then(() => {
|
||||
this._addUserCheckpoint({ threadId })
|
||||
})
|
||||
}
|
||||
|
||||
dismissStreamError(threadId: string): void {
|
||||
|
|
@ -1308,7 +1333,6 @@ We only need to do it for files that were edited since `from`, ie files between
|
|||
const model = this._codeEditorService.getActiveCodeEditor()?.getModel()
|
||||
if (model) {
|
||||
this._setThreadState(this.state.currentThreadId, {
|
||||
...defaultThreadState,
|
||||
stagingSelections: [{
|
||||
type: 'File',
|
||||
fileURI: model.uri,
|
||||
|
|
@ -1358,28 +1382,6 @@ We only need to do it for files that were edited since `from`, ie files between
|
|||
this._setState({ allThreads: newThreads }, true) // the current thread just changed (it had a message added to it)
|
||||
}
|
||||
|
||||
|
||||
private _removeMessageFromThread(threadId: string, messageIdx: number) {
|
||||
const { allThreads } = this.state
|
||||
const oldThread = allThreads[threadId]
|
||||
if (!oldThread) return // should never happen
|
||||
// update state and store it
|
||||
const newThreads = {
|
||||
...allThreads,
|
||||
[oldThread.id]: {
|
||||
...oldThread,
|
||||
lastModified: new Date().toISOString(),
|
||||
messages: [
|
||||
...oldThread.messages.slice(0, messageIdx),
|
||||
...oldThread.messages.slice(messageIdx + 1, Infinity),
|
||||
],
|
||||
}
|
||||
}
|
||||
this._storeAllThreads(newThreads)
|
||||
this._setState({ allThreads: newThreads }, true) // the current thread just changed (it had a message added to it)
|
||||
}
|
||||
|
||||
|
||||
// sets the currently selected message (must be undefined if no message is selected)
|
||||
setCurrentlyFocusedMessageIdx(messageIdx: number | undefined) {
|
||||
|
||||
|
|
|
|||
|
|
@ -1826,8 +1826,8 @@ const Checkpoint = ({ threadId, messageIdx }: { threadId: string; messageIdx: nu
|
|||
className='pointer-events-auto cursor-pointer select-none hover:brightness-125 flex items-center justify-center'
|
||||
onClick={() => {
|
||||
// reject all current changes and then jump back
|
||||
commandBarService.acceptOrRejectAllFiles({ behavior: 'reject' })
|
||||
chatThreadService.jumpToCheckpointAfterMessageIdx({ threadId, messageIdx })
|
||||
commandBarService.acceptOrRejectAllFiles({ behavior: 'accept' })
|
||||
chatThreadService.jumpToCheckpointBeforeMessageIdx({ threadId, messageIdx, jumpToUserModified: true })
|
||||
}}>
|
||||
<div className='bg-void-border-1 h-[1px] flex-grow'></div>
|
||||
<div className='px-2'>Checkpoint</div>
|
||||
|
|
@ -2024,15 +2024,18 @@ export const SidebarChat = () => {
|
|||
}, [isHistoryOpen, currentThread.id])
|
||||
|
||||
const numMessages = previousMessages.length
|
||||
const lastMessageIdx = previousMessages.findLastIndex(v => v.role !== 'checkpoint')
|
||||
|
||||
const previousMessagesHTML = useMemo(() => {
|
||||
const threadId = currentThread.id
|
||||
const currCheckpointIdx = chatThreadsState.allThreads[threadId]?.state?.currCheckpointIdx ?? Infinity // if not exist, treat like checkpoint is last message (infinity)
|
||||
|
||||
return previousMessages.map((message, i) => {
|
||||
const isLast = i === numMessages - 1 && (isRunning === 'tool' || isRunning === 'awaiting_user')
|
||||
return <div className={`${currCheckpointIdx < i ? 'opacity-50 pointer-events-none select-none' : ''}`}>
|
||||
<ChatBubble key={getChatBubbleId(currentThread.id, i)}
|
||||
const isLast = i === lastMessageIdx && (isRunning === 'tool' || isRunning === 'awaiting_user')
|
||||
return <div
|
||||
key={getChatBubbleId(currentThread.id, i)}
|
||||
className={`${currCheckpointIdx < i ? 'opacity-50 pointer-events-none select-none' : ''}`}>
|
||||
<ChatBubble
|
||||
chatMessage={message}
|
||||
messageIdx={i}
|
||||
isCommitted={true}
|
||||
|
|
|
|||
|
|
@ -28,16 +28,15 @@ export type ToolRequestApproval<T extends ToolName> = {
|
|||
// checkpoints
|
||||
export type CheckpointEntry = {
|
||||
role: 'checkpoint';
|
||||
type: 'after_user_edits' | 'after_tool_edits';
|
||||
afterStrOfURI: { [fsPath: string]: string };
|
||||
} | { // modifications that only count when undoing/redoing
|
||||
role: 'checkpoint_modification';
|
||||
type: 'user_modifications';
|
||||
afterStrOfURI: { [fsPath: string]: string };
|
||||
type: 'user_edit' | 'tool_edit';
|
||||
beforeStrOfURI: { [fsPath: string]: string | undefined };
|
||||
userModifications: {
|
||||
beforeStrOfURI: { [fsPath: string]: string | undefined };
|
||||
};
|
||||
// diffAreas: null;
|
||||
}
|
||||
|
||||
|
||||
|
||||
// WARNING: changing this format is a big deal!!!!!! need to migrate old format to new format on users' computers so people don't get errors.
|
||||
export type ChatMessage =
|
||||
| {
|
||||
|
|
|
|||
Loading…
Reference in a new issue