mirror of
https://github.com/microsoft/vscode-extension-samples.git
synced 2026-06-13 07:10:26 +08:00
Get tool calling working across multiple requests
This commit is contained in:
@ -188,6 +188,29 @@ df\`\`\``),
|
||||
context.subscriptions.push(toolUser);
|
||||
}
|
||||
|
||||
interface TsxToolUserMetadata {
|
||||
toolCallRounds: ToolCallRound[];
|
||||
toolCallResults: Record<string, vscode.LanguageModelToolResult>;
|
||||
}
|
||||
|
||||
export function isTsxToolUserMetadata(obj: unknown): obj is TsxToolUserMetadata {
|
||||
if (typeof obj !== 'object' || obj === null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const metadata = obj as TsxToolUserMetadata;
|
||||
if (!Array.isArray(metadata.toolCallRounds)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (typeof metadata.toolCallResults !== 'object' || metadata.toolCallResults === null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If you ever change the shape of ToolCallRound, remember to make this stricter
|
||||
return true;
|
||||
}
|
||||
|
||||
function registerChatParticipant2(context: vscode.ExtensionContext) {
|
||||
const handler: vscode.ChatRequestHandler = async (request: vscode.ChatRequest, chatContext: vscode.ChatContext, stream: vscode.ChatResponseStream, token: vscode.CancellationToken) => {
|
||||
const models = await vscode.lm.selectChatModels({
|
||||
@ -216,13 +239,13 @@ function registerChatParticipant2(context: vscode.ExtensionContext) {
|
||||
context: chatContext,
|
||||
request,
|
||||
toolCallRounds: [],
|
||||
toolCallResults: new Map()
|
||||
toolCallResults: {}
|
||||
},
|
||||
{ modelMaxPromptTokens: model.maxInputTokens },
|
||||
model)
|
||||
|
||||
const toolReferences = [...request.toolReferences];
|
||||
const accumulatedToolCalls = new Map<string, vscode.LanguageModelToolResult>();
|
||||
const accumulatedToolResults: Record<string, vscode.LanguageModelToolResult> = {};
|
||||
const toolCallRounds: ToolCallRound[] = [];
|
||||
const runWithFunctions = async (): Promise<void> => {
|
||||
const requestedTool = toolReferences.shift();
|
||||
@ -261,7 +284,7 @@ function registerChatParticipant2(context: vscode.ExtensionContext) {
|
||||
context: chatContext,
|
||||
request,
|
||||
toolCallRounds,
|
||||
toolCallResults: accumulatedToolCalls
|
||||
toolCallResults: accumulatedToolResults
|
||||
},
|
||||
{ modelMaxPromptTokens: model.maxInputTokens },
|
||||
model));
|
||||
@ -269,7 +292,7 @@ function registerChatParticipant2(context: vscode.ExtensionContext) {
|
||||
const toolResultMetadata = result.metadatas.getAll(ToolResultMetadata)
|
||||
if (toolResultMetadata?.length) {
|
||||
// TODO flatten
|
||||
toolResultMetadata.forEach(meta => meta.resultMap.forEach((value, key) => accumulatedToolCalls.set(key, value)));
|
||||
toolResultMetadata.forEach(meta => meta.resultMap.forEach((value, key) => accumulatedToolResults[key] = value));
|
||||
}
|
||||
|
||||
// RE-enter
|
||||
@ -278,6 +301,15 @@ function registerChatParticipant2(context: vscode.ExtensionContext) {
|
||||
};
|
||||
|
||||
await runWithFunctions();
|
||||
|
||||
return {
|
||||
metadata: {
|
||||
toolInfo: {
|
||||
toolCallResults: accumulatedToolResults,
|
||||
toolCallRounds
|
||||
} satisfies TsxToolUserMetadata
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const toolUser = vscode.chat.createChatParticipant('chat-tools-sample.tools2', handler);
|
||||
|
||||
@ -12,6 +12,7 @@ import {
|
||||
} from '@vscode/prompt-tsx';
|
||||
import { Chunk, ToolMessage, ToolResult } from '@vscode/prompt-tsx/dist/base/promptElements';
|
||||
import * as vscode from 'vscode';
|
||||
import { isTsxToolUserMetadata } from './extension';
|
||||
|
||||
export interface ToolCallRound {
|
||||
response: string;
|
||||
@ -22,7 +23,7 @@ export interface ToolUserProps extends BasePromptElementProps {
|
||||
request: vscode.ChatRequest;
|
||||
context: vscode.ChatContext;
|
||||
toolCallRounds: ToolCallRound[];
|
||||
toolCallResults: Map<string, vscode.LanguageModelToolResult>;
|
||||
toolCallResults: Record<string, vscode.LanguageModelToolResult>;
|
||||
}
|
||||
|
||||
export class ToolUserPrompt extends PromptElement<ToolUserProps, void> {
|
||||
@ -64,8 +65,8 @@ export class ToolUserPrompt extends PromptElement<ToolUserProps, void> {
|
||||
|
||||
interface ToolCallsProps extends BasePromptElementProps {
|
||||
toolCallRounds: ToolCallRound[];
|
||||
toolCallResults: Map<string, vscode.LanguageModelToolResult>;
|
||||
toolInvocationToken: vscode.ChatParticipantToolToken;
|
||||
toolCallResults: Record<string, vscode.LanguageModelToolResult>;
|
||||
toolInvocationToken: vscode.ChatParticipantToolToken | undefined;
|
||||
}
|
||||
|
||||
const agentSupportedContentTypes = [promptTsxContentType, 'text/plain'];
|
||||
@ -79,6 +80,7 @@ class ToolCalls extends PromptElement<ToolCallsProps, void> {
|
||||
|
||||
return <>
|
||||
{await Promise.all(this.props.toolCallRounds.map(round => this.renderOneToolCallRound(round, sizing)))}
|
||||
<UserMessage priority={100}>Above is the result of calling one or more tools. The user cannot see the results, so you should explain them to the user if referencing them in your answer.</UserMessage>
|
||||
</>
|
||||
}
|
||||
|
||||
@ -86,11 +88,10 @@ class ToolCalls extends PromptElement<ToolCallsProps, void> {
|
||||
// TODO- prompt-tsx export this type?
|
||||
// TODO- at what level do the parameters get stringified?
|
||||
const assistantToolCalls: any[] = round.toolCalls.map(tc => ({ type: 'function', function: { name: tc.name, arguments: JSON.stringify(tc.parameters) }, id: tc.toolCallId }));
|
||||
// TODO@prompt-tsx- don't remove "empty" assistant messages!
|
||||
|
||||
const toolResultMap = new Map<string, vscode.LanguageModelToolResult>();
|
||||
const budget = Math.floor(sizing.tokenBudget / round.toolCalls.length);
|
||||
// TODO@prompt-tsx- this would be a bit easier with a ToolCall element, but we can only return one instance of a metadata right now
|
||||
// TODO refactor into multiple elements
|
||||
const toolCallSizing: PromptSizing = {
|
||||
...sizing,
|
||||
tokenBudget: budget,
|
||||
@ -98,7 +99,7 @@ class ToolCalls extends PromptElement<ToolCallsProps, void> {
|
||||
|
||||
return <Chunk>
|
||||
<meta value={new ToolResultMetadata(toolResultMap)}></meta>
|
||||
<AssistantMessage toolCalls={assistantToolCalls}>{round.response || 'TODO'}</AssistantMessage>
|
||||
<AssistantMessage toolCalls={assistantToolCalls}>{round.response}</AssistantMessage>
|
||||
{await Promise.all(round.toolCalls.map(async toolCall => {
|
||||
const result = await this.renderOneToolCall(toolCall, toolCallSizing, this.props.toolInvocationToken);
|
||||
if (result.toolResult) {
|
||||
@ -106,7 +107,6 @@ class ToolCalls extends PromptElement<ToolCallsProps, void> {
|
||||
}
|
||||
return result.message;
|
||||
}))}
|
||||
<UserMessage priority={100}>Above is the result of calling one or more tools. The user cannot see the results, so you should explain them to the user if referencing them in your answer.</UserMessage>
|
||||
</Chunk>;
|
||||
}
|
||||
|
||||
@ -128,7 +128,7 @@ class ToolCalls extends PromptElement<ToolCallsProps, void> {
|
||||
countTokens: async (content: string) => sizing.countTokens(content),
|
||||
};
|
||||
|
||||
const toolResult = this.props.toolCallResults.get(toolCall.toolCallId) ??
|
||||
const toolResult = this.props.toolCallResults[toolCall.toolCallId] ??
|
||||
await vscode.lm.invokeTool(toolCall.name, { parameters: toolCall.parameters, requestedContentTypes: [contentType], toolInvocationToken: toolInvocationToken, tokenOptions }, dummyCancellationToken);
|
||||
const message = <ToolMessage toolCallId={toolCall.toolCallId}>
|
||||
{contentType === 'text/plain' ?
|
||||
@ -168,6 +168,11 @@ class History extends PromptElement<HistoryProps, void> {
|
||||
</>
|
||||
);
|
||||
} else if (message instanceof vscode.ChatResponseTurn) {
|
||||
const toolInfo = message.result.metadata?.toolInfo;
|
||||
if (isTsxToolUserMetadata(toolInfo)) {
|
||||
return <ToolCalls toolCallResults={toolInfo.toolCallResults} toolCallRounds={toolInfo.toolCallRounds} toolInvocationToken={undefined}></ToolCalls>
|
||||
}
|
||||
|
||||
return (
|
||||
<AssistantMessage>
|
||||
{chatResponseToString(message)}
|
||||
|
||||
Reference in New Issue
Block a user