diff --git a/dotnet/src/AutoGen.OpenAI.V1/Middleware/OpenAIChatRequestMessageConnector.cs b/dotnet/src/AutoGen.OpenAI.V1/Middleware/OpenAIChatRequestMessageConnector.cs index f1bea485c..3587d1b0d 100644 --- a/dotnet/src/AutoGen.OpenAI.V1/Middleware/OpenAIChatRequestMessageConnector.cs +++ b/dotnet/src/AutoGen.OpenAI.V1/Middleware/OpenAIChatRequestMessageConnector.cs @@ -335,7 +335,10 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments)); var textContent = message.GetContent() ?? string.Empty; - var chatRequestMessage = new ChatRequestAssistantMessage(textContent) { Name = message.From }; + + // don't include the name field when it's tool call message. + // fix https://github.com/microsoft/autogen/issues/3437 + var chatRequestMessage = new ChatRequestAssistantMessage(textContent); foreach (var tc in toolCall) { chatRequestMessage.ToolCalls.Add(tc); diff --git a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs index 2297d123b..fd55a1350 100644 --- a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs +++ b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs @@ -322,7 +322,10 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa var toolCallParts = message.ToolCalls.Select((tc, i) => ChatToolCall.CreateFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments)); var textContent = message.GetContent() ?? null; - var chatRequestMessage = new AssistantChatMessage(toolCallParts, textContent) { ParticipantName = message.From }; + + // Don't set participant name for assistant when it is tool call + // fix https://github.com/microsoft/autogen/issues/3437 + var chatRequestMessage = new AssistantChatMessage(toolCallParts, textContent); return [chatRequestMessage]; } diff --git a/dotnet/test/AutoGen.OpenAI.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt b/dotnet/test/AutoGen.OpenAI.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt index 3574e593d..55bd6502b 100644 --- a/dotnet/test/AutoGen.OpenAI.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt +++ b/dotnet/test/AutoGen.OpenAI.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt @@ -139,7 +139,7 @@ { "Role": "assistant", "Content": [], - "Name": "assistant", + "Name": null, "TooCall": [ { "Type": "Function", @@ -184,7 +184,7 @@ { "Role": "assistant", "Content": [], - "Name": "assistant", + "Name": null, "TooCall": [ { "Type": "Function", @@ -210,7 +210,7 @@ { "Role": "assistant", "Content": [], - "Name": "assistant", + "Name": null, "TooCall": [ { "Type": "Function", diff --git a/dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs index 04f4d3d4d..992bf9b60 100644 --- a/dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs +++ b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs @@ -27,6 +27,12 @@ public partial class OpenAIChatAgentTest return $"The weather in {location} is sunny."; } + [Function] + public async Task CalculateTaxAsync(string location, double income) + { + return $"[CalculateTax] The tax in {location} for income {income} is 1000."; + } + [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")] public async Task BasicConversationTestAsync() { @@ -246,6 +252,65 @@ public partial class OpenAIChatAgentTest respond.GetContent()?.Should().NotBeNullOrEmpty(); } + + [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")] + public async Task ItProduceValidContentAfterFunctionCall() + { + // https://github.com/microsoft/autogen/issues/3437 + var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable."); + var openaiClient = CreateOpenAIClientFromAzureOpenAI(); + var options = new ChatCompletionOptions() + { + Temperature = 0.7f, + MaxTokens = 1, + }; + + var agentName = "assistant"; + + var getWeatherToolCall = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}"); + var getWeatherToolCallResult = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}", "The weather in Seattle is sunny."); + var getWeatherToolCallMessage = new ToolCallMessage([getWeatherToolCall], from: agentName); + var getWeatherToolCallResultMessage = new ToolCallResultMessage([getWeatherToolCallResult], from: agentName); + var getWeatherAggregateMessage = new ToolCallAggregateMessage(getWeatherToolCallMessage, getWeatherToolCallResultMessage, from: agentName); + + var calculateTaxToolCall = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}"); + var calculateTaxToolCallResult = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}", "The tax in Seattle for income 1000 is 1000."); + var calculateTaxToolCallMessage = new ToolCallMessage([calculateTaxToolCall], from: agentName); + var calculateTaxToolCallResultMessage = new ToolCallResultMessage([calculateTaxToolCallResult], from: agentName); + var calculateTaxAggregateMessage = new ToolCallAggregateMessage(calculateTaxToolCallMessage, calculateTaxToolCallResultMessage, from: agentName); + + var chatHistory = new List() + { + new TextMessage(Role.User, "What's the weather in Seattle", from: "user"), + getWeatherAggregateMessage, + new TextMessage(Role.User, "The weather in Seattle is sunny, now check the tax in seattle", from: "admin"), + calculateTaxAggregateMessage, + new TextMessage(Role.User, "what's the weather in Paris", from: "user"), + getWeatherAggregateMessage, + new TextMessage(Role.User, "The weather in Paris is sunny, now check the tax in Paris", from: "admin"), + calculateTaxAggregateMessage, + new TextMessage(Role.User, "what's the weather in New York", from: "user"), + getWeatherAggregateMessage, + new TextMessage(Role.User, "The weather in New York is sunny, now check the tax in New York", from: "admin"), + calculateTaxAggregateMessage, + new TextMessage(Role.User, "what's the weather in London", from: "user"), + getWeatherAggregateMessage, + new TextMessage(Role.User, "The weather in London is sunny, now check the tax in London", from: "admin"), + }; + + var agent = new OpenAIChatAgent( + chatClient: openaiClient.GetChatClient(deployName), + name: "assistant", + options: options) + .RegisterMessageConnector(); + + var res = await agent.GenerateReplyAsync(chatHistory, new GenerateReplyOptions + { + MaxToken = 1024, + Functions = [this.GetWeatherAsyncFunctionContract, this.CalculateTaxAsyncFunctionContract], + }); + } + private OpenAIClient CreateOpenAIClientFromAzureOpenAI() { var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); diff --git a/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs index a05f440a1..3a2048c2f 100644 --- a/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs +++ b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs @@ -276,7 +276,10 @@ public class OpenAIMessageTests var innerMessage = msgs.Last(); innerMessage!.Should().BeOfType>(); var chatRequestMessage = (AssistantChatMessage)((MessageEnvelope)innerMessage!).Content; - chatRequestMessage.ParticipantName.Should().Be("assistant"); + // when the message is a tool call message + // the name field should not be set + // please visit OpenAIChatRequestMessageConnector class for more information + chatRequestMessage.ParticipantName.Should().BeNullOrEmpty(); chatRequestMessage.ToolCalls.Count().Should().Be(1); chatRequestMessage.Content.First().Text.Should().Be("textContent"); chatRequestMessage.ToolCalls.First().Should().BeOfType(); @@ -307,7 +310,10 @@ public class OpenAIMessageTests innerMessage!.Should().BeOfType>(); var chatRequestMessage = (AssistantChatMessage)((MessageEnvelope)innerMessage!).Content; chatRequestMessage.Content.Should().BeNullOrEmpty(); - chatRequestMessage.ParticipantName.Should().Be("assistant"); + // when the message is a tool call message + // the name field should not be set + // please visit OpenAIChatRequestMessageConnector class for more information + chatRequestMessage.ParticipantName.Should().BeNullOrEmpty(); chatRequestMessage.ToolCalls.Count().Should().Be(2); for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++) { diff --git a/dotnet/test/AutoGen.OpenAI.V1.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt b/dotnet/test/AutoGen.OpenAI.V1.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt index e8e9af84d..877bc57bf 100644 --- a/dotnet/test/AutoGen.OpenAI.V1.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt +++ b/dotnet/test/AutoGen.OpenAI.V1.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt @@ -81,7 +81,7 @@ { "Role": "assistant", "Content": "", - "Name": "assistant", + "Name": null, "TooCall": [ { "Type": "Function", @@ -126,7 +126,7 @@ { "Role": "assistant", "Content": "", - "Name": "assistant", + "Name": null, "TooCall": [ { "Type": "Function", @@ -152,7 +152,7 @@ { "Role": "assistant", "Content": "", - "Name": "assistant", + "Name": null, "TooCall": [ { "Type": "Function", diff --git a/dotnet/test/AutoGen.OpenAI.V1.Tests/MathClassTest.cs b/dotnet/test/AutoGen.OpenAI.V1.Tests/MathClassTest.cs index a1f9541f4..d6055fb78 100644 --- a/dotnet/test/AutoGen.OpenAI.V1.Tests/MathClassTest.cs +++ b/dotnet/test/AutoGen.OpenAI.V1.Tests/MathClassTest.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using AutoGen.OpenAI.V1.Extension; @@ -45,7 +46,11 @@ namespace AutoGen.OpenAI.V1.Tests _output.WriteLine($"agent name: {agent.Name}"); foreach (var message in messages) { - _output.WriteLine(message.FormatMessage()); + if (message is IMessage envelope) + { + var json = JsonSerializer.Serialize(envelope.Content, new JsonSerializerOptions { WriteIndented = true }); + _output.WriteLine(json); + } } throw; @@ -149,9 +154,9 @@ You create math question and ask student to answer it. Then you check if the answer is correct. If the answer is wrong, you ask student to fix it", modelName: model) - .RegisterMessageConnector() - .RegisterStreamingMiddleware(functionCallMiddleware) - .RegisterMiddleware(Print); + .RegisterMiddleware(Print) + .RegisterMiddleware(new OpenAIChatRequestMessageConnector()) + .RegisterMiddleware(functionCallMiddleware); return teacher; } diff --git a/dotnet/test/AutoGen.OpenAI.V1.Tests/OpenAIChatAgentTest.cs b/dotnet/test/AutoGen.OpenAI.V1.Tests/OpenAIChatAgentTest.cs index 0957cc9f4..1000339c6 100644 --- a/dotnet/test/AutoGen.OpenAI.V1.Tests/OpenAIChatAgentTest.cs +++ b/dotnet/test/AutoGen.OpenAI.V1.Tests/OpenAIChatAgentTest.cs @@ -22,7 +22,13 @@ public partial class OpenAIChatAgentTest [Function] public async Task GetWeatherAsync(string location) { - return $"The weather in {location} is sunny."; + return $"[GetWeather] The weather in {location} is sunny."; + } + + [Function] + public async Task CalculateTaxAsync(string location, double income) + { + return $"[CalculateTax] The tax in {location} for income {income} is 1000."; } [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")] @@ -270,6 +276,64 @@ public partial class OpenAIChatAgentTest action.Should().ThrowExactly().WithMessage("Messages should not be provided in options"); } + [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")] + public async Task ItProduceValidContentAfterFunctionCall() + { + // https://github.com/microsoft/autogen/issues/3437 + var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable."); + var openaiClient = CreateOpenAIClientFromAzureOpenAI(); + var options = new ChatCompletionsOptions(deployName, []) + { + Temperature = 0.7f, + MaxTokens = 1, + }; + + var agentName = "assistant"; + + var getWeatherToolCall = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}"); + var getWeatherToolCallResult = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}", "The weather in Seattle is sunny."); + var getWeatherToolCallMessage = new ToolCallMessage([getWeatherToolCall], from: agentName); + var getWeatherToolCallResultMessage = new ToolCallResultMessage([getWeatherToolCallResult], from: agentName); + var getWeatherAggregateMessage = new ToolCallAggregateMessage(getWeatherToolCallMessage, getWeatherToolCallResultMessage, from: agentName); + + var calculateTaxToolCall = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}"); + var calculateTaxToolCallResult = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}", "The tax in Seattle for income 1000 is 1000."); + var calculateTaxToolCallMessage = new ToolCallMessage([calculateTaxToolCall], from: agentName); + var calculateTaxToolCallResultMessage = new ToolCallResultMessage([calculateTaxToolCallResult], from: agentName); + var calculateTaxAggregateMessage = new ToolCallAggregateMessage(calculateTaxToolCallMessage, calculateTaxToolCallResultMessage, from: agentName); + + var chatHistory = new List() + { + new TextMessage(Role.User, "What's the weather in Seattle", from: "user"), + getWeatherAggregateMessage, + new TextMessage(Role.User, "The weather in Seattle is sunny, now check the tax in seattle", from: "admin"), + calculateTaxAggregateMessage, + new TextMessage(Role.User, "what's the weather in Paris", from: "user"), + getWeatherAggregateMessage, + new TextMessage(Role.User, "The weather in Paris is sunny, now check the tax in Paris", from: "admin"), + calculateTaxAggregateMessage, + new TextMessage(Role.User, "what's the weather in New York", from: "user"), + getWeatherAggregateMessage, + new TextMessage(Role.User, "The weather in New York is sunny, now check the tax in New York", from: "admin"), + calculateTaxAggregateMessage, + new TextMessage(Role.User, "what's the weather in London", from: "user"), + getWeatherAggregateMessage, + new TextMessage(Role.User, "The weather in London is sunny, now check the tax in London", from: "admin"), + }; + + var agent = new OpenAIChatAgent( + openAIClient: openaiClient, + name: "assistant", + options: options) + .RegisterMessageConnector(); + + var res = await agent.GenerateReplyAsync(chatHistory, new GenerateReplyOptions + { + MaxToken = 1024, + Functions = [this.GetWeatherAsyncFunctionContract, this.CalculateTaxAsyncFunctionContract], + }); + } + private OpenAIClient CreateOpenAIClientFromAzureOpenAI() { var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); diff --git a/dotnet/test/AutoGen.OpenAI.V1.Tests/OpenAIMessageTests.cs b/dotnet/test/AutoGen.OpenAI.V1.Tests/OpenAIMessageTests.cs index 3050c4e8e..876416fdc 100644 --- a/dotnet/test/AutoGen.OpenAI.V1.Tests/OpenAIMessageTests.cs +++ b/dotnet/test/AutoGen.OpenAI.V1.Tests/OpenAIMessageTests.cs @@ -278,7 +278,10 @@ public class OpenAIMessageTests var innerMessage = msgs.Last(); innerMessage!.Should().BeOfType>(); var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content; - chatRequestMessage.Name.Should().Be("assistant"); + // when the message is a tool call message + // the name field should not be set + // please visit OpenAIChatRequestMessageConnector class for more information + chatRequestMessage.Name.Should().BeNullOrEmpty(); chatRequestMessage.ToolCalls.Count().Should().Be(1); chatRequestMessage.Content.Should().Be("textContent"); chatRequestMessage.ToolCalls.First().Should().BeOfType(); @@ -309,7 +312,11 @@ public class OpenAIMessageTests innerMessage!.Should().BeOfType>(); var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content; chatRequestMessage.Content.Should().BeNullOrEmpty(); - chatRequestMessage.Name.Should().Be("assistant"); + + // when the message is a tool call message + // the name field should not be set + // please visit OpenAIChatRequestMessageConnector class for more information + chatRequestMessage.Name.Should().BeNullOrEmpty(); chatRequestMessage.ToolCalls.Count().Should().Be(2); for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++) {