stop setting name field when assistant message contains tool call (#3481)

This commit is contained in:
Xiaoyun Zhang 2024-09-05 13:54:30 -07:00 committed by GitHub
parent 40cfe07a95
commit a44b86f26e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 170 additions and 17 deletions

View File

@ -335,7 +335,10 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa
var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
var textContent = message.GetContent() ?? string.Empty;
var chatRequestMessage = new ChatRequestAssistantMessage(textContent) { Name = message.From };
// don't include the name field when it's tool call message.
// fix https://github.com/microsoft/autogen/issues/3437
var chatRequestMessage = new ChatRequestAssistantMessage(textContent);
foreach (var tc in toolCall)
{
chatRequestMessage.ToolCalls.Add(tc);

View File

@ -322,7 +322,10 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa
var toolCallParts = message.ToolCalls.Select((tc, i) => ChatToolCall.CreateFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
var textContent = message.GetContent() ?? null;
var chatRequestMessage = new AssistantChatMessage(toolCallParts, textContent) { ParticipantName = message.From };
// Don't set participant name for assistant when it is tool call
// fix https://github.com/microsoft/autogen/issues/3437
var chatRequestMessage = new AssistantChatMessage(toolCallParts, textContent);
return [chatRequestMessage];
}

View File

@ -139,7 +139,7 @@
{
"Role": "assistant",
"Content": [],
"Name": "assistant",
"Name": null,
"TooCall": [
{
"Type": "Function",
@ -184,7 +184,7 @@
{
"Role": "assistant",
"Content": [],
"Name": "assistant",
"Name": null,
"TooCall": [
{
"Type": "Function",
@ -210,7 +210,7 @@
{
"Role": "assistant",
"Content": [],
"Name": "assistant",
"Name": null,
"TooCall": [
{
"Type": "Function",

View File

@ -27,6 +27,12 @@ public partial class OpenAIChatAgentTest
return $"The weather in {location} is sunny.";
}
[Function]
public async Task<string> CalculateTaxAsync(string location, double income)
{
return $"[CalculateTax] The tax in {location} for income {income} is 1000.";
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task BasicConversationTestAsync()
{
@ -246,6 +252,65 @@ public partial class OpenAIChatAgentTest
respond.GetContent()?.Should().NotBeNullOrEmpty();
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task ItProduceValidContentAfterFunctionCall()
{
// https://github.com/microsoft/autogen/issues/3437
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
var openaiClient = CreateOpenAIClientFromAzureOpenAI();
var options = new ChatCompletionOptions()
{
Temperature = 0.7f,
MaxTokens = 1,
};
var agentName = "assistant";
var getWeatherToolCall = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}");
var getWeatherToolCallResult = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}", "The weather in Seattle is sunny.");
var getWeatherToolCallMessage = new ToolCallMessage([getWeatherToolCall], from: agentName);
var getWeatherToolCallResultMessage = new ToolCallResultMessage([getWeatherToolCallResult], from: agentName);
var getWeatherAggregateMessage = new ToolCallAggregateMessage(getWeatherToolCallMessage, getWeatherToolCallResultMessage, from: agentName);
var calculateTaxToolCall = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}");
var calculateTaxToolCallResult = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}", "The tax in Seattle for income 1000 is 1000.");
var calculateTaxToolCallMessage = new ToolCallMessage([calculateTaxToolCall], from: agentName);
var calculateTaxToolCallResultMessage = new ToolCallResultMessage([calculateTaxToolCallResult], from: agentName);
var calculateTaxAggregateMessage = new ToolCallAggregateMessage(calculateTaxToolCallMessage, calculateTaxToolCallResultMessage, from: agentName);
var chatHistory = new List<IMessage>()
{
new TextMessage(Role.User, "What's the weather in Seattle", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in Seattle is sunny, now check the tax in seattle", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in Paris", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in Paris is sunny, now check the tax in Paris", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in New York", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in New York is sunny, now check the tax in New York", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in London", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in London is sunny, now check the tax in London", from: "admin"),
};
var agent = new OpenAIChatAgent(
chatClient: openaiClient.GetChatClient(deployName),
name: "assistant",
options: options)
.RegisterMessageConnector();
var res = await agent.GenerateReplyAsync(chatHistory, new GenerateReplyOptions
{
MaxToken = 1024,
Functions = [this.GetWeatherAsyncFunctionContract, this.CalculateTaxAsyncFunctionContract],
});
}
private OpenAIClient CreateOpenAIClientFromAzureOpenAI()
{
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");

View File

@ -276,7 +276,10 @@ public class OpenAIMessageTests
var innerMessage = msgs.Last();
innerMessage!.Should().BeOfType<MessageEnvelope<ChatMessage>>();
var chatRequestMessage = (AssistantChatMessage)((MessageEnvelope<ChatMessage>)innerMessage!).Content;
chatRequestMessage.ParticipantName.Should().Be("assistant");
// when the message is a tool call message
// the name field should not be set
// please visit OpenAIChatRequestMessageConnector class for more information
chatRequestMessage.ParticipantName.Should().BeNullOrEmpty();
chatRequestMessage.ToolCalls.Count().Should().Be(1);
chatRequestMessage.Content.First().Text.Should().Be("textContent");
chatRequestMessage.ToolCalls.First().Should().BeOfType<ChatToolCall>();
@ -307,7 +310,10 @@ public class OpenAIMessageTests
innerMessage!.Should().BeOfType<MessageEnvelope<ChatMessage>>();
var chatRequestMessage = (AssistantChatMessage)((MessageEnvelope<ChatMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().BeNullOrEmpty();
chatRequestMessage.ParticipantName.Should().Be("assistant");
// when the message is a tool call message
// the name field should not be set
// please visit OpenAIChatRequestMessageConnector class for more information
chatRequestMessage.ParticipantName.Should().BeNullOrEmpty();
chatRequestMessage.ToolCalls.Count().Should().Be(2);
for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++)
{

View File

@ -81,7 +81,7 @@
{
"Role": "assistant",
"Content": "",
"Name": "assistant",
"Name": null,
"TooCall": [
{
"Type": "Function",
@ -126,7 +126,7 @@
{
"Role": "assistant",
"Content": "",
"Name": "assistant",
"Name": null,
"TooCall": [
{
"Type": "Function",
@ -152,7 +152,7 @@
{
"Role": "assistant",
"Content": "",
"Name": "assistant",
"Name": null,
"TooCall": [
{
"Type": "Function",

View File

@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.OpenAI.V1.Extension;
@ -45,7 +46,11 @@ namespace AutoGen.OpenAI.V1.Tests
_output.WriteLine($"agent name: {agent.Name}");
foreach (var message in messages)
{
_output.WriteLine(message.FormatMessage());
if (message is IMessage<object> envelope)
{
var json = JsonSerializer.Serialize(envelope.Content, new JsonSerializerOptions { WriteIndented = true });
_output.WriteLine(json);
}
}
throw;
@ -149,9 +154,9 @@ You create math question and ask student to answer it.
Then you check if the answer is correct.
If the answer is wrong, you ask student to fix it",
modelName: model)
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
.RegisterMiddleware(Print)
.RegisterMiddleware(new OpenAIChatRequestMessageConnector())
.RegisterMiddleware(functionCallMiddleware);
return teacher;
}

View File

@ -22,7 +22,13 @@ public partial class OpenAIChatAgentTest
[Function]
public async Task<string> GetWeatherAsync(string location)
{
return $"The weather in {location} is sunny.";
return $"[GetWeather] The weather in {location} is sunny.";
}
[Function]
public async Task<string> CalculateTaxAsync(string location, double income)
{
return $"[CalculateTax] The tax in {location} for income {income} is 1000.";
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
@ -270,6 +276,64 @@ public partial class OpenAIChatAgentTest
action.Should().ThrowExactly<ArgumentException>().WithMessage("Messages should not be provided in options");
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task ItProduceValidContentAfterFunctionCall()
{
// https://github.com/microsoft/autogen/issues/3437
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
var openaiClient = CreateOpenAIClientFromAzureOpenAI();
var options = new ChatCompletionsOptions(deployName, [])
{
Temperature = 0.7f,
MaxTokens = 1,
};
var agentName = "assistant";
var getWeatherToolCall = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}");
var getWeatherToolCallResult = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}", "The weather in Seattle is sunny.");
var getWeatherToolCallMessage = new ToolCallMessage([getWeatherToolCall], from: agentName);
var getWeatherToolCallResultMessage = new ToolCallResultMessage([getWeatherToolCallResult], from: agentName);
var getWeatherAggregateMessage = new ToolCallAggregateMessage(getWeatherToolCallMessage, getWeatherToolCallResultMessage, from: agentName);
var calculateTaxToolCall = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}");
var calculateTaxToolCallResult = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}", "The tax in Seattle for income 1000 is 1000.");
var calculateTaxToolCallMessage = new ToolCallMessage([calculateTaxToolCall], from: agentName);
var calculateTaxToolCallResultMessage = new ToolCallResultMessage([calculateTaxToolCallResult], from: agentName);
var calculateTaxAggregateMessage = new ToolCallAggregateMessage(calculateTaxToolCallMessage, calculateTaxToolCallResultMessage, from: agentName);
var chatHistory = new List<IMessage>()
{
new TextMessage(Role.User, "What's the weather in Seattle", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in Seattle is sunny, now check the tax in seattle", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in Paris", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in Paris is sunny, now check the tax in Paris", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in New York", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in New York is sunny, now check the tax in New York", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in London", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in London is sunny, now check the tax in London", from: "admin"),
};
var agent = new OpenAIChatAgent(
openAIClient: openaiClient,
name: "assistant",
options: options)
.RegisterMessageConnector();
var res = await agent.GenerateReplyAsync(chatHistory, new GenerateReplyOptions
{
MaxToken = 1024,
Functions = [this.GetWeatherAsyncFunctionContract, this.CalculateTaxAsyncFunctionContract],
});
}
private OpenAIClient CreateOpenAIClientFromAzureOpenAI()
{
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");

View File

@ -278,7 +278,10 @@ public class OpenAIMessageTests
var innerMessage = msgs.Last();
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Name.Should().Be("assistant");
// when the message is a tool call message
// the name field should not be set
// please visit OpenAIChatRequestMessageConnector class for more information
chatRequestMessage.Name.Should().BeNullOrEmpty();
chatRequestMessage.ToolCalls.Count().Should().Be(1);
chatRequestMessage.Content.Should().Be("textContent");
chatRequestMessage.ToolCalls.First().Should().BeOfType<ChatCompletionsFunctionToolCall>();
@ -309,7 +312,11 @@ public class OpenAIMessageTests
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().BeNullOrEmpty();
chatRequestMessage.Name.Should().Be("assistant");
// when the message is a tool call message
// the name field should not be set
// please visit OpenAIChatRequestMessageConnector class for more information
chatRequestMessage.Name.Should().BeNullOrEmpty();
chatRequestMessage.ToolCalls.Count().Should().Be(2);
for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++)
{