tool progress!

This commit is contained in:
Andrew Pareles 2025-02-15 02:05:37 -08:00
parent bc6150aeac
commit 152e605856
8 changed files with 215 additions and 113 deletions

View file

@ -14,6 +14,7 @@ import { IRange } from '../../../../editor/common/core/range.js';
import { ILLMMessageService } from '../common/llmMessageService.js';
import { IModelService } from '../../../../editor/common/services/model.js';
import { chat_userMessage, chat_systemMessage } from './prompt/prompts.js';
import { IToolsService, ToolName, voidTools } from '../common/toolsService.js';
// one of the square items that indicates a selection in a chat bubble (NOT a file, a Selection of text)
export type CodeSelection = {
@ -60,6 +61,13 @@ export type ChatMessage =
content: string;
displayContent?: undefined;
}
| {
role: 'tool';
name: string; // internal use
params: string | null; // internal use
content: string | null; // summary of the tool to the LLM
displayContent: string | null; // text message of result
}
// a 'thread' means a chat message history
export type ChatThreads = {
@ -124,7 +132,7 @@ export interface IChatThreadService {
isFocusingMessage(): boolean;
setFocusedMessageIdx(messageIdx: number | undefined): void;
_useFocusedStagingState(messageIdx?: number | undefined): readonly [StagingInfo, (stagingInfo: StagingInfo) => void];
useFocusedStagingState(messageIdx?: number | undefined): readonly [StagingInfo, (stagingInfo: StagingInfo) => void];
editUserMessageAndStreamResponse(userMessage: string, messageIdx: number): Promise<void>;
addUserMessageAndStreamResponse(userMessage: string): Promise<void>;
@ -151,6 +159,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
@IStorageService private readonly _storageService: IStorageService,
@IModelService private readonly _modelService: IModelService,
@ILLMMessageService private readonly _llmMessageService: ILLMMessageService,
@IToolsService private readonly _toolsService: IToolsService,
) {
super()
@ -254,14 +263,120 @@ class ChatThreadService extends Disposable implements IChatThreadService {
// ---------- streaming ----------
finishStreaming = (threadId: string, content: string, error?: { message: string, fullError: Error | null }) => {
private _finishStreamingTextMessage = (threadId: string, content: string, error?: { message: string, fullError: Error | null }) => {
// add assistant's message to chat history, and clear selection
const assistantHistoryElt: ChatMessage = { role: 'assistant', content, displayContent: content || null }
this._addMessageToThread(threadId, assistantHistoryElt)
this._addMessageToThread(threadId, { role: 'assistant', content, displayContent: content || null })
this._setStreamState(threadId, { messageSoFar: undefined, streamingToken: undefined, error })
}
async addUserMessageAndStreamResponse(userMessage: string, stagingOverride?: StagingInfo | null) {
const thread = this.getCurrentThread()
const threadId = thread.id
let threadStaging = thread.staging
const currStaging = stagingOverride ?? threadStaging ?? defaultStaging // don't use _useFocusedStagingState to avoid race conditions with focusing
const { selections: currSelns, } = currStaging
// add user's message to chat history
const instructions = userMessage
const content = await chat_userMessage(instructions, currSelns, this._modelService)
const userHistoryElt: ChatMessage = { role: 'user', content: content, displayContent: instructions, selections: currSelns, staging: null, }
this._addMessageToThread(threadId, userHistoryElt)
this._setStreamState(threadId, { error: undefined })
// agent loop
let shouldContinue = false
do {
shouldContinue = false
console.log('Q')
let res_: () => void
const awaitable = new Promise<void>((res, rej) => { res_ = res })
const llmCancelToken = this._llmMessageService.sendLLMMessage({
messagesType: 'chatMessages',
useProviderFor: 'Ctrl+L',
logging: { loggingName: `Agent` },
messages: [
{ role: 'system', content: chat_systemMessage },
...this.getCurrentThread().messages.map(m => ({ role: m.role, content: m.content || '(empty model output)' })),
],
tools: [voidTools['read_file']],
onText: ({ fullText }) => {
this._setStreamState(threadId, { messageSoFar: fullText })
},
onFinalMessage: async ({ fullText, tools }) => {
if (tools.length === 0) {
this._finishStreamingTextMessage(threadId, fullText)
}
else {
for (const tool of tools) {
if (!(tool.name in this._toolsService.toolFns)) {
this._addMessageToThread(threadId, { role: 'tool', name: tool.name, params: tool.args, content: `Error: This tool was not recognized, so it was not called.`, displayContent: `Error: tool not recognized.`, })
}
else {
const toolName = tool.name as ToolName
const toolResult = await this._toolsService.toolFns[toolName](JSON.parse(tool.args))
const string = this._toolsService.toolResultToString[toolName](toolResult as any)
this._addMessageToThread(threadId, { role: 'tool', name: tool.name, params: tool.args, content: string, displayContent: string, })
shouldContinue = true
}
}
}
res_()
},
onError: (error) => {
this._finishStreamingTextMessage(threadId, this.streamState[threadId]?.messageSoFar ?? '', error)
res_()
},
})
if (llmCancelToken === null) return
this._setStreamState(threadId, { streamingToken: llmCancelToken })
await awaitable
}
while (shouldContinue);
// const llmCancelToken = this._llmMessageService.sendLLMMessage({
// messagesType: 'chatMessages',
// logging: { loggingName: 'Chat' },
// useProviderFor: 'Ctrl+L',
// messages: [
// { role: 'system', content: chat_systemMessage },
// ...this.getCurrentThread().messages.map(m => ({ role: m.role, content: m.content || '(empty model output)' })),
// ],
// onText: ({ newText, fullText }) => {
// this._setStreamState(threadId, { messageSoFar: fullText })
// },
// onFinalMessage: ({ fullText: content }) => {
// this._finishStreaming(threadId, content)
// },
// onError: (error) => {
// this._finishStreaming(threadId, this.streamState[threadId]?.messageSoFar ?? '', error)
// },
// })
// if (llmCancelToken === null) return
// this._setStreamState(threadId, { streamingToken: llmCancelToken })
}
async editUserMessageAndStreamResponse(userMessage: string, messageIdx: number) {
const thread = this.getCurrentThread()
@ -284,58 +399,18 @@ class ChatThreadService extends Disposable implements IChatThreadService {
}
}, true)
// stream the edit
// re-add the message and stream it
this.addUserMessageAndStreamResponse(userMessage, messageToReplace.staging)
}
async addUserMessageAndStreamResponse(userMessage: string, stagingOverride?: StagingInfo | null) {
const thread = this.getCurrentThread()
const threadId = thread.id
let threadStaging = thread.staging
const currStaging = stagingOverride ?? threadStaging ?? defaultStaging // don't use _useFocusedStagingState to avoid race conditions with focusing
const { selections: currSelns, } = currStaging
// add user's message to chat history
const instructions = userMessage
const content = await chat_userMessage(instructions, currSelns, this._modelService)
const userHistoryElt: ChatMessage = { role: 'user', content: content, displayContent: instructions, selections: currSelns, staging: null, }
this._addMessageToThread(threadId, userHistoryElt)
this._setStreamState(threadId, { error: undefined })
const llmCancelToken = this._llmMessageService.sendLLMMessage({
messagesType: 'chatMessages',
logging: { loggingName: 'Chat' },
useProviderFor: 'Ctrl+L',
messages: [
{ role: 'system', content: chat_systemMessage },
...this.getCurrentThread().messages.map(m => ({ role: m.role, content: m.content || '(empty model output)' })),
],
onText: ({ newText, fullText }) => {
this._setStreamState(threadId, { messageSoFar: fullText })
},
onFinalMessage: ({ fullText: content }) => {
this.finishStreaming(threadId, content)
},
onError: (error) => {
this.finishStreaming(threadId, this.streamState[threadId]?.messageSoFar ?? '', error)
},
})
if (llmCancelToken === null) return
this._setStreamState(threadId, { streamingToken: llmCancelToken })
}
cancelStreaming(threadId: string) {
const llmCancelToken = this.streamState[threadId]?.streamingToken
if (llmCancelToken !== undefined) this._llmMessageService.abort(llmCancelToken)
this.finishStreaming(threadId, this.streamState[threadId]?.messageSoFar ?? '')
this._finishStreamingTextMessage(threadId, this.streamState[threadId]?.messageSoFar ?? '')
}
dismissStreamError(threadId: string): void {
@ -475,7 +550,7 @@ class ChatThreadService extends Disposable implements IChatThreadService {
}
// gets `staging` and `setStaging` of the currently focused element, given the index of the currently selected message (or undefined if no message is selected)
_useFocusedStagingState(messageIdx?: number | undefined) {
useFocusedStagingState(messageIdx?: number | undefined) {
const defaultStaging = { isBeingEdited: false, selections: [], text: '' }

View file

@ -42,7 +42,6 @@ import { ILLMMessageService } from '../common/llmMessageService.js';
import { LLMChatMessage, _InternalLLMChatMessage, errorDetails } from '../common/llmMessageTypes.js';
import { IMetricsService } from '../common/metricsService.js';
import { VSReadFile } from './helpers/readFile.js';
import { voidTools } from '../common/toolsService.js';
const configOfBG = (color: Color) => {
return { dark: color, light: color, hcDark: color, hcLight: color, }
@ -1140,40 +1139,6 @@ class EditCodeService extends Disposable implements IEditCodeService {
async startAgent(queryStr: string) {
// agent loop
const messages: LLMChatMessage[] = []
while (true) {
await new Promise((res, rej) => {
this._llmMessageService.sendLLMMessage({
messagesType: 'chatMessages',
tools: [voidTools['read_file']],
useProviderFor: 'Apply',
logging: { loggingName: `Agent` },
messages,
onText: ({ fullText }) => {
},
onFinalMessage: async ({ fullText, tools }) => {
res(tools)
},
onError: (e) => {
},
})
})
}
}
stopAgent() {
}
public startApplying(opts: StartApplyingOpts) {

View file

@ -11,7 +11,11 @@ export const VSReadFile = async (modelService: IModelService, uri: URI): Promise
}
export const VSReadFileRaw = async (fileService: IFileService, uri: URI) => {
const res = await fileService.readFile(uri)
const str = res.value.toString()
return str
try {
const res = await fileService.readFile(uri)
const str = res.value.toString()
return str
} catch (e) {
return null
}
}

View file

@ -551,7 +551,7 @@ const ChatBubble = ({ chatMessage, isLoading, messageIdx }: { chatMessage: ChatM
const chatThreadsService = accessor.get('IChatThreadService')
// edit mode state
const [staging, setStaging] = chatThreadsService._useFocusedStagingState(messageIdx)
const [staging, setStaging] = chatThreadsService.useFocusedStagingState(messageIdx)
const mode: ChatBubbleMode = staging.isBeingEdited ? 'edit' : 'display'
const [isFocused, setIsFocused] = useState(false)
const [isHovered, setIsHovered] = useState(false)
@ -682,6 +682,9 @@ const ChatBubble = ({ chatMessage, isLoading, messageIdx }: { chatMessage: ChatM
chatbubbleContents = <ChatMarkdownRender string={chatMessage.displayContent ?? ''} chatMessageLocation={chatMessageLocation} />
}
else if (role === 'tool'){
chatbubbleContents = chatMessage.name
}
return <div
// align chatbubble accoridng to role
@ -765,7 +768,7 @@ export const SidebarChat = () => {
const currentThread = chatThreadsService.getCurrentThread()
const previousMessages = currentThread?.messages ?? []
const [staging, setStaging] = chatThreadsService._useFocusedStagingState()
const [staging, setStaging] = chatThreadsService.useFocusedStagingState()
// stream state
const currThreadStreamState = useChatThreadsStreamState(chatThreadsState.currentThreadId)
@ -822,7 +825,7 @@ export const SidebarChat = () => {
const prevMessagesHTML = useMemo(() => {
return previousMessages.map((message, i) =>
<ChatBubble key={`${message.displayContent}-${i}`} chatMessage={message} messageIdx={i} />
<ChatBubble key={i} chatMessage={message} messageIdx={i} />
)
}, [previousMessages])
@ -836,6 +839,7 @@ export const SidebarChat = () => {
const messagesHTML = <ScrollToBottomContainer
key={currentThread.id} // force rerender on all children if id changes
scrollContainerRef={scrollContainerRef}
className={`
w-full h-auto

View file

@ -135,7 +135,7 @@ registerAction2(class extends Action2 {
const chatThreadService = accessor.get(IChatThreadService)
const focusedMessageIdx = chatThreadService.getFocusedMessageIdx()
const [staging, setStaging] = chatThreadService._useFocusedStagingState(focusedMessageIdx)
const [staging, setStaging] = chatThreadService.useFocusedStagingState(focusedMessageIdx)
const selections = staging.selections || []
const setSelections = (s: StagingSelectionItem[]) => setStaging({ ...staging, selections: s })

View file

@ -27,7 +27,7 @@ export type OnError = (p: { message: string, fullError: Error | null }) => void
export type AbortRef = { current: (() => void) | null }
export type LLMChatMessage = {
role: 'system' | 'user' | 'assistant';
role: 'system' | 'user' | 'assistant' | 'tool';
content: string;
}

View file

@ -29,7 +29,7 @@ const paginationHelper = {
param: { pageNumber: { type: 'number', description: 'The page number (optional, default is 1).' }, }
} as const
export const voidTools: { [name: string]: InternalToolInfo } = {
export const voidTools = {
read_file: {
name: 'read_file',
description: 'Returns file contents of a given URI.',
@ -73,16 +73,22 @@ export const voidTools: { [name: string]: InternalToolInfo } = {
// description: 'Searches files semantically for the given string query.',
// // RAG
// },
}
} satisfies { [name: string]: InternalToolInfo }
export type ToolName = keyof typeof voidTools
type ToolParamNames<T extends ToolName> = keyof typeof voidTools[T]['params']
type ToolParamsObj<T extends ToolName> = { [paramName in ToolParamNames<T>]: unknown }
export type ToolParamNames<T extends ToolName> = keyof typeof voidTools[T]['params']
export type ToolParamsObj<T extends ToolName> = { [paramName in ToolParamNames<T>]: unknown }
export type ToolCallReturnType<T extends ToolName>
= T extends 'read_file' ? Promise<string>
: T extends 'list_dir' ? Promise<string>
: T extends 'pathname_search' ? Promise<string | URI[]>
: T extends 'search' ? Promise<string | URI[]>
: never
export type ToolFns = { [T in ToolName]: (p: string) => ToolCallReturnType<T> }
export type ToolResultToString = { [T in ToolName]: (result: Awaited<ToolCallReturnType<T>>) => string }
async function generateDirectoryTreeMd(fileService: IFileService, rootURI: URI): Promise<string> {
@ -110,17 +116,21 @@ const validateURI = (uriStr: unknown) => {
return uri
}
export interface IToolService {
export interface IToolsService {
readonly _serviceBrand: undefined;
toolFns: ToolFns;
toolResultToString: ToolResultToString;
}
export const IToolService = createDecorator<IToolService>('ToolService');
export const IToolsService = createDecorator<IToolsService>('ToolsService');
export class ToolService implements IToolService {
export class ToolsService implements IToolsService {
readonly _serviceBrand: undefined;
public toolFns
public toolFns: ToolFns
public toolResultToString: ToolResultToString
constructor(
@IFileService fileService: IFileService,
@ -132,29 +142,56 @@ export class ToolService implements IToolService {
const queryBuilder = instantiationService.createInstance(QueryBuilder);
const parseObj = <T extends ToolName,>(s: string): { [s: string]: unknown } | null => {
try {
const o = JSON.parse(s)
return o
}
catch (e) {
return null
}
}
const invalidToolParamMsg = '(LLM parameter format was invalid for this tool)'
this.toolFns = {
read_file: async ({ uri: uriStr }: ToolParamsObj<'read_file'>) => {
read_file: async (s: string) => {
const o = parseObj(s)
if (!o) return invalidToolParamMsg
const { uri: uriStr } = o
const uri = validateURI(uriStr)
const fileContents = await VSReadFileRaw(fileService, uri)
return fileContents ?? '(could not read file)'
return fileContents ?? invalidToolParamMsg
},
list_dir: async ({ uri: uriStr }: ToolParamsObj<'list_dir'>) => {
list_dir: async (s: string) => {
const o = parseObj(s)
if (!o) return invalidToolParamMsg
const { uri: uriStr } = o
const uri = validateURI(uriStr)
// TODO!!!! check to make sure in workspace
// TODO check to make sure is not gitignored
const treeStr = await generateDirectoryTreeMd(fileService, uri)
return treeStr
},
pathname_search: async ({ query: queryStr }: ToolParamsObj<'pathname_search'>) => {
if (typeof queryStr !== 'string') return '(Error: query was not a string)'
pathname_search: async (s: string) => {
const o = parseObj(s)
if (!o) return invalidToolParamMsg
const { query: queryStr } = o
if (typeof queryStr !== 'string') return 'Error: query was not a string'
const query = queryBuilder.file(workspaceContextService.getWorkspace().folders.map(f => f.uri), { filePattern: queryStr, })
const data = await searchService.fileSearch(query, CancellationToken.None)
const URIs = data.results.map(({ resource, results }) => resource.fsPath)
const URIs = data.results.map(({ resource, results }) => resource)
return URIs
},
search: async ({ query: queryStr }: ToolParamsObj<'search'>) => {
if (typeof queryStr !== 'string') return '(Error: query was not a string)'
search: async (s: string) => {
const o = parseObj(s)
if (!o) return '(could not search)'
const { query: queryStr } = o
if (typeof queryStr !== 'string') return 'Error: query was not a string'
const query = queryBuilder.text({ pattern: queryStr, }, workspaceContextService.getWorkspace().folders.map(f => f.uri))
const data = await searchService.textSearch(query, CancellationToken.None)
@ -164,6 +201,23 @@ export class ToolService implements IToolService {
}
this.toolResultToString = {
read_file: (URIs) => {
return URIs
},
list_dir: (URIs) => {
return URIs
},
pathname_search: (URIs) => {
if (typeof URIs === 'string') return URIs
return URIs.map(uri => uri.fsPath).join('\n')
},
search: (URIs) => {
if (typeof URIs === 'string') return URIs
return URIs.map(uri => uri.fsPath).join('\n')
},
}
}
@ -171,5 +225,5 @@ export class ToolService implements IToolService {
}
registerSingleton(IToolService, ToolService, InstantiationType.Eager);
registerSingleton(IToolsService, ToolsService, InstantiationType.Eager);

View file

@ -6,7 +6,7 @@
import Anthropic from '@anthropic-ai/sdk';
import { _InternalSendLLMChatMessageFnType } from '../../common/llmMessageTypes.js';
import { anthropicMaxPossibleTokens } from '../../common/voidSettingsTypes.js';
import { InternalToolInfo, voidTools } from '../../common/toolsService.js';
import { InternalToolInfo } from '../../common/toolsService.js';
@ -28,7 +28,7 @@ export const toAnthropicTool = (toolInfo: InternalToolInfo) => {
export const sendAnthropicChat: _InternalSendLLMChatMessageFnType = ({ messages, onText, onFinalMessage, onError, settingsOfProvider, modelName, _setAborter }) => {
export const sendAnthropicChat: _InternalSendLLMChatMessageFnType = ({ messages, onText, onFinalMessage, onError, settingsOfProvider, modelName, _setAborter, tools }) => {
const thisConfig = settingsOfProvider.anthropic
@ -45,8 +45,8 @@ export const sendAnthropicChat: _InternalSendLLMChatMessageFnType = ({ messages,
messages: messages,
model: modelName,
max_tokens: maxTokens,
tools: [toAnthropicTool(voidTools.list_dir)]
});
tools: tools?.map(tool => toAnthropicTool(tool))
})
// when receive text