fix stream state

This commit is contained in:
Andrew Pareles 2025-03-12 02:02:25 -07:00
parent 11d376325e
commit 79ef6b5773
2 changed files with 97 additions and 83 deletions

View file

@ -310,7 +310,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
}, true)
// re-add the message and stream it
this.addUserMessageAndStreamResponse({ userMessage, chatMode, chatSelections: { prevSelns, currSelns } })
this.addUserMessageAndStreamResponse({ userMessage, chatMode, _chatSelections: { prevSelns, currSelns } })
}
@ -328,14 +328,14 @@ class ChatThreadService extends Disposable implements IChatThreadService {
}
async addUserMessageAndStreamResponse({ userMessage, chatMode, chatSelections }: { userMessage: string, chatMode: ChatMode, chatSelections?: { prevSelns?: StagingSelectionItem[], currSelns?: StagingSelectionItem[] } }) {
async addUserMessageAndStreamResponse({ userMessage, chatMode, _chatSelections }: { userMessage: string, chatMode: ChatMode, _chatSelections?: { prevSelns?: StagingSelectionItem[], currSelns?: StagingSelectionItem[] } }) {
const thread = this.getCurrentThread()
const threadId = thread.id
// selections in all past chats, then in current chat (can have many duplicates here)
const prevSelns: StagingSelectionItem[] = chatSelections?.prevSelns ?? this._getAllSelections()
const currSelns: StagingSelectionItem[] = chatSelections?.currSelns ?? thread.state.stagingSelections
const prevSelns: StagingSelectionItem[] = _chatSelections?.prevSelns ?? this._getAllSelections()
const currSelns: StagingSelectionItem[] = _chatSelections?.currSelns ?? thread.state.stagingSelections
// add user's message to chat history
const instructions = userMessage
@ -347,10 +347,12 @@ class ChatThreadService extends Disposable implements IChatThreadService {
this._setStreamState(threadId, { error: undefined })
const tools: InternalToolInfo[] | undefined = (
chatMode === 'chat' ? undefined
: chatMode === 'agent' ? Object.keys(voidTools).map(toolName => voidTools[toolName as ToolName])
: undefined)
const toolNames: ToolName[] | undefined = chatMode === 'chat' ? undefined
: chatMode === 'gather' ? (Object.keys(voidTools) as ToolName[]).filter(toolName => !toolNamesThatRequireApproval.has(toolName))
: chatMode === 'agent' ? Object.keys(voidTools) as ToolName[]
: undefined
const tools: InternalToolInfo[] | undefined = toolNames?.map(toolName => voidTools[toolName])
// these settings should not change throughout the loop (eg anthropic breaks if you change its thinking mode and it's using tools)
const featureName: FeatureName = 'Chat'
@ -372,8 +374,8 @@ class ChatThreadService extends Disposable implements IChatThreadService {
shouldSendAnotherMessage = false // false by default
nMessagesSent += 1
let res_: () => void // resolves when user approves this tool use (or if tool doesn't require approval)
const awaitable = new Promise<void>((res, rej) => { res_ = res })
let resMessageIsDonePromise: () => void // resolves when user approves this tool use (or if tool doesn't require approval)
const messageIsDonePromise = new Promise<void>((res, rej) => { resMessageIsDonePromise = res })
// replace last userMessage with userMessageFullContent (which contains all the files too)
const messages_ = toLLMChatMessages(this.getCurrentThread().messages)
@ -403,81 +405,92 @@ class ChatThreadService extends Disposable implements IChatThreadService {
},
onFinalMessage: async ({ fullText, toolCalls, fullReasoning, anthropicReasoning }) => {
this._addMessageToThread(threadId, { role: 'assistant', content: fullText, reasoning: fullReasoning, anthropicReasoning })
// if no tools, finish
if ((toolCalls?.length ?? 0) === 0) {
this._addMessageToThread(threadId, { role: 'assistant', content: fullText, reasoning: fullReasoning, anthropicReasoning })
this._setStreamState(threadId, { messageSoFar: undefined, reasoningSoFar: undefined, streamingToken: undefined })
resMessageIsDonePromise()
return
}
else {
this._addMessageToThread(threadId, { role: 'assistant', content: fullText, reasoning: fullReasoning, anthropicReasoning })
this._setStreamState(threadId, { messageSoFar: undefined, reasoningSoFar: undefined }) // clear streaming message
// deal with the tool
const tool: ToolCallType | undefined = toolCalls?.[0]
if (!tool) {
res_()
return
}
const toolName: ToolName = tool.name
shouldSendAnotherMessage = true
// if tools
// clear messageSoFar since we added it to the chat history (but don't clear streamingToken, we're still streaming)
this._setStreamState(threadId, { messageSoFar: undefined, reasoningSoFar: undefined })
// 1. validate tool params
let toolParams: ToolCallParams[ToolName]
try {
const params = await this._toolsService.validateParams[toolName](tool.paramsStr)
toolParams = params
} catch (error) {
const errorMessage = getErrorMessage(error)
this._addMessageToThread(threadId, { role: 'tool', name: toolName, paramsStr: tool.paramsStr, id: tool.id, content: errorMessage, result: { type: 'error', params: undefined, value: errorMessage }, })
res_()
return
}
// 2. if tool requires approval, await the approval
if (toolNamesThatRequireApproval.has(toolName)) {
const voidToolId = generateUuid()
const toolApprovalPromise = new Promise<void>((res, rej) => { this.resRejOfToolAwaitingApproval[voidToolId] = { res, rej } })
this._addMessageToThread(threadId, { role: 'tool_request', name: toolName, params: toolParams, voidToolId: voidToolId })
try {
await toolApprovalPromise
// accepted tool
}
catch (e) {
// TODO!!! test rejection
// if (Math.random() > 0) throw new Error('TESTING')
const errorMessage = 'Tool call was rejected by the user.'
this._addMessageToThread(threadId, { role: 'tool', name: toolName, paramsStr: tool.paramsStr, id: tool.id, content: errorMessage, result: { type: 'rejected', params: toolParams, value: errorMessage }, })
shouldSendAnotherMessage = false // interrupt flow by rejecting
res_()
return
}
}
// 3. call the tool
let toolResult: ToolResultType[typeof toolName]
try {
toolResult = await this._toolsService.callTool[toolName](toolParams as any) // typescript is so bad it doesn't even couple the type of ToolResult with the type of the function being called here
} catch (error) {
const errorMessage = getErrorMessage(error)
this._addMessageToThread(threadId, { role: 'tool', name: toolName, paramsStr: tool.paramsStr, id: tool.id, content: errorMessage, result: { type: 'error', params: toolParams, value: errorMessage }, })
res_()
return
}
// 4. stringify the result to give the LLM
let toolResultStr: string
try {
toolResultStr = this._toolsService.stringOfResult[toolName](toolParams as any, toolResult as any)
} catch (error) {
const errorMessage = `Tool call succeeded, but there was an error stringifying the output.\n${getErrorMessage(error)}`
this._addMessageToThread(threadId, { role: 'tool', name: toolName, paramsStr: tool.paramsStr, id: tool.id, content: errorMessage, result: { type: 'error', params: toolParams, value: errorMessage }, })
res_()
return
}
// 5. add to history
this._addMessageToThread(threadId, { role: 'tool', name: toolName, paramsStr: tool.paramsStr, id: tool.id, content: toolResultStr, result: { type: 'success', params: toolParams, value: toolResult }, })
res_()
// deal with the tool
const tool: ToolCallType | undefined = toolCalls?.[0]
if (!tool) {
this._setStreamState(threadId, { messageSoFar: undefined, reasoningSoFar: undefined, streamingToken: undefined })
resMessageIsDonePromise()
return
}
const toolName: ToolName = tool.name
shouldSendAnotherMessage = true
// 1. validate tool params
let toolParams: ToolCallParams[ToolName]
try {
const params = await this._toolsService.validateParams[toolName](tool.paramsStr)
toolParams = params
} catch (error) {
const errorMessage = getErrorMessage(error)
this._addMessageToThread(threadId, { role: 'tool', name: toolName, paramsStr: tool.paramsStr, id: tool.id, content: errorMessage, result: { type: 'error', params: undefined, value: errorMessage }, })
this._setStreamState(threadId, { messageSoFar: undefined, reasoningSoFar: undefined, streamingToken: undefined })
resMessageIsDonePromise()
return
}
// 2. if tool requires approval, await the approval
if (toolNamesThatRequireApproval.has(toolName)) {
const voidToolId = generateUuid()
const toolApprovalPromise = new Promise<void>((res, rej) => { this.resRejOfToolAwaitingApproval[voidToolId] = { res, rej } })
this._addMessageToThread(threadId, { role: 'tool_request', name: toolName, params: toolParams, voidToolId: voidToolId })
try {
await toolApprovalPromise
// accepted tool
}
catch (e) {
console.log('successfully rejected', voidToolId)
// TODO!!! test rejection
// if (Math.random() > 0) throw new Error('TESTING')
const errorMessage = 'Tool call was rejected by the user.'
this._addMessageToThread(threadId, { role: 'tool', name: toolName, paramsStr: tool.paramsStr, id: tool.id, content: errorMessage, result: { type: 'rejected', params: toolParams, value: errorMessage }, })
shouldSendAnotherMessage = false // interrupt flow by rejecting
this._setStreamState(threadId, { messageSoFar: undefined, reasoningSoFar: undefined, streamingToken: undefined })
resMessageIsDonePromise()
return
}
}
// 3. call the tool
let toolResult: ToolResultType[typeof toolName]
try {
toolResult = await this._toolsService.callTool[toolName](toolParams as any) // typescript is so bad it doesn't even couple the type of ToolResult with the type of the function being called here
} catch (error) {
const errorMessage = getErrorMessage(error)
this._addMessageToThread(threadId, { role: 'tool', name: toolName, paramsStr: tool.paramsStr, id: tool.id, content: errorMessage, result: { type: 'error', params: toolParams, value: errorMessage }, })
this._setStreamState(threadId, { messageSoFar: undefined, reasoningSoFar: undefined, streamingToken: undefined })
resMessageIsDonePromise()
return
}
// 4. stringify the result to give the LLM
let toolResultStr: string
try {
toolResultStr = this._toolsService.stringOfResult[toolName](toolParams as any, toolResult as any)
} catch (error) {
const errorMessage = `Tool call succeeded, but there was an error stringifying the output.\n${getErrorMessage(error)}`
this._addMessageToThread(threadId, { role: 'tool', name: toolName, paramsStr: tool.paramsStr, id: tool.id, content: errorMessage, result: { type: 'error', params: toolParams, value: errorMessage }, })
this._setStreamState(threadId, { messageSoFar: undefined, reasoningSoFar: undefined, streamingToken: undefined })
resMessageIsDonePromise()
return
}
// 5. add to history
this._addMessageToThread(threadId, { role: 'tool', name: toolName, paramsStr: tool.paramsStr, id: tool.id, content: toolResultStr, result: { type: 'success', params: toolParams, value: toolResult }, })
this._setStreamState(threadId, { messageSoFar: undefined, reasoningSoFar: undefined, streamingToken: undefined })
resMessageIsDonePromise()
},
onError: (error) => {
@ -486,13 +499,14 @@ class ChatThreadService extends Disposable implements IChatThreadService {
// add assistant's message to chat history, and clear selection
this._addMessageToThread(threadId, { role: 'assistant', content: messageSoFar, reasoning: reasoningSoFar, anthropicReasoning: null })
this._setStreamState(threadId, { messageSoFar: undefined, reasoningSoFar: undefined, streamingToken: undefined, error })
res_()
resMessageIsDonePromise()
},
})
if (llmCancelToken === null) break
this._setStreamState(threadId, { streamingToken: llmCancelToken })
await awaitable
await messageIsDonePromise
console.log('done awaiting...')
}
}

View file

@ -1620,13 +1620,13 @@ export const SidebarChat = () => {
scrollContainerRef.current?.scrollTo({ top: 0, left: 0 })
}, [isHistoryOpen, currentThread.id])
const numMessages = previousMessages.length + (isStreaming ? 1 : 0)
const numMessages = previousMessages.length
const previousMessagesHTML = useMemo(() => {
return previousMessages.map((message, i) =>
<ChatBubble key={getChatBubbleId(currentThread.id, i)} chatMessage={message} messageIdx={i} isLast={i === numMessages - 1} />
)
}, [previousMessages, currentThread])
}, [previousMessages, currentThread, numMessages])
const streamingChatIdx = previousMessagesHTML.length
const currStreamingMessageHTML = !!(reasoningSoFar || messageSoFar || isStreaming) ?