Get tool calling working across multiple requests

This commit is contained in:
Rob Lourens
2024-10-10 17:47:39 -07:00
parent ddc6407f74
commit f330723fdb
2 changed files with 49 additions and 12 deletions

View File

@ -188,6 +188,29 @@ df\`\`\``),
context.subscriptions.push(toolUser);
}
interface TsxToolUserMetadata {
toolCallRounds: ToolCallRound[];
toolCallResults: Record<string, vscode.LanguageModelToolResult>;
}
export function isTsxToolUserMetadata(obj: unknown): obj is TsxToolUserMetadata {
if (typeof obj !== 'object' || obj === null) {
return false;
}
const metadata = obj as TsxToolUserMetadata;
if (!Array.isArray(metadata.toolCallRounds)) {
return false;
}
if (typeof metadata.toolCallResults !== 'object' || metadata.toolCallResults === null) {
return false;
}
// If you ever change the shape of ToolCallRound, remember to make this stricter
return true;
}
function registerChatParticipant2(context: vscode.ExtensionContext) {
const handler: vscode.ChatRequestHandler = async (request: vscode.ChatRequest, chatContext: vscode.ChatContext, stream: vscode.ChatResponseStream, token: vscode.CancellationToken) => {
const models = await vscode.lm.selectChatModels({
@ -216,13 +239,13 @@ function registerChatParticipant2(context: vscode.ExtensionContext) {
context: chatContext,
request,
toolCallRounds: [],
toolCallResults: new Map()
toolCallResults: {}
},
{ modelMaxPromptTokens: model.maxInputTokens },
model)
const toolReferences = [...request.toolReferences];
const accumulatedToolCalls = new Map<string, vscode.LanguageModelToolResult>();
const accumulatedToolResults: Record<string, vscode.LanguageModelToolResult> = {};
const toolCallRounds: ToolCallRound[] = [];
const runWithFunctions = async (): Promise<void> => {
const requestedTool = toolReferences.shift();
@ -261,7 +284,7 @@ function registerChatParticipant2(context: vscode.ExtensionContext) {
context: chatContext,
request,
toolCallRounds,
toolCallResults: accumulatedToolCalls
toolCallResults: accumulatedToolResults
},
{ modelMaxPromptTokens: model.maxInputTokens },
model));
@ -269,7 +292,7 @@ function registerChatParticipant2(context: vscode.ExtensionContext) {
const toolResultMetadata = result.metadatas.getAll(ToolResultMetadata)
if (toolResultMetadata?.length) {
// TODO flatten
toolResultMetadata.forEach(meta => meta.resultMap.forEach((value, key) => accumulatedToolCalls.set(key, value)));
toolResultMetadata.forEach(meta => meta.resultMap.forEach((value, key) => accumulatedToolResults[key] = value));
}
// RE-enter
@ -278,6 +301,15 @@ function registerChatParticipant2(context: vscode.ExtensionContext) {
};
await runWithFunctions();
return {
metadata: {
toolInfo: {
toolCallResults: accumulatedToolResults,
toolCallRounds
} satisfies TsxToolUserMetadata
}
}
};
const toolUser = vscode.chat.createChatParticipant('chat-tools-sample.tools2', handler);

View File

@ -12,6 +12,7 @@ import {
} from '@vscode/prompt-tsx';
import { Chunk, ToolMessage, ToolResult } from '@vscode/prompt-tsx/dist/base/promptElements';
import * as vscode from 'vscode';
import { isTsxToolUserMetadata } from './extension';
export interface ToolCallRound {
response: string;
@ -22,7 +23,7 @@ export interface ToolUserProps extends BasePromptElementProps {
request: vscode.ChatRequest;
context: vscode.ChatContext;
toolCallRounds: ToolCallRound[];
toolCallResults: Map<string, vscode.LanguageModelToolResult>;
toolCallResults: Record<string, vscode.LanguageModelToolResult>;
}
export class ToolUserPrompt extends PromptElement<ToolUserProps, void> {
@ -64,8 +65,8 @@ export class ToolUserPrompt extends PromptElement<ToolUserProps, void> {
interface ToolCallsProps extends BasePromptElementProps {
toolCallRounds: ToolCallRound[];
toolCallResults: Map<string, vscode.LanguageModelToolResult>;
toolInvocationToken: vscode.ChatParticipantToolToken;
toolCallResults: Record<string, vscode.LanguageModelToolResult>;
toolInvocationToken: vscode.ChatParticipantToolToken | undefined;
}
const agentSupportedContentTypes = [promptTsxContentType, 'text/plain'];
@ -79,6 +80,7 @@ class ToolCalls extends PromptElement<ToolCallsProps, void> {
return <>
{await Promise.all(this.props.toolCallRounds.map(round => this.renderOneToolCallRound(round, sizing)))}
<UserMessage priority={100}>Above is the result of calling one or more tools. The user cannot see the results, so you should explain them to the user if referencing them in your answer.</UserMessage>
</>
}
@ -86,11 +88,10 @@ class ToolCalls extends PromptElement<ToolCallsProps, void> {
// TODO- prompt-tsx export this type?
// TODO- at what level do the parameters get stringified?
const assistantToolCalls: any[] = round.toolCalls.map(tc => ({ type: 'function', function: { name: tc.name, arguments: JSON.stringify(tc.parameters) }, id: tc.toolCallId }));
// TODO@prompt-tsx- don't remove "empty" assistant messages!
const toolResultMap = new Map<string, vscode.LanguageModelToolResult>();
const budget = Math.floor(sizing.tokenBudget / round.toolCalls.length);
// TODO@prompt-tsx- this would be a bit easier with a ToolCall element, but we can only return one instance of a metadata right now
// TODO refactor into multiple elements
const toolCallSizing: PromptSizing = {
...sizing,
tokenBudget: budget,
@ -98,7 +99,7 @@ class ToolCalls extends PromptElement<ToolCallsProps, void> {
return <Chunk>
<meta value={new ToolResultMetadata(toolResultMap)}></meta>
<AssistantMessage toolCalls={assistantToolCalls}>{round.response || 'TODO'}</AssistantMessage>
<AssistantMessage toolCalls={assistantToolCalls}>{round.response}</AssistantMessage>
{await Promise.all(round.toolCalls.map(async toolCall => {
const result = await this.renderOneToolCall(toolCall, toolCallSizing, this.props.toolInvocationToken);
if (result.toolResult) {
@ -106,7 +107,6 @@ class ToolCalls extends PromptElement<ToolCallsProps, void> {
}
return result.message;
}))}
<UserMessage priority={100}>Above is the result of calling one or more tools. The user cannot see the results, so you should explain them to the user if referencing them in your answer.</UserMessage>
</Chunk>;
}
@ -128,7 +128,7 @@ class ToolCalls extends PromptElement<ToolCallsProps, void> {
countTokens: async (content: string) => sizing.countTokens(content),
};
const toolResult = this.props.toolCallResults.get(toolCall.toolCallId) ??
const toolResult = this.props.toolCallResults[toolCall.toolCallId] ??
await vscode.lm.invokeTool(toolCall.name, { parameters: toolCall.parameters, requestedContentTypes: [contentType], toolInvocationToken: toolInvocationToken, tokenOptions }, dummyCancellationToken);
const message = <ToolMessage toolCallId={toolCall.toolCallId}>
{contentType === 'text/plain' ?
@ -168,6 +168,11 @@ class History extends PromptElement<HistoryProps, void> {
</>
);
} else if (message instanceof vscode.ChatResponseTurn) {
const toolInfo = message.result.metadata?.toolInfo;
if (isTsxToolUserMetadata(toolInfo)) {
return <ToolCalls toolCallResults={toolInfo.toolCallResults} toolCallRounds={toolInfo.toolCallRounds} toolInvocationToken={undefined}></ToolCalls>
}
return (
<AssistantMessage>
{chatResponseToString(message)}