From f330723fdb4bf4d74278270be84c4e869cb21814 Mon Sep 17 00:00:00 2001 From: Rob Lourens Date: Thu, 10 Oct 2024 17:47:39 -0700 Subject: [PATCH] Get tool calling working across multiple requests --- chat-tools-sample/src/extension.ts | 40 ++++++++++++++++++++++++--- chat-tools-sample/src/toolsPrompt.tsx | 21 ++++++++------ 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/chat-tools-sample/src/extension.ts b/chat-tools-sample/src/extension.ts index 41acf62d..37a4af32 100644 --- a/chat-tools-sample/src/extension.ts +++ b/chat-tools-sample/src/extension.ts @@ -188,6 +188,29 @@ df\`\`\``), context.subscriptions.push(toolUser); } +interface TsxToolUserMetadata { + toolCallRounds: ToolCallRound[]; + toolCallResults: Record; +} + +export function isTsxToolUserMetadata(obj: unknown): obj is TsxToolUserMetadata { + if (typeof obj !== 'object' || obj === null) { + return false; + } + + const metadata = obj as TsxToolUserMetadata; + if (!Array.isArray(metadata.toolCallRounds)) { + return false; + } + + if (typeof metadata.toolCallResults !== 'object' || metadata.toolCallResults === null) { + return false; + } + + // If you ever change the shape of ToolCallRound, remember to make this stricter + return true; +} + function registerChatParticipant2(context: vscode.ExtensionContext) { const handler: vscode.ChatRequestHandler = async (request: vscode.ChatRequest, chatContext: vscode.ChatContext, stream: vscode.ChatResponseStream, token: vscode.CancellationToken) => { const models = await vscode.lm.selectChatModels({ @@ -216,13 +239,13 @@ function registerChatParticipant2(context: vscode.ExtensionContext) { context: chatContext, request, toolCallRounds: [], - toolCallResults: new Map() + toolCallResults: {} }, { modelMaxPromptTokens: model.maxInputTokens }, model) const toolReferences = [...request.toolReferences]; - const accumulatedToolCalls = new Map(); + const accumulatedToolResults: Record = {}; const toolCallRounds: ToolCallRound[] = []; const runWithFunctions = async (): Promise => { const requestedTool = toolReferences.shift(); @@ -261,7 +284,7 @@ function registerChatParticipant2(context: vscode.ExtensionContext) { context: chatContext, request, toolCallRounds, - toolCallResults: accumulatedToolCalls + toolCallResults: accumulatedToolResults }, { modelMaxPromptTokens: model.maxInputTokens }, model)); @@ -269,7 +292,7 @@ function registerChatParticipant2(context: vscode.ExtensionContext) { const toolResultMetadata = result.metadatas.getAll(ToolResultMetadata) if (toolResultMetadata?.length) { // TODO flatten - toolResultMetadata.forEach(meta => meta.resultMap.forEach((value, key) => accumulatedToolCalls.set(key, value))); + toolResultMetadata.forEach(meta => meta.resultMap.forEach((value, key) => accumulatedToolResults[key] = value)); } // RE-enter @@ -278,6 +301,15 @@ function registerChatParticipant2(context: vscode.ExtensionContext) { }; await runWithFunctions(); + + return { + metadata: { + toolInfo: { + toolCallResults: accumulatedToolResults, + toolCallRounds + } satisfies TsxToolUserMetadata + } + } }; const toolUser = vscode.chat.createChatParticipant('chat-tools-sample.tools2', handler); diff --git a/chat-tools-sample/src/toolsPrompt.tsx b/chat-tools-sample/src/toolsPrompt.tsx index f576701d..226b3273 100644 --- a/chat-tools-sample/src/toolsPrompt.tsx +++ b/chat-tools-sample/src/toolsPrompt.tsx @@ -12,6 +12,7 @@ import { } from '@vscode/prompt-tsx'; import { Chunk, ToolMessage, ToolResult } from '@vscode/prompt-tsx/dist/base/promptElements'; import * as vscode from 'vscode'; +import { isTsxToolUserMetadata } from './extension'; export interface ToolCallRound { response: string; @@ -22,7 +23,7 @@ export interface ToolUserProps extends BasePromptElementProps { request: vscode.ChatRequest; context: vscode.ChatContext; toolCallRounds: ToolCallRound[]; - toolCallResults: Map; + toolCallResults: Record; } export class ToolUserPrompt extends PromptElement { @@ -64,8 +65,8 @@ export class ToolUserPrompt extends PromptElement { interface ToolCallsProps extends BasePromptElementProps { toolCallRounds: ToolCallRound[]; - toolCallResults: Map; - toolInvocationToken: vscode.ChatParticipantToolToken; + toolCallResults: Record; + toolInvocationToken: vscode.ChatParticipantToolToken | undefined; } const agentSupportedContentTypes = [promptTsxContentType, 'text/plain']; @@ -79,6 +80,7 @@ class ToolCalls extends PromptElement { return <> {await Promise.all(this.props.toolCallRounds.map(round => this.renderOneToolCallRound(round, sizing)))} + Above is the result of calling one or more tools. The user cannot see the results, so you should explain them to the user if referencing them in your answer. } @@ -86,11 +88,10 @@ class ToolCalls extends PromptElement { // TODO- prompt-tsx export this type? // TODO- at what level do the parameters get stringified? const assistantToolCalls: any[] = round.toolCalls.map(tc => ({ type: 'function', function: { name: tc.name, arguments: JSON.stringify(tc.parameters) }, id: tc.toolCallId })); - // TODO@prompt-tsx- don't remove "empty" assistant messages! const toolResultMap = new Map(); const budget = Math.floor(sizing.tokenBudget / round.toolCalls.length); - // TODO@prompt-tsx- this would be a bit easier with a ToolCall element, but we can only return one instance of a metadata right now + // TODO refactor into multiple elements const toolCallSizing: PromptSizing = { ...sizing, tokenBudget: budget, @@ -98,7 +99,7 @@ class ToolCalls extends PromptElement { return - {round.response || 'TODO'} + {round.response} {await Promise.all(round.toolCalls.map(async toolCall => { const result = await this.renderOneToolCall(toolCall, toolCallSizing, this.props.toolInvocationToken); if (result.toolResult) { @@ -106,7 +107,6 @@ class ToolCalls extends PromptElement { } return result.message; }))} - Above is the result of calling one or more tools. The user cannot see the results, so you should explain them to the user if referencing them in your answer. ; } @@ -128,7 +128,7 @@ class ToolCalls extends PromptElement { countTokens: async (content: string) => sizing.countTokens(content), }; - const toolResult = this.props.toolCallResults.get(toolCall.toolCallId) ?? + const toolResult = this.props.toolCallResults[toolCall.toolCallId] ?? await vscode.lm.invokeTool(toolCall.name, { parameters: toolCall.parameters, requestedContentTypes: [contentType], toolInvocationToken: toolInvocationToken, tokenOptions }, dummyCancellationToken); const message = {contentType === 'text/plain' ? @@ -168,6 +168,11 @@ class History extends PromptElement { ); } else if (message instanceof vscode.ChatResponseTurn) { + const toolInfo = message.result.metadata?.toolInfo; + if (isTsxToolUserMetadata(toolInfo)) { + return + } + return ( {chatResponseToString(message)}