More cleanup

This commit is contained in:
Rob Lourens
2024-10-10 20:27:08 -07:00
parent 34e880ac89
commit e6b0934c19
3 changed files with 134 additions and 136 deletions

View File

@ -2,11 +2,12 @@ import * as vscode from 'vscode';
import { FindFilesTool, RunInTerminalTool, TabCountTool } from './tools';
import { renderPrompt } from '@vscode/prompt-tsx';
import { ToolCallRound, ToolResultMetadata, ToolUserPrompt } from './toolsPrompt';
import { registerTsxChatParticipant } from './tsxParticipant';
export function activate(context: vscode.ExtensionContext) {
registerChatTool(context);
registerChatParticipant(context);
registerChatParticipant2(context);
registerTsxChatParticipant(context);
}
function registerChatTool(context: vscode.ExtensionContext) {
@ -188,137 +189,6 @@ df\`\`\``),
context.subscriptions.push(toolUser);
}
interface TsxToolUserMetadata {
toolCallRounds: ToolCallRound[];
toolCallResults: Record<string, vscode.LanguageModelToolResult>;
}
export function isTsxToolUserMetadata(obj: unknown): obj is TsxToolUserMetadata {
if (typeof obj !== 'object' || obj === null) {
return false;
}
const metadata = obj as TsxToolUserMetadata;
if (!Array.isArray(metadata.toolCallRounds)) {
return false;
}
if (typeof metadata.toolCallResults !== 'object' || metadata.toolCallResults === null) {
return false;
}
// If you ever change the shape of ToolCallRound, remember to make this stricter
return true;
}
function registerChatParticipant2(context: vscode.ExtensionContext) {
const handler: vscode.ChatRequestHandler = async (request: vscode.ChatRequest, chatContext: vscode.ChatContext, stream: vscode.ChatResponseStream, token: vscode.CancellationToken) => {
const models = await vscode.lm.selectChatModels({
vendor: 'copilot',
family: 'gpt-4o'
});
const model = models[0];
stream.markdown(`Available tools: ${vscode.lm.tools.map(tool => tool.id).join(', ')}\n\n`);
const allTools = vscode.lm.tools.map((tool): vscode.LanguageModelChatTool => {
return {
name: tool.id,
description: tool.description,
parametersSchema: tool.parametersSchema ?? {}
};
});
const options: vscode.LanguageModelChatRequestOptions = {
justification: 'Just because!',
};
let { messages, references } = await renderPrompt(
ToolUserPrompt,
{
context: chatContext,
request,
toolCallRounds: [],
toolCallResults: {}
},
{ modelMaxPromptTokens: model.maxInputTokens },
model);
references.forEach(ref => {
if (ref.anchor instanceof vscode.Uri || ref.anchor instanceof vscode.Location) {
stream.reference(ref.anchor);
}
});
const toolReferences = [...request.toolReferences];
const accumulatedToolResults: Record<string, vscode.LanguageModelToolResult> = {};
const toolCallRounds: ToolCallRound[] = [];
const runWithFunctions = async (): Promise<void> => {
const requestedTool = toolReferences.shift();
if (requestedTool) {
options.toolChoice = requestedTool.id;
options.tools = allTools.filter(tool => tool.name === requestedTool.id);
} else {
options.toolChoice = undefined;
options.tools = allTools;
}
const toolCalls: vscode.LanguageModelToolCallPart[] = [];
const response = await model.sendRequest(messages, options, token);
let responseStr = '';
for await (const part of response.stream) {
if (part instanceof vscode.LanguageModelTextPart) {
stream.markdown(part.value);
responseStr += part.value;
} else if (part instanceof vscode.LanguageModelToolCallPart) {
// TODO vscode should be doing this
part.parameters = JSON.parse(part.parameters);
toolCalls.push(part);
}
}
if (toolCalls.length) {
toolCallRounds.push({
response: responseStr,
toolCalls
});
const result = (await renderPrompt(
ToolUserPrompt,
{
context: chatContext,
request,
toolCallRounds,
toolCallResults: accumulatedToolResults
},
{ modelMaxPromptTokens: model.maxInputTokens },
model));
messages = result.messages;
const toolResultMetadata = result.metadatas.getAll(ToolResultMetadata)
if (toolResultMetadata?.length) {
toolResultMetadata.forEach(meta => accumulatedToolResults[meta.toolCallId] = meta.result);
}
return runWithFunctions();
}
};
await runWithFunctions();
return {
metadata: {
toolInfo: {
toolCallResults: accumulatedToolResults,
toolCallRounds
} satisfies TsxToolUserMetadata
}
}
};
const toolUser = vscode.chat.createChatParticipant('chat-tools-sample.tools2', handler);
toolUser.iconPath = new vscode.ThemeIcon('tools');
context.subscriptions.push(toolUser);
}
async function getContextMessage(references: ReadonlyArray<vscode.ChatPromptReference>): Promise<string> {
const contextParts = (await Promise.all(references.map(async ref => {
if (ref.value instanceof vscode.Uri) {

View File

@ -15,7 +15,7 @@ import {
PromptReference,
} from '@vscode/prompt-tsx';
import * as vscode from 'vscode';
import { isTsxToolUserMetadata } from './extension';
import { isTsxToolUserMetadata } from './tsxParticipant';
export interface ToolCallRound {
response: string;
@ -166,9 +166,9 @@ class History extends PromptElement<HistoryProps, void> {
</>
);
} else if (message instanceof vscode.ChatResponseTurn) {
const toolInfo = message.result.metadata?.toolInfo;
if (isTsxToolUserMetadata(toolInfo) && toolInfo.toolCallRounds.length > 0) {
return <ToolCalls toolCallResults={toolInfo.toolCallResults} toolCallRounds={toolInfo.toolCallRounds} toolInvocationToken={undefined} />;
const metadata = message.result.metadata;
if (isTsxToolUserMetadata(metadata) && metadata.toolCallsMetadata.toolCallRounds.length > 0) {
return <ToolCalls toolCallResults={metadata.toolCallsMetadata.toolCallResults} toolCallRounds={metadata.toolCallsMetadata.toolCallRounds} toolInvocationToken={undefined} />;
}
return <AssistantMessage>{chatResponseToString(message)}</AssistantMessage>;

View File

@ -0,0 +1,128 @@
import * as vscode from 'vscode';
import { FindFilesTool, RunInTerminalTool, TabCountTool } from './tools';
import { renderPrompt } from '@vscode/prompt-tsx';
import { ToolCallRound, ToolResultMetadata, ToolUserPrompt } from './toolsPrompt';
export interface TsxToolUserMetadata {
toolCallsMetadata: ToolCallsMetadata;
}
export interface ToolCallsMetadata {
toolCallRounds: ToolCallRound[];
toolCallResults: Record<string, vscode.LanguageModelToolResult>;
}
export function isTsxToolUserMetadata(obj: unknown): obj is TsxToolUserMetadata {
// If you change the metadata format, you would have to make this stricter or handle old objects in old ChatRequest metadata
return !!obj &&
!!(obj as TsxToolUserMetadata).toolCallsMetadata &&
Array.isArray((obj as TsxToolUserMetadata).toolCallsMetadata.toolCallRounds);
}
export function registerTsxChatParticipant(context: vscode.ExtensionContext) {
const handler: vscode.ChatRequestHandler = async (request: vscode.ChatRequest, chatContext: vscode.ChatContext, stream: vscode.ChatResponseStream, token: vscode.CancellationToken) => {
const models = await vscode.lm.selectChatModels({
vendor: 'copilot',
family: 'gpt-4o'
});
const model = models[0];
stream.markdown(`Available tools: ${vscode.lm.tools.map(tool => tool.id).join(', ')}\n\n`);
const allTools = vscode.lm.tools.map((tool): vscode.LanguageModelChatTool => {
return {
name: tool.id,
description: tool.description,
parametersSchema: tool.parametersSchema ?? {}
};
});
const options: vscode.LanguageModelChatRequestOptions = {
justification: 'Just because!',
};
let { messages, references } = await renderPrompt(
ToolUserPrompt,
{
context: chatContext,
request,
toolCallRounds: [],
toolCallResults: {}
},
{ modelMaxPromptTokens: model.maxInputTokens },
model);
references.forEach(ref => {
if (ref.anchor instanceof vscode.Uri || ref.anchor instanceof vscode.Location) {
stream.reference(ref.anchor);
}
});
const toolReferences = [...request.toolReferences];
const accumulatedToolResults: Record<string, vscode.LanguageModelToolResult> = {};
const toolCallRounds: ToolCallRound[] = [];
const runWithFunctions = async (): Promise<void> => {
const requestedTool = toolReferences.shift();
if (requestedTool) {
options.toolChoice = requestedTool.id;
options.tools = allTools.filter(tool => tool.name === requestedTool.id);
} else {
options.toolChoice = undefined;
options.tools = allTools;
}
const toolCalls: vscode.LanguageModelToolCallPart[] = [];
const response = await model.sendRequest(messages, options, token);
let responseStr = '';
for await (const part of response.stream) {
if (part instanceof vscode.LanguageModelTextPart) {
stream.markdown(part.value);
responseStr += part.value;
} else if (part instanceof vscode.LanguageModelToolCallPart) {
// TODO vscode should be doing this
part.parameters = JSON.parse(part.parameters);
toolCalls.push(part);
}
}
if (toolCalls.length) {
toolCallRounds.push({
response: responseStr,
toolCalls
});
const result = (await renderPrompt(
ToolUserPrompt,
{
context: chatContext,
request,
toolCallRounds,
toolCallResults: accumulatedToolResults
},
{ modelMaxPromptTokens: model.maxInputTokens },
model));
messages = result.messages;
const toolResultMetadata = result.metadatas.getAll(ToolResultMetadata)
if (toolResultMetadata?.length) {
toolResultMetadata.forEach(meta => accumulatedToolResults[meta.toolCallId] = meta.result);
}
return runWithFunctions();
}
};
await runWithFunctions();
return {
metadata: {
toolCallsMetadata: {
toolCallResults: accumulatedToolResults,
toolCallRounds
}
} satisfies TsxToolUserMetadata,
}
};
const toolUser = vscode.chat.createChatParticipant('chat-tools-sample.tools2', handler);
toolUser.iconPath = new vscode.ThemeIcon('tools');
context.subscriptions.push(toolUser);
}