update types for tool support

This commit is contained in:
Andrew Pareles 2025-02-14 01:05:10 -08:00
parent 0bcd88dad6
commit 9cfcf396c1
6 changed files with 53 additions and 52 deletions

View file

@ -3,6 +3,7 @@
* Licensed under the Apache License, Version 2.0. See LICENSE.txt for more information.
*--------------------------------------------------------------------------------------*/
import { InternalToolInfo } from './toolsService.js'
import { FeatureName, ProviderName, SettingsOfProvider } from './voidSettingsTypes.js'
@ -44,9 +45,11 @@ type _InternalSendFIMMessage = {
type SendLLMType = {
messagesType: 'chatMessages';
messages: LLMChatMessage[];
tools?: InternalToolInfo[];
} | {
messagesType: 'FIMMessage';
messages: _InternalSendFIMMessage;
tools?: undefined;
}
// service types
@ -96,6 +99,8 @@ export type _InternalSendLLMChatMessageFnType = (
modelName: string;
_setAborter: (aborter: () => void) => void;
tools?: InternalToolInfo[],
messages: _InternalLLMChatMessage[];
}
) => void

View file

@ -15,6 +15,7 @@ import { ISearchService } from '../../../../workbench/services/search/common/sea
// we do this using Anthropic's style and convert to OpenAI style later
export type InternalToolInfo = {
name: string,
description: string,
params: {
[paramName: string]: { type: string, description: string | undefined } // name -> type
@ -23,13 +24,14 @@ export type InternalToolInfo = {
}
// helper
const pagination = {
const paginationHelper = {
desc: `Very large results may be paginated (indicated in the result). Pagination fails gracefully if out of bounds or invalid page number.`,
param: { pageNumber: { type: 'number', description: 'The page number (optional, default is 1).' }, }
} as const
export const contextTools = {
read_file: {
name: 'read_file',
description: 'Returns file contents of a given URI.',
params: {
uri: { type: 'string', description: undefined },
@ -38,28 +40,31 @@ export const contextTools = {
},
list_dir: {
description: `Returns all file names and folder names in a given URI. ${pagination.desc}`,
name: 'list_dir',
description: `Returns all file names and folder names in a given URI. ${paginationHelper.desc}`,
params: {
uri: { type: 'string', description: undefined },
...pagination.param
...paginationHelper.param
},
required: ['uri'],
},
pathname_search: {
description: `Returns all pathnames that match a given grep query. You should use this when looking for a file with a specific name or path. This does NOT search file content. ${pagination.desc}`,
name: 'pathname_search',
description: `Returns all pathnames that match a given grep query. You should use this when looking for a file with a specific name or path. This does NOT search file content. ${paginationHelper.desc}`,
params: {
query: { type: 'string', description: undefined },
...pagination.param,
...paginationHelper.param,
},
required: ['query']
},
search: {
description: `Returns all code excerpts containing the given string or grep query. This does NOT search pathname. As a follow-up, you may want to use read_file to view the full file contents of the results. ${pagination.desc}`,
name: 'search',
description: `Returns all code excerpts containing the given string or grep query. This does NOT search pathname. As a follow-up, you may want to use read_file to view the full file contents of the results. ${paginationHelper.desc}`,
params: {
query: { type: 'string', description: undefined },
...pagination.param,
...paginationHelper.param,
},
required: ['query'],
},
@ -69,26 +74,18 @@ export const contextTools = {
// // RAG
// },
} as const satisfies { [name: string]: InternalToolInfo }
}
export type ContextToolName = keyof typeof contextTools
type ContextToolParamNames<T extends ContextToolName> = keyof typeof contextTools[T]['params']
type ContextToolParams<T extends ContextToolName> = { [paramName in ContextToolParamNames<T>]: unknown }
type AllContextToolCallFns = {
[ToolName in ContextToolName]: ((p: (ContextToolParams<ToolName>)) => Promise<string>)
}
// TODO check to make sure in workspace
// TODO check to make sure is not gitignored
async function generateDirectoryTreeMd(fileService: IFileService, rootURI: URI): Promise<string> {
let output = ''
function traverseChildren(children: IFileStat[], depth: number) {
@ -116,7 +113,6 @@ const validateURI = (uriStr: unknown) => {
export interface IToolService {
readonly _serviceBrand: undefined;
callContextTool: <T extends ContextToolName>(toolName: T, params: ContextToolParams<T>) => Promise<string>
}
export const IToolService = createDecorator<IToolService>('ToolService');
@ -125,7 +121,7 @@ export class ToolService implements IToolService {
readonly _serviceBrand: undefined;
contextToolCallFns: AllContextToolCallFns
public contextToolCallFns
constructor(
@IFileService fileService: IFileService,
@ -138,31 +134,33 @@ export class ToolService implements IToolService {
const queryBuilder = instantiationService.createInstance(QueryBuilder);
this.contextToolCallFns = {
read_file: async ({ uri: uriStr }) => {
read_file: async ({ uri: uriStr }: ContextToolParams<'read_file'>) => {
const uri = validateURI(uriStr)
const fileContents = await VSReadFileRaw(fileService, uri)
return fileContents ?? '(could not read file)'
},
list_dir: async ({ uri: uriStr }) => {
list_dir: async ({ uri: uriStr }: ContextToolParams<'list_dir'>) => {
const uri = validateURI(uriStr)
// TODO!!!! check to make sure in workspace
// TODO check to make sure is not gitignored
const treeStr = await generateDirectoryTreeMd(fileService, uri)
return treeStr
},
pathname_search: async ({ query: queryStr }) => {
pathname_search: async ({ query: queryStr }: ContextToolParams<'pathname_search'>) => {
if (typeof queryStr !== 'string') return '(Error: query was not a string)'
const query = queryBuilder.file(workspaceContextService.getWorkspace().folders.map(f => f.uri), { filePattern: queryStr, });
const data = await searchService.fileSearch(query, CancellationToken.None);
const str = data.results.map(({ resource, results }) => resource.fsPath).join('\n')
return str
const URIs = data.results.map(({ resource, results }) => resource.fsPath)
return URIs
},
search: async ({ query: queryStr }) => {
search: async ({ query: queryStr }: ContextToolParams<'search'>) => {
if (typeof queryStr !== 'string') return '(Error: query was not a string)'
const query = queryBuilder.text({ pattern: queryStr, }, workspaceContextService.getWorkspace().folders.map(f => f.uri));
const data = await searchService.textSearch(query, CancellationToken.None);
const str = data.results.map(({ resource, results }) => resource)
return str as any
const URIs = data.results.map(({ resource, results }) => resource)
return URIs
},
}
@ -172,12 +170,6 @@ export class ToolService implements IToolService {
}
callContextTool: IToolService['callContextTool'] = (toolName, params) => {
return this.contextToolCallFns[toolName](params)
}
}
registerSingleton(IToolService, ToolService, InstantiationType.Eager);

View file

@ -11,10 +11,10 @@ import { InternalToolInfo } from '../../common/toolsService.js';
export const toAnthropicTool = (toolName: string, toolInfo: InternalToolInfo) => {
const { description, params, required } = toolInfo
export const toAnthropicTool = (toolInfo: InternalToolInfo) => {
const { name, description, params, required } = toolInfo
return {
name: toolName,
name: name,
description: description,
input_schema: {
type: 'object',
@ -45,6 +45,7 @@ export const sendAnthropicChat: _InternalSendLLMChatMessageFnType = ({ messages,
messages: messages,
model: modelName,
max_tokens: maxTokens,
// tools: [toAnthropicTool(contextTools.list_dir)]
});
@ -60,12 +61,9 @@ export const sendAnthropicChat: _InternalSendLLMChatMessageFnType = ({ messages,
if (e.type === 'content_block_start') {
if (e.content_block.type !== 'tool_use') return
const index = e.index
const tool = e.content_block
if (!toolCallOfIndex[index])
toolCallOfIndex[index] = { name: '', args: '' }
toolCallOfIndex[index].name += tool.name ?? ''
toolCallOfIndex[index].args += tool.input ?? ''
if (!toolCallOfIndex[index]) toolCallOfIndex[index] = { name: '', args: '' }
toolCallOfIndex[index].name += e.content_block.name ?? ''
toolCallOfIndex[index].args += e.content_block.input ?? ''
}
else if (e.type === 'content_block_delta') {
if (e.delta.type !== 'input_json_delta') return
@ -79,7 +77,7 @@ export const sendAnthropicChat: _InternalSendLLMChatMessageFnType = ({ messages,
stream.on('finalMessage', (response) => {
// stringify the response's content
const content = response.content.map(c => c.type === 'text' ? c.text : '').join('\n')
const tools = response.content.map(c => c.type === 'tool_use' ? { name: c.name, input: c.input } : null)
const tools = response.content.map(c => c.type === 'tool_use' ? { name: c.name, input: c.input } : null).filter(c => !!c)
console.log("TOOLS!!!!", typeof tools[0]?.input, JSON.stringify(tools, null, 2))

View file

@ -99,6 +99,11 @@ export const sendOllamaChat: _InternalSendLLMChatMessageFnType = ({ messages, on
// iterate through the stream
for await (const chunk of stream) {
const newText = chunk.message.content;
// chunk.message.tool_calls[0].function.arguments
fullText += newText;
onText({ newText, fullText });
}

View file

@ -14,12 +14,12 @@ import { InternalToolInfo } from '../../common/toolsService.js';
// prompting - https://platform.openai.com/docs/guides/reasoning#advice-on-prompting
export const toOpenAITool = (toolName: string, toolInfo: InternalToolInfo) => {
const { description, params, required } = toolInfo
export const toOpenAITool = (toolInfo: InternalToolInfo) => {
const { name, description, params, required } = toolInfo
return {
type: 'function',
function: {
name: toolName,
name: name,
description: description,
parameters: {
type: 'object',

View file

@ -61,6 +61,7 @@ export const sendLLMMessage = ({
settingsOfProvider,
providerName,
modelName,
tools,
}: SendLLMMessageParams,
metricsService: IMetricsService
@ -141,27 +142,27 @@ export const sendLLMMessage = ({
case 'deepseek':
case 'openAICompatible':
if (messagesType === 'FIMMessage') sendOpenAIFIM({ messages: messages_, onText, onFinalMessage, onError, settingsOfProvider, modelName, _setAborter, providerName });
else /* */ sendOpenAIChat({ messages: messagesArr, onText, onFinalMessage, onError, settingsOfProvider, modelName, _setAborter, providerName });
else /* */ sendOpenAIChat({ messages: messagesArr, onText, onFinalMessage, onError, settingsOfProvider, modelName, _setAborter, providerName, tools });
break;
case 'ollama':
if (messagesType === 'FIMMessage') sendOllamaFIM({ messages: messages_, onText, onFinalMessage, onError, settingsOfProvider, modelName, _setAborter, providerName })
else /* */ sendOllamaChat({ messages: messagesArr, onText, onFinalMessage, onError, settingsOfProvider, modelName, _setAborter, providerName })
if (messagesType === 'FIMMessage') sendOllamaFIM({ messages: messages_, onText, onFinalMessage, onError, settingsOfProvider, modelName, _setAborter, providerName });
else /* */ sendOllamaChat({ messages: messagesArr, onText, onFinalMessage, onError, settingsOfProvider, modelName, _setAborter, providerName, tools });
break;
case 'anthropic':
if (messagesType === 'FIMMessage') onFinalMessage({ fullText: 'TODO - Anthropic FIM' })
else /* */ sendAnthropicChat({ messages: messagesArr, onText, onFinalMessage, onError, settingsOfProvider, modelName, _setAborter, providerName });
else /* */ sendAnthropicChat({ messages: messagesArr, onText, onFinalMessage, onError, settingsOfProvider, modelName, _setAborter, providerName, tools });
break;
case 'gemini':
if (messagesType === 'FIMMessage') onFinalMessage({ fullText: 'TODO - Gemini FIM' })
else /* */ sendGeminiChat({ messages: messagesArr, onText, onFinalMessage, onError, settingsOfProvider, modelName, _setAborter, providerName });
else /* */ sendGeminiChat({ messages: messagesArr, onText, onFinalMessage, onError, settingsOfProvider, modelName, _setAborter, providerName, tools });
break;
case 'groq':
if (messagesType === 'FIMMessage') onFinalMessage({ fullText: 'TODO - Groq FIM' })
else /* */ sendGroqChat({ messages: messagesArr, onText, onFinalMessage, onError, settingsOfProvider, modelName, _setAborter, providerName });
else /* */ sendGroqChat({ messages: messagesArr, onText, onFinalMessage, onError, settingsOfProvider, modelName, _setAborter, providerName, tools });
break;
case 'mistral':
if (messagesType === 'FIMMessage') onFinalMessage({ fullText: 'TODO - Mistral FIM' })
else /* */ sendMistralChat({ messages: messagesArr, onText, onFinalMessage, onError, settingsOfProvider, modelName, _setAborter, providerName });
else /* */ sendMistralChat({ messages: messagesArr, onText, onFinalMessage, onError, settingsOfProvider, modelName, _setAborter, providerName, tools });
break;
default:
onError({ message: `Error: Void provider was "${providerName}", which is not recognized.`, fullError: null })