mirror of https://github.com/microsoft/autogen.git
stop setting name field when assistant message contains tool call (#3481)
This commit is contained in:
parent
40cfe07a95
commit
a44b86f26e
|
@ -335,7 +335,10 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa
|
|||
|
||||
var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
|
||||
var textContent = message.GetContent() ?? string.Empty;
|
||||
var chatRequestMessage = new ChatRequestAssistantMessage(textContent) { Name = message.From };
|
||||
|
||||
// don't include the name field when it's tool call message.
|
||||
// fix https://github.com/microsoft/autogen/issues/3437
|
||||
var chatRequestMessage = new ChatRequestAssistantMessage(textContent);
|
||||
foreach (var tc in toolCall)
|
||||
{
|
||||
chatRequestMessage.ToolCalls.Add(tc);
|
||||
|
|
|
@ -322,7 +322,10 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa
|
|||
|
||||
var toolCallParts = message.ToolCalls.Select((tc, i) => ChatToolCall.CreateFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
|
||||
var textContent = message.GetContent() ?? null;
|
||||
var chatRequestMessage = new AssistantChatMessage(toolCallParts, textContent) { ParticipantName = message.From };
|
||||
|
||||
// Don't set participant name for assistant when it is tool call
|
||||
// fix https://github.com/microsoft/autogen/issues/3437
|
||||
var chatRequestMessage = new AssistantChatMessage(toolCallParts, textContent);
|
||||
|
||||
return [chatRequestMessage];
|
||||
}
|
||||
|
|
|
@ -139,7 +139,7 @@
|
|||
{
|
||||
"Role": "assistant",
|
||||
"Content": [],
|
||||
"Name": "assistant",
|
||||
"Name": null,
|
||||
"TooCall": [
|
||||
{
|
||||
"Type": "Function",
|
||||
|
@ -184,7 +184,7 @@
|
|||
{
|
||||
"Role": "assistant",
|
||||
"Content": [],
|
||||
"Name": "assistant",
|
||||
"Name": null,
|
||||
"TooCall": [
|
||||
{
|
||||
"Type": "Function",
|
||||
|
@ -210,7 +210,7 @@
|
|||
{
|
||||
"Role": "assistant",
|
||||
"Content": [],
|
||||
"Name": "assistant",
|
||||
"Name": null,
|
||||
"TooCall": [
|
||||
{
|
||||
"Type": "Function",
|
||||
|
|
|
@ -27,6 +27,12 @@ public partial class OpenAIChatAgentTest
|
|||
return $"The weather in {location} is sunny.";
|
||||
}
|
||||
|
||||
[Function]
|
||||
public async Task<string> CalculateTaxAsync(string location, double income)
|
||||
{
|
||||
return $"[CalculateTax] The tax in {location} for income {income} is 1000.";
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task BasicConversationTestAsync()
|
||||
{
|
||||
|
@ -246,6 +252,65 @@ public partial class OpenAIChatAgentTest
|
|||
respond.GetContent()?.Should().NotBeNullOrEmpty();
|
||||
}
|
||||
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task ItProduceValidContentAfterFunctionCall()
|
||||
{
|
||||
// https://github.com/microsoft/autogen/issues/3437
|
||||
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
|
||||
var openaiClient = CreateOpenAIClientFromAzureOpenAI();
|
||||
var options = new ChatCompletionOptions()
|
||||
{
|
||||
Temperature = 0.7f,
|
||||
MaxTokens = 1,
|
||||
};
|
||||
|
||||
var agentName = "assistant";
|
||||
|
||||
var getWeatherToolCall = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}");
|
||||
var getWeatherToolCallResult = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}", "The weather in Seattle is sunny.");
|
||||
var getWeatherToolCallMessage = new ToolCallMessage([getWeatherToolCall], from: agentName);
|
||||
var getWeatherToolCallResultMessage = new ToolCallResultMessage([getWeatherToolCallResult], from: agentName);
|
||||
var getWeatherAggregateMessage = new ToolCallAggregateMessage(getWeatherToolCallMessage, getWeatherToolCallResultMessage, from: agentName);
|
||||
|
||||
var calculateTaxToolCall = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}");
|
||||
var calculateTaxToolCallResult = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}", "The tax in Seattle for income 1000 is 1000.");
|
||||
var calculateTaxToolCallMessage = new ToolCallMessage([calculateTaxToolCall], from: agentName);
|
||||
var calculateTaxToolCallResultMessage = new ToolCallResultMessage([calculateTaxToolCallResult], from: agentName);
|
||||
var calculateTaxAggregateMessage = new ToolCallAggregateMessage(calculateTaxToolCallMessage, calculateTaxToolCallResultMessage, from: agentName);
|
||||
|
||||
var chatHistory = new List<IMessage>()
|
||||
{
|
||||
new TextMessage(Role.User, "What's the weather in Seattle", from: "user"),
|
||||
getWeatherAggregateMessage,
|
||||
new TextMessage(Role.User, "The weather in Seattle is sunny, now check the tax in seattle", from: "admin"),
|
||||
calculateTaxAggregateMessage,
|
||||
new TextMessage(Role.User, "what's the weather in Paris", from: "user"),
|
||||
getWeatherAggregateMessage,
|
||||
new TextMessage(Role.User, "The weather in Paris is sunny, now check the tax in Paris", from: "admin"),
|
||||
calculateTaxAggregateMessage,
|
||||
new TextMessage(Role.User, "what's the weather in New York", from: "user"),
|
||||
getWeatherAggregateMessage,
|
||||
new TextMessage(Role.User, "The weather in New York is sunny, now check the tax in New York", from: "admin"),
|
||||
calculateTaxAggregateMessage,
|
||||
new TextMessage(Role.User, "what's the weather in London", from: "user"),
|
||||
getWeatherAggregateMessage,
|
||||
new TextMessage(Role.User, "The weather in London is sunny, now check the tax in London", from: "admin"),
|
||||
};
|
||||
|
||||
var agent = new OpenAIChatAgent(
|
||||
chatClient: openaiClient.GetChatClient(deployName),
|
||||
name: "assistant",
|
||||
options: options)
|
||||
.RegisterMessageConnector();
|
||||
|
||||
var res = await agent.GenerateReplyAsync(chatHistory, new GenerateReplyOptions
|
||||
{
|
||||
MaxToken = 1024,
|
||||
Functions = [this.GetWeatherAsyncFunctionContract, this.CalculateTaxAsyncFunctionContract],
|
||||
});
|
||||
}
|
||||
|
||||
private OpenAIClient CreateOpenAIClientFromAzureOpenAI()
|
||||
{
|
||||
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
|
||||
|
|
|
@ -276,7 +276,10 @@ public class OpenAIMessageTests
|
|||
var innerMessage = msgs.Last();
|
||||
innerMessage!.Should().BeOfType<MessageEnvelope<ChatMessage>>();
|
||||
var chatRequestMessage = (AssistantChatMessage)((MessageEnvelope<ChatMessage>)innerMessage!).Content;
|
||||
chatRequestMessage.ParticipantName.Should().Be("assistant");
|
||||
// when the message is a tool call message
|
||||
// the name field should not be set
|
||||
// please visit OpenAIChatRequestMessageConnector class for more information
|
||||
chatRequestMessage.ParticipantName.Should().BeNullOrEmpty();
|
||||
chatRequestMessage.ToolCalls.Count().Should().Be(1);
|
||||
chatRequestMessage.Content.First().Text.Should().Be("textContent");
|
||||
chatRequestMessage.ToolCalls.First().Should().BeOfType<ChatToolCall>();
|
||||
|
@ -307,7 +310,10 @@ public class OpenAIMessageTests
|
|||
innerMessage!.Should().BeOfType<MessageEnvelope<ChatMessage>>();
|
||||
var chatRequestMessage = (AssistantChatMessage)((MessageEnvelope<ChatMessage>)innerMessage!).Content;
|
||||
chatRequestMessage.Content.Should().BeNullOrEmpty();
|
||||
chatRequestMessage.ParticipantName.Should().Be("assistant");
|
||||
// when the message is a tool call message
|
||||
// the name field should not be set
|
||||
// please visit OpenAIChatRequestMessageConnector class for more information
|
||||
chatRequestMessage.ParticipantName.Should().BeNullOrEmpty();
|
||||
chatRequestMessage.ToolCalls.Count().Should().Be(2);
|
||||
for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++)
|
||||
{
|
||||
|
|
|
@ -81,7 +81,7 @@
|
|||
{
|
||||
"Role": "assistant",
|
||||
"Content": "",
|
||||
"Name": "assistant",
|
||||
"Name": null,
|
||||
"TooCall": [
|
||||
{
|
||||
"Type": "Function",
|
||||
|
@ -126,7 +126,7 @@
|
|||
{
|
||||
"Role": "assistant",
|
||||
"Content": "",
|
||||
"Name": "assistant",
|
||||
"Name": null,
|
||||
"TooCall": [
|
||||
{
|
||||
"Type": "Function",
|
||||
|
@ -152,7 +152,7 @@
|
|||
{
|
||||
"Role": "assistant",
|
||||
"Content": "",
|
||||
"Name": "assistant",
|
||||
"Name": null,
|
||||
"TooCall": [
|
||||
{
|
||||
"Type": "Function",
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text.Json;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using AutoGen.OpenAI.V1.Extension;
|
||||
|
@ -45,7 +46,11 @@ namespace AutoGen.OpenAI.V1.Tests
|
|||
_output.WriteLine($"agent name: {agent.Name}");
|
||||
foreach (var message in messages)
|
||||
{
|
||||
_output.WriteLine(message.FormatMessage());
|
||||
if (message is IMessage<object> envelope)
|
||||
{
|
||||
var json = JsonSerializer.Serialize(envelope.Content, new JsonSerializerOptions { WriteIndented = true });
|
||||
_output.WriteLine(json);
|
||||
}
|
||||
}
|
||||
|
||||
throw;
|
||||
|
@ -149,9 +154,9 @@ You create math question and ask student to answer it.
|
|||
Then you check if the answer is correct.
|
||||
If the answer is wrong, you ask student to fix it",
|
||||
modelName: model)
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
.RegisterMiddleware(Print)
|
||||
.RegisterMiddleware(new OpenAIChatRequestMessageConnector())
|
||||
.RegisterMiddleware(functionCallMiddleware);
|
||||
|
||||
return teacher;
|
||||
}
|
||||
|
|
|
@ -22,7 +22,13 @@ public partial class OpenAIChatAgentTest
|
|||
[Function]
|
||||
public async Task<string> GetWeatherAsync(string location)
|
||||
{
|
||||
return $"The weather in {location} is sunny.";
|
||||
return $"[GetWeather] The weather in {location} is sunny.";
|
||||
}
|
||||
|
||||
[Function]
|
||||
public async Task<string> CalculateTaxAsync(string location, double income)
|
||||
{
|
||||
return $"[CalculateTax] The tax in {location} for income {income} is 1000.";
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
|
@ -270,6 +276,64 @@ public partial class OpenAIChatAgentTest
|
|||
action.Should().ThrowExactly<ArgumentException>().WithMessage("Messages should not be provided in options");
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task ItProduceValidContentAfterFunctionCall()
|
||||
{
|
||||
// https://github.com/microsoft/autogen/issues/3437
|
||||
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
|
||||
var openaiClient = CreateOpenAIClientFromAzureOpenAI();
|
||||
var options = new ChatCompletionsOptions(deployName, [])
|
||||
{
|
||||
Temperature = 0.7f,
|
||||
MaxTokens = 1,
|
||||
};
|
||||
|
||||
var agentName = "assistant";
|
||||
|
||||
var getWeatherToolCall = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}");
|
||||
var getWeatherToolCallResult = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}", "The weather in Seattle is sunny.");
|
||||
var getWeatherToolCallMessage = new ToolCallMessage([getWeatherToolCall], from: agentName);
|
||||
var getWeatherToolCallResultMessage = new ToolCallResultMessage([getWeatherToolCallResult], from: agentName);
|
||||
var getWeatherAggregateMessage = new ToolCallAggregateMessage(getWeatherToolCallMessage, getWeatherToolCallResultMessage, from: agentName);
|
||||
|
||||
var calculateTaxToolCall = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}");
|
||||
var calculateTaxToolCallResult = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}", "The tax in Seattle for income 1000 is 1000.");
|
||||
var calculateTaxToolCallMessage = new ToolCallMessage([calculateTaxToolCall], from: agentName);
|
||||
var calculateTaxToolCallResultMessage = new ToolCallResultMessage([calculateTaxToolCallResult], from: agentName);
|
||||
var calculateTaxAggregateMessage = new ToolCallAggregateMessage(calculateTaxToolCallMessage, calculateTaxToolCallResultMessage, from: agentName);
|
||||
|
||||
var chatHistory = new List<IMessage>()
|
||||
{
|
||||
new TextMessage(Role.User, "What's the weather in Seattle", from: "user"),
|
||||
getWeatherAggregateMessage,
|
||||
new TextMessage(Role.User, "The weather in Seattle is sunny, now check the tax in seattle", from: "admin"),
|
||||
calculateTaxAggregateMessage,
|
||||
new TextMessage(Role.User, "what's the weather in Paris", from: "user"),
|
||||
getWeatherAggregateMessage,
|
||||
new TextMessage(Role.User, "The weather in Paris is sunny, now check the tax in Paris", from: "admin"),
|
||||
calculateTaxAggregateMessage,
|
||||
new TextMessage(Role.User, "what's the weather in New York", from: "user"),
|
||||
getWeatherAggregateMessage,
|
||||
new TextMessage(Role.User, "The weather in New York is sunny, now check the tax in New York", from: "admin"),
|
||||
calculateTaxAggregateMessage,
|
||||
new TextMessage(Role.User, "what's the weather in London", from: "user"),
|
||||
getWeatherAggregateMessage,
|
||||
new TextMessage(Role.User, "The weather in London is sunny, now check the tax in London", from: "admin"),
|
||||
};
|
||||
|
||||
var agent = new OpenAIChatAgent(
|
||||
openAIClient: openaiClient,
|
||||
name: "assistant",
|
||||
options: options)
|
||||
.RegisterMessageConnector();
|
||||
|
||||
var res = await agent.GenerateReplyAsync(chatHistory, new GenerateReplyOptions
|
||||
{
|
||||
MaxToken = 1024,
|
||||
Functions = [this.GetWeatherAsyncFunctionContract, this.CalculateTaxAsyncFunctionContract],
|
||||
});
|
||||
}
|
||||
|
||||
private OpenAIClient CreateOpenAIClientFromAzureOpenAI()
|
||||
{
|
||||
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
|
||||
|
|
|
@ -278,7 +278,10 @@ public class OpenAIMessageTests
|
|||
var innerMessage = msgs.Last();
|
||||
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
|
||||
var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
|
||||
chatRequestMessage.Name.Should().Be("assistant");
|
||||
// when the message is a tool call message
|
||||
// the name field should not be set
|
||||
// please visit OpenAIChatRequestMessageConnector class for more information
|
||||
chatRequestMessage.Name.Should().BeNullOrEmpty();
|
||||
chatRequestMessage.ToolCalls.Count().Should().Be(1);
|
||||
chatRequestMessage.Content.Should().Be("textContent");
|
||||
chatRequestMessage.ToolCalls.First().Should().BeOfType<ChatCompletionsFunctionToolCall>();
|
||||
|
@ -309,7 +312,11 @@ public class OpenAIMessageTests
|
|||
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
|
||||
var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
|
||||
chatRequestMessage.Content.Should().BeNullOrEmpty();
|
||||
chatRequestMessage.Name.Should().Be("assistant");
|
||||
|
||||
// when the message is a tool call message
|
||||
// the name field should not be set
|
||||
// please visit OpenAIChatRequestMessageConnector class for more information
|
||||
chatRequestMessage.Name.Should().BeNullOrEmpty();
|
||||
chatRequestMessage.ToolCalls.Count().Should().Be(2);
|
||||
for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++)
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue