Merge branch 'main' into u/refactor

This commit is contained in:
XiaoYun Zhang 2024-10-30 11:15:43 -07:00
commit 9d052a1661
31 changed files with 1197 additions and 946 deletions

View File

@ -217,13 +217,14 @@ dotnet_diagnostic.IDE0161.severity = warning # Use file-scoped namespace
csharp_style_var_elsewhere = true:suggestion # Prefer 'var' everywhere
csharp_prefer_simple_using_statement = true:suggestion
csharp_style_namespace_declarations = block_scoped:silent
csharp_style_namespace_declarations = file_scoped:warning
csharp_style_prefer_method_group_conversion = true:silent
csharp_style_prefer_top_level_statements = true:silent
csharp_style_prefer_primary_constructors = true:suggestion
csharp_style_expression_bodied_lambdas = true:silent
csharp_style_prefer_local_over_anonymous_function = true:suggestion
dotnet_diagnostic.CA2016.severity = suggestion
csharp_prefer_static_anonymous_function = true:suggestion
# disable check for generated code
[*.generated.cs]
@ -693,6 +694,7 @@ dotnet_style_prefer_compound_assignment = true:suggestion
dotnet_style_prefer_simplified_interpolation = true:suggestion
dotnet_style_prefer_collection_expression = when_types_loosely_match:suggestion
dotnet_style_namespace_match_folder = true:suggestion
dotnet_style_qualification_for_method = false:silent
[**/*.g.cs]
generated_code = true

View File

@ -127,7 +127,7 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.Agents.Te
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "sample", "sample", "{686480D7-8FEC-4ED3-9C5D-CEBE1057A7ED}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "HelloAgentState", "samples\Hello\HelloAgentState\HelloAgentState.csproj", "{64EF61E7-00A6-4E5E-9808-62E10993A0E5}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HelloAgentState", "samples\Hello\HelloAgentState\HelloAgentState.csproj", "{64EF61E7-00A6-4E5E-9808-62E10993A0E5}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution

View File

@ -27,10 +27,11 @@ namespace Hello
[TopicSubscription("HelloAgents")]
public class HelloAgent(
IAgentContext context,
[FromKeyedServices("EventTypes")] EventTypes typeRegistry) : ConsoleAgent(
[FromKeyedServices("EventTypes")] EventTypes typeRegistry) : AgentBase(
context,
typeRegistry),
ISayHello,
IHandleConsole,
IHandle<NewMessageReceived>,
IHandle<ConversationClosed>
{

View File

@ -5,45 +5,44 @@ using System;
using System.Collections.Generic;
using System.Linq;
namespace AutoGen
namespace AutoGen;
public static class LLMConfigAPI
{
public static class LLMConfigAPI
public static IEnumerable<ILLMConfig> GetOpenAIConfigList(
string apiKey,
IEnumerable<string>? modelIDs = null)
{
public static IEnumerable<ILLMConfig> GetOpenAIConfigList(
string apiKey,
IEnumerable<string>? modelIDs = null)
var models = modelIDs ?? new[]
{
var models = modelIDs ?? new[]
{
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-4-32k",
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-4-1106-preview",
};
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-4-32k",
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-4-1106-preview",
};
return models.Select(modelId => new OpenAIConfig(apiKey, modelId));
}
return models.Select(modelId => new OpenAIConfig(apiKey, modelId));
}
public static IEnumerable<ILLMConfig> GetAzureOpenAIConfigList(
string endpoint,
string apiKey,
IEnumerable<string> deploymentNames)
{
return deploymentNames.Select(deploymentName => new AzureOpenAIConfig(endpoint, deploymentName, apiKey));
}
public static IEnumerable<ILLMConfig> GetAzureOpenAIConfigList(
string endpoint,
string apiKey,
IEnumerable<string> deploymentNames)
{
return deploymentNames.Select(deploymentName => new AzureOpenAIConfig(endpoint, deploymentName, apiKey));
}
/// <summary>
/// Get a list of LLMConfig objects from a JSON file.
/// </summary>
internal static IEnumerable<ILLMConfig> ConfigListFromJson(
string filePath,
IEnumerable<string>? filterModels = null)
{
// Disable this API from documentation for now.
throw new NotImplementedException();
}
/// <summary>
/// Get a list of LLMConfig objects from a JSON file.
/// </summary>
internal static IEnumerable<ILLMConfig> ConfigListFromJson(
string filePath,
IEnumerable<string>? filterModels = null)
{
// Disable this API from documentation for now.
throw new NotImplementedException();
}
}

View File

@ -2,6 +2,7 @@
// AgentBase.cs
using System.Diagnostics;
using System.Reflection;
using System.Text;
using System.Text.Json;
using System.Threading.Channels;
@ -20,6 +21,8 @@ public abstract class AgentBase : IAgentBase, IHandle
private readonly Channel<object> _mailbox = Channel.CreateUnbounded<object>();
private readonly IAgentContext _context;
public string Route { get; set; } = "base";
protected internal ILogger Logger => _context.Logger;
public IAgentContext Context => _context;
protected readonly EventTypes EventTypes;
@ -215,14 +218,39 @@ public abstract class AgentBase : IAgentBase, IHandle
public Task CallHandler(CloudEvent item)
{
// Only send the event to the handler if the agent type is handling that type
if (EventTypes.EventsMap[GetType()].Contains(item.Type))
// foreach of the keys in the EventTypes.EventsMap[] if it contains the item.type
foreach (var key in EventTypes.EventsMap.Keys)
{
var payload = item.ProtoData.Unpack(EventTypes.TypeRegistry);
var convertedPayload = Convert.ChangeType(payload, EventTypes.Types[item.Type]);
var genericInterfaceType = typeof(IHandle<>).MakeGenericType(EventTypes.Types[item.Type]);
var methodInfo = genericInterfaceType.GetMethod(nameof(IHandle<object>.Handle)) ?? throw new InvalidOperationException($"Method not found on type {genericInterfaceType.FullName}");
return methodInfo.Invoke(this, [payload]) as Task ?? Task.CompletedTask;
if (EventTypes.EventsMap[key].Contains(item.Type))
{
var payload = item.ProtoData.Unpack(EventTypes.TypeRegistry);
var convertedPayload = Convert.ChangeType(payload, EventTypes.Types[item.Type]);
var genericInterfaceType = typeof(IHandle<>).MakeGenericType(EventTypes.Types[item.Type]);
MethodInfo methodInfo;
try
{
// check that our target actually implements this interface, otherwise call the default static
if (genericInterfaceType.IsAssignableFrom(this.GetType()))
{
methodInfo = genericInterfaceType.GetMethod(nameof(IHandle<object>.Handle), BindingFlags.Public | BindingFlags.Instance)
?? throw new InvalidOperationException($"Method not found on type {genericInterfaceType.FullName}");
return methodInfo.Invoke(this, [payload]) as Task ?? Task.CompletedTask;
}
else
{
// The error here is we have registered for an event that we do not have code to listen to
throw new InvalidOperationException($"No handler found for event '{item.Type}'; expecting IHandle<{item.Type}> implementation.");
}
}
catch (Exception ex)
{
Logger.LogError(ex, $"Error invoking method {nameof(IHandle<object>.Handle)}");
throw; // TODO: ?
}
}
}
return Task.CompletedTask;
}

View File

@ -0,0 +1,46 @@
using Microsoft.AutoGen.Abstractions;
namespace Microsoft.AutoGen.Agents;
public interface IHandleConsole : IHandle<Output>, IHandle<Input>
{
string Route { get; }
AgentId AgentId { get; }
ValueTask PublishEvent(CloudEvent item);
async Task IHandle<Output>.Handle(Output item)
{
// Assuming item has a property `Message` that we want to write to the console
Console.WriteLine(item.Message);
await ProcessOutput(item.Message);
var evt = new OutputWritten
{
Route = "console"
}.ToCloudEvent(AgentId.Key);
await PublishEvent(evt);
}
async Task IHandle<Input>.Handle(Input item)
{
Console.WriteLine("Please enter input:");
string content = Console.ReadLine() ?? string.Empty;
await ProcessInput(content);
var evt = new InputProcessed
{
Route = "console"
}.ToCloudEvent(AgentId.Key);
await PublishEvent(evt);
}
static Task ProcessOutput(string message)
{
// Implement your output processing logic here
return Task.CompletedTask;
}
static Task<string> ProcessInput(string message)
{
// Implement your input processing logic here
return Task.FromResult(message);
}
}

View File

@ -86,6 +86,7 @@ public sealed class GrpcAgentWorkerRuntime : IHostedService, IDisposable, IAgent
message.Response.RequestId = request.OriginalRequestId;
request.Agent.ReceiveMessage(message);
break;
case Message.MessageOneofCase.RegisterAgentTypeResponse:
if (!message.RegisterAgentTypeResponse.Success)
{

View File

@ -74,7 +74,52 @@ public static class HostBuilderExtensions
.Where(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IHandle<>))
.Select(i => (GetMessageDescriptor(i.GetGenericArguments().First())?.FullName ?? "")).ToHashSet()))
.ToDictionary(item => item.t, item => item.Item2);
// if the assembly contains any interfaces of type IHandler, then add all the methods of the interface to the eventsMap
var handlersMap = AppDomain.CurrentDomain.GetAssemblies()
.SelectMany(assembly => assembly.GetTypes())
.Where(type => ReflectionHelper.IsSubclassOfGeneric(type, typeof(AgentBase)) && !type.IsAbstract)
.Select(t => (t, t.GetMethods()
.Where(m => m.Name == "Handle")
.Select(m => (GetMessageDescriptor(m.GetParameters().First().ParameterType)?.FullName ?? "")).ToHashSet()))
.ToDictionary(item => item.t, item => item.Item2);
// get interfaces implemented by the agent and get the methods of the interface if they are named Handle
var ifaceHandlersMap = AppDomain.CurrentDomain.GetAssemblies()
.SelectMany(assembly => assembly.GetTypes())
.Where(type => ReflectionHelper.IsSubclassOfGeneric(type, typeof(AgentBase)) && !type.IsAbstract)
.Select(t => t.GetInterfaces()
.Select(i => (t, i, i.GetMethods()
.Where(m => m.Name == "Handle")
.Select(m => (GetMessageDescriptor(m.GetParameters().First().ParameterType)?.FullName ?? ""))
//to dictionary of type t and paramter type of the method
.ToDictionary(m => m, m => m).Keys.ToHashSet())).ToList());
// for each item in ifaceHandlersMap, add the handlers to eventsMap with item as the key
foreach (var item in ifaceHandlersMap)
{
foreach (var iface in item)
{
if (eventsMap.TryGetValue(iface.Item2, out var events))
{
events.UnionWith(iface.Item3);
}
else
{
eventsMap[iface.Item2] = iface.Item3;
}
}
}
// merge the handlersMap into the eventsMap
foreach (var item in handlersMap)
{
if (eventsMap.TryGetValue(item.Key, out var events))
{
events.UnionWith(item.Value);
}
else
{
eventsMap[item.Key] = item.Value;
}
}
return new EventTypes(typeRegistry, types, eventsMap);
});
return new AgentApplicationBuilder(builder);

View File

@ -3,34 +3,33 @@
using Microsoft.Extensions.AI;
namespace Microsoft.Extensions.Hosting
{
public static class AIModelClient
{
public static IHostApplicationBuilder AddChatCompletionService(this IHostApplicationBuilder builder, string serviceName)
{
var pipeline = (ChatClientBuilder pipeline) => pipeline
.UseLogging()
.UseFunctionInvocation()
.UseOpenTelemetry(configure: c => c.EnableSensitiveData = true);
namespace Microsoft.Extensions.Hosting;
if (builder.Configuration[$"{serviceName}:ModelType"] == "ollama")
{
builder.AddOllamaChatClient(serviceName, pipeline);
}
else if (builder.Configuration[$"{serviceName}:ModelType"] == "openai" || builder.Configuration[$"{serviceName}:ModelType"] == "azureopenai")
{
builder.AddOpenAIChatClient(serviceName, pipeline);
}
else if (builder.Configuration[$"{serviceName}:ModelType"] == "azureaiinference")
{
builder.AddAzureChatClient(serviceName, pipeline);
}
else
{
throw new InvalidOperationException("Did not find a valid model implementation for the given service name ${serviceName}, valid supported implemenation types are ollama, openai, azureopenai, azureaiinference");
}
return builder;
public static class AIModelClient
{
public static IHostApplicationBuilder AddChatCompletionService(this IHostApplicationBuilder builder, string serviceName)
{
var pipeline = (ChatClientBuilder pipeline) => pipeline
.UseLogging()
.UseFunctionInvocation()
.UseOpenTelemetry(configure: c => c.EnableSensitiveData = true);
if (builder.Configuration[$"{serviceName}:ModelType"] == "ollama")
{
builder.AddOllamaChatClient(serviceName, pipeline);
}
else if (builder.Configuration[$"{serviceName}:ModelType"] == "openai" || builder.Configuration[$"{serviceName}:ModelType"] == "azureopenai")
{
builder.AddOpenAIChatClient(serviceName, pipeline);
}
else if (builder.Configuration[$"{serviceName}:ModelType"] == "azureaiinference")
{
builder.AddAzureChatClient(serviceName, pipeline);
}
else
{
throw new InvalidOperationException("Did not find a valid model implementation for the given service name ${serviceName}, valid supported implemenation types are ollama, openai, azureopenai, azureaiinference");
}
return builder;
}
}

View File

@ -14,206 +14,205 @@ using FluentAssertions;
using OpenAI;
using Xunit.Abstractions;
namespace AutoGen.OpenAI.Tests
namespace AutoGen.OpenAI.Tests;
public partial class MathClassTest
{
public partial class MathClassTest
private readonly ITestOutputHelper _output;
// as of 2024-05-20, aoai return 500 error when round > 1
// I'm pretty sure that round > 5 was supported before
// So this is probably some wield regression on aoai side
// I'll keep this test case here for now, plus setting round to 1
// so the test can still pass.
// In the future, we should rewind this test case to round > 1 (previously was 5)
private int round = 1;
public MathClassTest(ITestOutputHelper output)
{
private readonly ITestOutputHelper _output;
_output = output;
}
// as of 2024-05-20, aoai return 500 error when round > 1
// I'm pretty sure that round > 5 was supported before
// So this is probably some wield regression on aoai side
// I'll keep this test case here for now, plus setting round to 1
// so the test can still pass.
// In the future, we should rewind this test case to round > 1 (previously was 5)
private int round = 1;
public MathClassTest(ITestOutputHelper output)
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
{
try
{
_output = output;
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
_output.WriteLine(reply.FormatMessage());
return Task.FromResult(reply);
}
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
catch (Exception)
{
try
_output.WriteLine("Request failed");
_output.WriteLine($"agent name: {agent.Name}");
foreach (var message in messages)
{
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
_output.WriteLine(reply.FormatMessage());
return Task.FromResult(reply);
}
catch (Exception)
{
_output.WriteLine("Request failed");
_output.WriteLine($"agent name: {agent.Name}");
foreach (var message in messages)
{
_output.WriteLine(message.FormatMessage());
}
throw;
_output.WriteLine(message.FormatMessage());
}
throw;
}
[FunctionAttribute]
public async Task<string> CreateMathQuestion(string question, int question_index)
{
return $@"[MATH_QUESTION]
}
[FunctionAttribute]
public async Task<string> CreateMathQuestion(string question, int question_index)
{
return $@"[MATH_QUESTION]
Question {question_index}:
{question}
Student, please answer";
}
}
[FunctionAttribute]
public async Task<string> AnswerQuestion(string answer)
{
return $@"[MATH_ANSWER]
[FunctionAttribute]
public async Task<string> AnswerQuestion(string answer)
{
return $@"[MATH_ANSWER]
The answer is {answer}
teacher please check answer";
}
}
[FunctionAttribute]
public async Task<string> AnswerIsCorrect(string message)
{
return $@"[ANSWER_IS_CORRECT]
[FunctionAttribute]
public async Task<string> AnswerIsCorrect(string message)
{
return $@"[ANSWER_IS_CORRECT]
{message}
please update progress";
}
}
[FunctionAttribute]
public async Task<string> UpdateProgress(int correctAnswerCount)
[FunctionAttribute]
public async Task<string> UpdateProgress(int correctAnswerCount)
{
if (correctAnswerCount >= this.round)
{
if (correctAnswerCount >= this.round)
{
return $@"[UPDATE_PROGRESS]
return $@"[UPDATE_PROGRESS]
{GroupChatExtension.TERMINATE}";
}
else
{
return $@"[UPDATE_PROGRESS]
}
else
{
return $@"[UPDATE_PROGRESS]
the number of resolved question is {correctAnswerCount}
teacher, please create the next math question";
}
}
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task OpenAIAgentMathChatTestAsync()
{
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
var openaiClient = new AzureOpenAIClient(new Uri(endPoint), new ApiKeyCredential(key));
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task OpenAIAgentMathChatTestAsync()
{
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
var openaiClient = new AzureOpenAIClient(new Uri(endPoint), new ApiKeyCredential(key));
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
var adminFunctionMiddleware = new FunctionCallMiddleware(
functions: [this.UpdateProgressFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
});
var admin = new OpenAIChatAgent(
chatClient: openaiClient.GetChatClient(deployName),
name: "Admin",
systemMessage: $@"You are admin. You update progress after each question is answered.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(adminFunctionMiddleware)
.RegisterMiddleware(Print);
var adminFunctionMiddleware = new FunctionCallMiddleware(
functions: [this.UpdateProgressFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
});
var admin = new OpenAIChatAgent(
chatClient: openaiClient.GetChatClient(deployName),
name: "Admin",
systemMessage: $@"You are admin. You update progress after each question is answered.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(adminFunctionMiddleware)
.RegisterMiddleware(Print);
var groupAdmin = new OpenAIChatAgent(
chatClient: openaiClient.GetChatClient(deployName),
name: "GroupAdmin",
systemMessage: "You are group admin. You manage the group chat.")
.RegisterMessageConnector()
.RegisterMiddleware(Print);
await RunMathChatAsync(teacher, student, admin, groupAdmin);
}
var groupAdmin = new OpenAIChatAgent(
chatClient: openaiClient.GetChatClient(deployName),
name: "GroupAdmin",
systemMessage: "You are group admin. You manage the group chat.")
.RegisterMessageConnector()
.RegisterMiddleware(Print);
await RunMathChatAsync(teacher, student, admin, groupAdmin);
}
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
});
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
});
var teacher = new OpenAIChatAgent(
chatClient: client.GetChatClient(model),
name: "Teacher",
systemMessage: @"You are a preschool math teacher.
var teacher = new OpenAIChatAgent(
chatClient: client.GetChatClient(model),
name: "Teacher",
systemMessage: @"You are a preschool math teacher.
You create math question and ask student to answer it.
Then you check if the answer is correct.
If the answer is wrong, you ask student to fix it")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
return teacher;
}
return teacher;
}
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.AnswerQuestionFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
});
var student = new OpenAIChatAgent(
chatClient: client.GetChatClient(model),
name: "Student",
systemMessage: @"You are a student. You answer math question from teacher.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.AnswerQuestionFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
});
var student = new OpenAIChatAgent(
chatClient: client.GetChatClient(model),
name: "Student",
systemMessage: @"You are a student. You answer math question from teacher.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
return student;
}
return student;
}
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
{
var teacher2Student = Transition.Create(teacher, student);
var student2Teacher = Transition.Create(student, teacher);
var teacher2Admin = Transition.Create(teacher, admin);
var admin2Teacher = Transition.Create(admin, teacher);
var workflow = new Graph(
[
teacher2Student,
student2Teacher,
teacher2Admin,
admin2Teacher,
]);
var group = new GroupChat(
workflow: workflow,
members: [
admin,
teacher,
student,
],
admin: groupAdmin);
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
{
var teacher2Student = Transition.Create(teacher, student);
var student2Teacher = Transition.Create(student, teacher);
var teacher2Admin = Transition.Create(teacher, admin);
var admin2Teacher = Transition.Create(admin, teacher);
var workflow = new Graph(
[
teacher2Student,
student2Teacher,
teacher2Admin,
admin2Teacher,
]);
var group = new GroupChat(
workflow: workflow,
members: [
admin,
teacher,
student,
],
admin: groupAdmin);
var groupChatManager = new GroupChatManager(group);
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
var groupChatManager = new GroupChatManager(group);
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
// check if there's terminate chat message from admin
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
.Count()
.Should().Be(1);
}
// check if there's terminate chat message from admin
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
.Count()
.Should().Be(1);
}
}

View File

@ -13,214 +13,213 @@ using Azure.AI.OpenAI;
using FluentAssertions;
using Xunit.Abstractions;
namespace AutoGen.OpenAI.V1.Tests
namespace AutoGen.OpenAI.V1.Tests;
public partial class MathClassTest
{
public partial class MathClassTest
private readonly ITestOutputHelper _output;
// as of 2024-05-20, aoai return 500 error when round > 1
// I'm pretty sure that round > 5 was supported before
// So this is probably some wield regression on aoai side
// I'll keep this test case here for now, plus setting round to 1
// so the test can still pass.
// In the future, we should rewind this test case to round > 1 (previously was 5)
private int round = 1;
public MathClassTest(ITestOutputHelper output)
{
private readonly ITestOutputHelper _output;
_output = output;
}
// as of 2024-05-20, aoai return 500 error when round > 1
// I'm pretty sure that round > 5 was supported before
// So this is probably some wield regression on aoai side
// I'll keep this test case here for now, plus setting round to 1
// so the test can still pass.
// In the future, we should rewind this test case to round > 1 (previously was 5)
private int round = 1;
public MathClassTest(ITestOutputHelper output)
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
{
try
{
_output = output;
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
_output.WriteLine(reply.FormatMessage());
return Task.FromResult(reply);
}
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
catch (Exception)
{
try
_output.WriteLine("Request failed");
_output.WriteLine($"agent name: {agent.Name}");
foreach (var message in messages)
{
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
_output.WriteLine(reply.FormatMessage());
return Task.FromResult(reply);
}
catch (Exception)
{
_output.WriteLine("Request failed");
_output.WriteLine($"agent name: {agent.Name}");
foreach (var message in messages)
if (message is IMessage<object> envelope)
{
if (message is IMessage<object> envelope)
{
var json = JsonSerializer.Serialize(envelope.Content, new JsonSerializerOptions { WriteIndented = true });
_output.WriteLine(json);
}
var json = JsonSerializer.Serialize(envelope.Content, new JsonSerializerOptions { WriteIndented = true });
_output.WriteLine(json);
}
throw;
}
throw;
}
[FunctionAttribute]
public async Task<string> CreateMathQuestion(string question, int question_index)
{
return $@"[MATH_QUESTION]
}
[FunctionAttribute]
public async Task<string> CreateMathQuestion(string question, int question_index)
{
return $@"[MATH_QUESTION]
Question {question_index}:
{question}
Student, please answer";
}
}
[FunctionAttribute]
public async Task<string> AnswerQuestion(string answer)
{
return $@"[MATH_ANSWER]
[FunctionAttribute]
public async Task<string> AnswerQuestion(string answer)
{
return $@"[MATH_ANSWER]
The answer is {answer}
teacher please check answer";
}
}
[FunctionAttribute]
public async Task<string> AnswerIsCorrect(string message)
{
return $@"[ANSWER_IS_CORRECT]
[FunctionAttribute]
public async Task<string> AnswerIsCorrect(string message)
{
return $@"[ANSWER_IS_CORRECT]
{message}
please update progress";
}
}
[FunctionAttribute]
public async Task<string> UpdateProgress(int correctAnswerCount)
[FunctionAttribute]
public async Task<string> UpdateProgress(int correctAnswerCount)
{
if (correctAnswerCount >= this.round)
{
if (correctAnswerCount >= this.round)
{
return $@"[UPDATE_PROGRESS]
return $@"[UPDATE_PROGRESS]
{GroupChatExtension.TERMINATE}";
}
else
{
return $@"[UPDATE_PROGRESS]
}
else
{
return $@"[UPDATE_PROGRESS]
the number of resolved question is {correctAnswerCount}
teacher, please create the next math question";
}
}
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task OpenAIAgentMathChatTestAsync()
{
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
var openaiClient = new OpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(key));
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task OpenAIAgentMathChatTestAsync()
{
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
var openaiClient = new OpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(key));
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
var adminFunctionMiddleware = new FunctionCallMiddleware(
functions: [this.UpdateProgressFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
});
var admin = new OpenAIChatAgent(
openAIClient: openaiClient,
modelName: deployName,
name: "Admin",
systemMessage: $@"You are admin. You update progress after each question is answered.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(adminFunctionMiddleware)
.RegisterMiddleware(Print);
var adminFunctionMiddleware = new FunctionCallMiddleware(
functions: [this.UpdateProgressFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
});
var admin = new OpenAIChatAgent(
openAIClient: openaiClient,
modelName: deployName,
name: "Admin",
systemMessage: $@"You are admin. You update progress after each question is answered.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(adminFunctionMiddleware)
.RegisterMiddleware(Print);
var groupAdmin = new OpenAIChatAgent(
openAIClient: openaiClient,
modelName: deployName,
name: "GroupAdmin",
systemMessage: "You are group admin. You manage the group chat.")
.RegisterMessageConnector()
.RegisterMiddleware(Print);
await RunMathChatAsync(teacher, student, admin, groupAdmin);
}
var groupAdmin = new OpenAIChatAgent(
openAIClient: openaiClient,
modelName: deployName,
name: "GroupAdmin",
systemMessage: "You are group admin. You manage the group chat.")
.RegisterMessageConnector()
.RegisterMiddleware(Print);
await RunMathChatAsync(teacher, student, admin, groupAdmin);
}
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
});
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
});
var teacher = new OpenAIChatAgent(
openAIClient: client,
name: "Teacher",
systemMessage: @"You are a preschool math teacher.
var teacher = new OpenAIChatAgent(
openAIClient: client,
name: "Teacher",
systemMessage: @"You are a preschool math teacher.
You create math question and ask student to answer it.
Then you check if the answer is correct.
If the answer is wrong, you ask student to fix it",
modelName: model)
.RegisterMiddleware(Print)
.RegisterMiddleware(new OpenAIChatRequestMessageConnector())
.RegisterMiddleware(functionCallMiddleware);
modelName: model)
.RegisterMiddleware(Print)
.RegisterMiddleware(new OpenAIChatRequestMessageConnector())
.RegisterMiddleware(functionCallMiddleware);
return teacher;
}
return teacher;
}
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.AnswerQuestionFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
});
var student = new OpenAIChatAgent(
openAIClient: client,
name: "Student",
modelName: model,
systemMessage: @"You are a student. You answer math question from teacher.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.AnswerQuestionFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
});
var student = new OpenAIChatAgent(
openAIClient: client,
name: "Student",
modelName: model,
systemMessage: @"You are a student. You answer math question from teacher.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
return student;
}
return student;
}
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
{
var teacher2Student = Transition.Create(teacher, student);
var student2Teacher = Transition.Create(student, teacher);
var teacher2Admin = Transition.Create(teacher, admin);
var admin2Teacher = Transition.Create(admin, teacher);
var workflow = new Graph(
[
teacher2Student,
student2Teacher,
teacher2Admin,
admin2Teacher,
]);
var group = new GroupChat(
workflow: workflow,
members: [
admin,
teacher,
student,
],
admin: groupAdmin);
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
{
var teacher2Student = Transition.Create(teacher, student);
var student2Teacher = Transition.Create(student, teacher);
var teacher2Admin = Transition.Create(teacher, admin);
var admin2Teacher = Transition.Create(admin, teacher);
var workflow = new Graph(
[
teacher2Student,
student2Teacher,
teacher2Admin,
admin2Teacher,
]);
var group = new GroupChat(
workflow: workflow,
members: [
admin,
teacher,
student,
],
admin: groupAdmin);
var groupChatManager = new GroupChatManager(group);
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
var groupChatManager = new GroupChatManager(group);
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
// check if there's terminate chat message from admin
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
.Count()
.Should().Be(1);
}
// check if there's terminate chat message from admin
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
.Count()
.Should().Be(1);
}
}

View File

@ -4,85 +4,84 @@
using AutoGen.SourceGenerator.Template; // Needed for FunctionCallTemplate
using Xunit; // Needed for Fact and Assert
namespace AutoGen.SourceGenerator.Tests
namespace AutoGen.SourceGenerator.Tests;
public class FunctionCallTemplateEncodingTests
{
public class FunctionCallTemplateEncodingTests
[Fact]
public void FunctionDescription_Should_Encode_DoubleQuotes()
{
[Fact]
public void FunctionDescription_Should_Encode_DoubleQuotes()
// Arrange
var functionContracts = new List<SourceGeneratorFunctionContract>
{
// Arrange
var functionContracts = new List<SourceGeneratorFunctionContract>
new SourceGeneratorFunctionContract
{
new SourceGeneratorFunctionContract
Name = "TestFunction",
Description = "This is a \"test\" function",
Parameters = new SourceGeneratorParameterContract[]
{
Name = "TestFunction",
Description = "This is a \"test\" function",
Parameters = new SourceGeneratorParameterContract[]
new SourceGeneratorParameterContract
{
new SourceGeneratorParameterContract
{
Name = "param1",
Description = "This is a \"parameter\" description",
Type = "string",
IsOptional = false
}
},
ReturnType = "void"
}
};
Name = "param1",
Description = "This is a \"parameter\" description",
Type = "string",
IsOptional = false
}
},
ReturnType = "void"
}
};
var template = new FunctionCallTemplate
{
NameSpace = "TestNamespace",
ClassName = "TestClass",
FunctionContracts = functionContracts
};
// Act
var result = template.TransformText();
// Assert
Assert.Contains("Description = @\"This is a \"\"test\"\" function\"", result);
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
}
[Fact]
public void ParameterDescription_Should_Encode_DoubleQuotes()
var template = new FunctionCallTemplate
{
// Arrange
var functionContracts = new List<SourceGeneratorFunctionContract>
NameSpace = "TestNamespace",
ClassName = "TestClass",
FunctionContracts = functionContracts
};
// Act
var result = template.TransformText();
// Assert
Assert.Contains("Description = @\"This is a \"\"test\"\" function\"", result);
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
}
[Fact]
public void ParameterDescription_Should_Encode_DoubleQuotes()
{
// Arrange
var functionContracts = new List<SourceGeneratorFunctionContract>
{
new SourceGeneratorFunctionContract
{
new SourceGeneratorFunctionContract
Name = "TestFunction",
Description = "This is a test function",
Parameters = new SourceGeneratorParameterContract[]
{
Name = "TestFunction",
Description = "This is a test function",
Parameters = new SourceGeneratorParameterContract[]
new SourceGeneratorParameterContract
{
new SourceGeneratorParameterContract
{
Name = "param1",
Description = "This is a \"parameter\" description",
Type = "string",
IsOptional = false
}
},
ReturnType = "void"
}
};
Name = "param1",
Description = "This is a \"parameter\" description",
Type = "string",
IsOptional = false
}
},
ReturnType = "void"
}
};
var template = new FunctionCallTemplate
{
NameSpace = "TestNamespace",
ClassName = "TestClass",
FunctionContracts = functionContracts
};
var template = new FunctionCallTemplate
{
NameSpace = "TestNamespace",
ClassName = "TestClass",
FunctionContracts = functionContracts
};
// Act
var result = template.TransformText();
// Act
var result = template.TransformText();
// Assert
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
}
// Assert
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
}
}

View File

@ -10,122 +10,121 @@ using FluentAssertions;
using OpenAI.Chat;
using Xunit;
namespace AutoGen.SourceGenerator.Tests
namespace AutoGen.SourceGenerator.Tests;
public class FunctionExample
{
public class FunctionExample
private readonly FunctionExamples functionExamples = new FunctionExamples();
private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions
{
private readonly FunctionExamples functionExamples = new FunctionExamples();
private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions
WriteIndented = true,
};
[Fact]
public void Add_Test()
{
var args = new
{
WriteIndented = true,
a = 1,
b = 2,
};
[Fact]
public void Add_Test()
this.VerifyFunction(functionExamples.AddWrapper, args, 3);
this.VerifyFunctionDefinition(functionExamples.AddFunctionContract.ToChatTool());
}
[Fact]
public void Sum_Test()
{
var args = new
{
var args = new
args = new double[] { 1, 2, 3 },
};
this.VerifyFunction(functionExamples.SumWrapper, args, 6.0);
this.VerifyFunctionDefinition(functionExamples.SumFunctionContract.ToChatTool());
}
[Fact]
public async Task DictionaryToString_Test()
{
var args = new
{
xargs = new Dictionary<string, string>
{
a = 1,
b = 2,
};
{ "a", "1" },
{ "b", "2" },
},
};
this.VerifyFunction(functionExamples.AddWrapper, args, 3);
this.VerifyFunctionDefinition(functionExamples.AddFunctionContract.ToChatTool());
}
await this.VerifyAsyncFunction(functionExamples.DictionaryToStringAsyncWrapper, args, JsonSerializer.Serialize(args.xargs, jsonSerializerOptions));
this.VerifyFunctionDefinition(functionExamples.DictionaryToStringAsyncFunctionContract.ToChatTool());
}
[Fact]
public void Sum_Test()
[Fact]
public async Task TopLevelFunctionExampleAddTestAsync()
{
var example = new TopLevelStatementFunctionExample();
var args = new
{
var args = new
{
args = new double[] { 1, 2, 3 },
};
a = 1,
b = 2,
};
this.VerifyFunction(functionExamples.SumWrapper, args, 6.0);
this.VerifyFunctionDefinition(functionExamples.SumFunctionContract.ToChatTool());
}
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
}
[Fact]
public async Task DictionaryToString_Test()
[Fact]
public async Task FilescopeFunctionExampleAddTestAsync()
{
var example = new FilescopeNamespaceFunctionExample();
var args = new
{
var args = new
{
xargs = new Dictionary<string, string>
{
{ "a", "1" },
{ "b", "2" },
},
};
a = 1,
b = 2,
};
await this.VerifyAsyncFunction(functionExamples.DictionaryToStringAsyncWrapper, args, JsonSerializer.Serialize(args.xargs, jsonSerializerOptions));
this.VerifyFunctionDefinition(functionExamples.DictionaryToStringAsyncFunctionContract.ToChatTool());
}
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
}
[Fact]
public async Task TopLevelFunctionExampleAddTestAsync()
[Fact]
public void Query_Test()
{
var args = new
{
var example = new TopLevelStatementFunctionExample();
var args = new
{
a = 1,
b = 2,
};
query = "hello",
k = 3,
};
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
}
this.VerifyFunction(functionExamples.QueryWrapper, args, new[] { "hello", "hello", "hello" });
this.VerifyFunctionDefinition(functionExamples.QueryFunctionContract.ToChatTool());
}
[Fact]
public async Task FilescopeFunctionExampleAddTestAsync()
[UseReporter(typeof(DiffReporter))]
[UseApprovalSubdirectory("ApprovalTests")]
private void VerifyFunctionDefinition(ChatTool function)
{
var func = new
{
var example = new FilescopeNamespaceFunctionExample();
var args = new
{
a = 1,
b = 2,
};
name = function.FunctionName,
description = function.FunctionDescription.Replace(Environment.NewLine, ","),
parameters = function.FunctionParameters.ToObjectFromJson<object>(options: jsonSerializerOptions),
};
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
}
Approvals.Verify(JsonSerializer.Serialize(func, jsonSerializerOptions));
}
[Fact]
public void Query_Test()
{
var args = new
{
query = "hello",
k = 3,
};
private void VerifyFunction<T, U>(Func<string, T> func, U args, T expected)
{
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
var res = func(str);
res.Should().BeEquivalentTo(expected);
}
this.VerifyFunction(functionExamples.QueryWrapper, args, new[] { "hello", "hello", "hello" });
this.VerifyFunctionDefinition(functionExamples.QueryFunctionContract.ToChatTool());
}
[UseReporter(typeof(DiffReporter))]
[UseApprovalSubdirectory("ApprovalTests")]
private void VerifyFunctionDefinition(ChatTool function)
{
var func = new
{
name = function.FunctionName,
description = function.FunctionDescription.Replace(Environment.NewLine, ","),
parameters = function.FunctionParameters.ToObjectFromJson<object>(options: jsonSerializerOptions),
};
Approvals.Verify(JsonSerializer.Serialize(func, jsonSerializerOptions));
}
private void VerifyFunction<T, U>(Func<string, T> func, U args, T expected)
{
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
var res = func(str);
res.Should().BeEquivalentTo(expected);
}
private async Task VerifyAsyncFunction<T, U>(Func<string, Task<T>> func, U args, T expected)
{
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
var res = await func(str);
res.Should().BeEquivalentTo(expected);
}
private async Task VerifyAsyncFunction<T, U>(Func<string, Task<T>> func, U args, T expected)
{
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
var res = await func(str);
res.Should().BeEquivalentTo(expected);
}
}

View File

@ -4,67 +4,66 @@
using System.Text.Json;
using AutoGen.Core;
namespace AutoGen.SourceGenerator.Tests
namespace AutoGen.SourceGenerator.Tests;
public partial class FunctionExamples
{
public partial class FunctionExamples
/// <summary>
/// Add function
/// </summary>
/// <param name="a">a</param>
/// <param name="b">b</param>
[FunctionAttribute]
public int Add(int a, int b)
{
/// <summary>
/// Add function
/// </summary>
/// <param name="a">a</param>
/// <param name="b">b</param>
[FunctionAttribute]
public int Add(int a, int b)
{
return a + b;
}
return a + b;
}
/// <summary>
/// Add two numbers.
/// </summary>
/// <param name="a">The first number.</param>
/// <param name="b">The second number.</param>
[Function]
public Task<string> AddAsync(int a, int b)
{
return Task.FromResult($"{a} + {b} = {a + b}");
}
/// <summary>
/// Add two numbers.
/// </summary>
/// <param name="a">The first number.</param>
/// <param name="b">The second number.</param>
[Function]
public Task<string> AddAsync(int a, int b)
{
return Task.FromResult($"{a} + {b} = {a + b}");
}
/// <summary>
/// Sum function
/// </summary>
/// <param name="args">an array of double values</param>
[FunctionAttribute]
public double Sum(double[] args)
{
return args.Sum();
}
/// <summary>
/// Sum function
/// </summary>
/// <param name="args">an array of double values</param>
[FunctionAttribute]
public double Sum(double[] args)
{
return args.Sum();
}
/// <summary>
/// DictionaryToString function
/// </summary>
/// <param name="xargs">an object of key-value pairs. key is string, value is string</param>
[FunctionAttribute]
public Task<string> DictionaryToStringAsync(Dictionary<string, string> xargs)
/// <summary>
/// DictionaryToString function
/// </summary>
/// <param name="xargs">an object of key-value pairs. key is string, value is string</param>
[FunctionAttribute]
public Task<string> DictionaryToStringAsync(Dictionary<string, string> xargs)
{
var res = JsonSerializer.Serialize(xargs, new JsonSerializerOptions
{
var res = JsonSerializer.Serialize(xargs, new JsonSerializerOptions
{
WriteIndented = true,
});
WriteIndented = true,
});
return Task.FromResult(res);
}
return Task.FromResult(res);
}
/// <summary>
/// query function
/// </summary>
/// <param name="query">query, required</param>
/// <param name="k">top k, optional, default value is 3</param>
/// <param name="thresold">thresold, optional, default value is 0.5</param>
[FunctionAttribute]
public string[] Query(string query, int k = 3, float thresold = 0.5f)
{
return Enumerable.Repeat(query, k).ToArray();
}
/// <summary>
/// query function
/// </summary>
/// <param name="query">query, required</param>
/// <param name="k">top k, optional, default value is 3</param>
/// <param name="thresold">thresold, optional, default value is 0.5</param>
[FunctionAttribute]
public string[] Query(string query, int k = 3, float thresold = 0.5f)
{
return Enumerable.Repeat(query, k).ToArray();
}
}

View File

@ -7,73 +7,72 @@ using System.Threading.Tasks;
using AutoGen.BasicSample;
using Xunit.Abstractions;
namespace AutoGen.Tests
namespace AutoGen.Tests;
public class BasicSampleTest
{
public class BasicSampleTest
private readonly ITestOutputHelper _output;
public BasicSampleTest(ITestOutputHelper output)
{
private readonly ITestOutputHelper _output;
_output = output;
Console.SetOut(new ConsoleWriter(_output));
}
public BasicSampleTest(ITestOutputHelper output)
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task AssistantAgentTestAsync()
{
await Example01_AssistantAgent.RunAsync();
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task TwoAgentMathClassTestAsync()
{
await Example02_TwoAgent_MathChat.RunAsync();
}
[ApiKeyFact("OPENAI_API_KEY")]
public async Task AgentFunctionCallTestAsync()
{
await Example03_Agent_FunctionCall.RunAsync();
}
[ApiKeyFact("MISTRAL_API_KEY")]
public async Task MistralClientAgent_TokenCount()
{
await Example14_MistralClientAgent_TokenCount.RunAsync();
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task DynamicGroupChatCalculateFibonacciAsync()
{
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunAsync();
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunWorkflowAsync();
}
[ApiKeyFact("OPENAI_API_KEY")]
public async Task DalleAndGPT4VTestAsync()
{
await Example05_Dalle_And_GPT4V.RunAsync();
}
[ApiKeyFact("OPENAI_API_KEY")]
public async Task GPT4ImageMessage()
{
await Example15_GPT4V_BinaryDataImageMessage.RunAsync();
}
public class ConsoleWriter : StringWriter
{
private ITestOutputHelper output;
public ConsoleWriter(ITestOutputHelper output)
{
_output = output;
Console.SetOut(new ConsoleWriter(_output));
this.output = output;
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task AssistantAgentTestAsync()
public override void WriteLine(string? m)
{
await Example01_AssistantAgent.RunAsync();
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task TwoAgentMathClassTestAsync()
{
await Example02_TwoAgent_MathChat.RunAsync();
}
[ApiKeyFact("OPENAI_API_KEY")]
public async Task AgentFunctionCallTestAsync()
{
await Example03_Agent_FunctionCall.RunAsync();
}
[ApiKeyFact("MISTRAL_API_KEY")]
public async Task MistralClientAgent_TokenCount()
{
await Example14_MistralClientAgent_TokenCount.RunAsync();
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task DynamicGroupChatCalculateFibonacciAsync()
{
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunAsync();
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunWorkflowAsync();
}
[ApiKeyFact("OPENAI_API_KEY")]
public async Task DalleAndGPT4VTestAsync()
{
await Example05_Dalle_And_GPT4V.RunAsync();
}
[ApiKeyFact("OPENAI_API_KEY")]
public async Task GPT4ImageMessage()
{
await Example15_GPT4V_BinaryDataImageMessage.RunAsync();
}
public class ConsoleWriter : StringWriter
{
private ITestOutputHelper output;
public ConsoleWriter(ITestOutputHelper output)
{
this.output = output;
}
public override void WriteLine(string? m)
{
output.WriteLine(m);
}
output.WriteLine(m);
}
}
}

View File

@ -3,18 +3,17 @@
using Xunit;
namespace AutoGen.Tests
{
public class GraphTests
{
[Fact]
public void GraphTest()
{
var graph1 = new Graph();
Assert.NotNull(graph1);
namespace AutoGen.Tests;
var graph2 = new Graph(null);
Assert.NotNull(graph2);
}
public class GraphTests
{
[Fact]
public void GraphTest()
{
var graph1 = new Graph();
Assert.NotNull(graph1);
var graph2 = new Graph(null);
Assert.NotNull(graph2);
}
}

View File

@ -9,219 +9,218 @@ using FluentAssertions;
using Xunit;
using Xunit.Abstractions;
namespace AutoGen.Tests
namespace AutoGen.Tests;
public partial class SingleAgentTest
{
public partial class SingleAgentTest
private ITestOutputHelper _output;
public SingleAgentTest(ITestOutputHelper output)
{
private ITestOutputHelper _output;
public SingleAgentTest(ITestOutputHelper output)
{
_output = output;
}
_output = output;
}
private ILLMConfig CreateAzureOpenAIGPT35TurboConfig()
{
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
return new AzureOpenAIConfig(endpoint, deployName, key);
}
private ILLMConfig CreateAzureOpenAIGPT35TurboConfig()
{
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
return new AzureOpenAIConfig(endpoint, deployName, key);
}
private ILLMConfig CreateOpenAIGPT4VisionConfig()
{
var key = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new ArgumentException("OPENAI_API_KEY is not set");
return new OpenAIConfig(key, "gpt-4-vision-preview");
}
private ILLMConfig CreateOpenAIGPT4VisionConfig()
{
var key = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new ArgumentException("OPENAI_API_KEY is not set");
return new OpenAIConfig(key, "gpt-4-vision-preview");
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task AssistantAgentFunctionCallTestAsync()
{
var config = this.CreateAzureOpenAIGPT35TurboConfig();
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task AssistantAgentFunctionCallTestAsync()
{
var config = this.CreateAzureOpenAIGPT35TurboConfig();
var llmConfig = new ConversableAgentConfig
var llmConfig = new ConversableAgentConfig
{
Temperature = 0,
FunctionContracts = new[]
{
Temperature = 0,
FunctionContracts = new[]
{
this.EchoAsyncFunctionContract,
},
ConfigList = new[]
{
config,
},
};
var assistantAgent = new AssistantAgent(
name: "assistant",
llmConfig: llmConfig);
await EchoFunctionCallTestAsync(assistantAgent);
}
[Fact]
public async Task AssistantAgentDefaultReplyTestAsync()
{
var assistantAgent = new AssistantAgent(
llmConfig: null,
name: "assistant",
defaultReply: "hello world");
var reply = await assistantAgent.SendAsync("hi");
reply.GetContent().Should().Be("hello world");
reply.GetRole().Should().Be(Role.Assistant);
reply.From.Should().Be(assistantAgent.Name);
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task AssistantAgentFunctionCallSelfExecutionTestAsync()
{
var config = this.CreateAzureOpenAIGPT35TurboConfig();
var llmConfig = new ConversableAgentConfig
this.EchoAsyncFunctionContract,
},
ConfigList = new[]
{
FunctionContracts = new[]
{
this.EchoAsyncFunctionContract,
},
ConfigList = new[]
{
config,
},
};
var assistantAgent = new AssistantAgent(
name: "assistant",
llmConfig: llmConfig,
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ nameof(EchoAsync), this.EchoAsyncWrapper },
});
config,
},
};
await EchoFunctionCallExecutionTestAsync(assistantAgent);
}
var assistantAgent = new AssistantAgent(
name: "assistant",
llmConfig: llmConfig);
/// <summary>
/// echo when asked.
/// </summary>
/// <param name="message">message to echo</param>
[FunctionAttribute]
public async Task<string> EchoAsync(string message)
await EchoFunctionCallTestAsync(assistantAgent);
}
[Fact]
public async Task AssistantAgentDefaultReplyTestAsync()
{
var assistantAgent = new AssistantAgent(
llmConfig: null,
name: "assistant",
defaultReply: "hello world");
var reply = await assistantAgent.SendAsync("hi");
reply.GetContent().Should().Be("hello world");
reply.GetRole().Should().Be(Role.Assistant);
reply.From.Should().Be(assistantAgent.Name);
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task AssistantAgentFunctionCallSelfExecutionTestAsync()
{
var config = this.CreateAzureOpenAIGPT35TurboConfig();
var llmConfig = new ConversableAgentConfig
{
return $"[ECHO] {message}";
}
FunctionContracts = new[]
{
this.EchoAsyncFunctionContract,
},
ConfigList = new[]
{
config,
},
};
var assistantAgent = new AssistantAgent(
name: "assistant",
llmConfig: llmConfig,
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ nameof(EchoAsync), this.EchoAsyncWrapper },
});
/// <summary>
/// return the label name with hightest inference cost
/// </summary>
/// <param name="labelName"></param>
/// <returns></returns>
[FunctionAttribute]
public async Task<string> GetHighestLabel(string labelName, string color)
await EchoFunctionCallExecutionTestAsync(assistantAgent);
}
/// <summary>
/// echo when asked.
/// </summary>
/// <param name="message">message to echo</param>
[FunctionAttribute]
public async Task<string> EchoAsync(string message)
{
return $"[ECHO] {message}";
}
/// <summary>
/// return the label name with hightest inference cost
/// </summary>
/// <param name="labelName"></param>
/// <returns></returns>
[FunctionAttribute]
public async Task<string> GetHighestLabel(string labelName, string color)
{
return $"[HIGHEST_LABEL] {labelName} {color}";
}
public async Task EchoFunctionCallTestAsync(IAgent agent)
{
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that call echo function");
var helloWorld = new TextMessage(Role.User, "echo Hello world");
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
reply.From.Should().Be(agent.Name);
reply.GetToolCalls()!.First().FunctionName.Should().Be(nameof(EchoAsync));
}
public async Task EchoFunctionCallExecutionTestAsync(IAgent agent)
{
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
var helloWorld = new TextMessage(Role.User, "echo Hello world");
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
reply.GetContent().Should().Be("[ECHO] Hello world");
reply.From.Should().Be(agent.Name);
reply.Should().BeOfType<ToolCallAggregateMessage>();
}
public async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent agent)
{
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
var helloWorld = new TextMessage(Role.User, "echo Hello world");
var option = new GenerateReplyOptions
{
return $"[HIGHEST_LABEL] {labelName} {color}";
}
public async Task EchoFunctionCallTestAsync(IAgent agent)
Temperature = 0,
};
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { helloWorld }, option);
var answer = "[ECHO] Hello world";
IMessage? finalReply = default;
await foreach (var reply in replyStream)
{
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that call echo function");
var helloWorld = new TextMessage(Role.User, "echo Hello world");
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
reply.From.Should().Be(agent.Name);
reply.GetToolCalls()!.First().FunctionName.Should().Be(nameof(EchoAsync));
finalReply = reply;
}
public async Task EchoFunctionCallExecutionTestAsync(IAgent agent)
if (finalReply is ToolCallAggregateMessage aggregateMessage)
{
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
var helloWorld = new TextMessage(Role.User, "echo Hello world");
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
reply.GetContent().Should().Be("[ECHO] Hello world");
reply.From.Should().Be(agent.Name);
reply.Should().BeOfType<ToolCallAggregateMessage>();
var toolCallResultMessage = aggregateMessage.Message2;
toolCallResultMessage.ToolCalls.First().Result.Should().Be(answer);
toolCallResultMessage.From.Should().Be(agent.Name);
toolCallResultMessage.ToolCalls.First().FunctionName.Should().Be(nameof(EchoAsync));
}
public async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent agent)
else
{
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
var helloWorld = new TextMessage(Role.User, "echo Hello world");
var option = new GenerateReplyOptions
{
Temperature = 0,
};
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { helloWorld }, option);
var answer = "[ECHO] Hello world";
IMessage? finalReply = default;
await foreach (var reply in replyStream)
{
reply.From.Should().Be(agent.Name);
finalReply = reply;
}
if (finalReply is ToolCallAggregateMessage aggregateMessage)
{
var toolCallResultMessage = aggregateMessage.Message2;
toolCallResultMessage.ToolCalls.First().Result.Should().Be(answer);
toolCallResultMessage.From.Should().Be(agent.Name);
toolCallResultMessage.ToolCalls.First().FunctionName.Should().Be(nameof(EchoAsync));
}
else
{
throw new Exception("unexpected message type");
}
}
public async Task UpperCaseTestAsync(IAgent agent)
{
var message = new TextMessage(Role.User, "Please convert abcde to upper case.");
var reply = await agent.SendAsync(chatHistory: new[] { message });
reply.GetContent().Should().Contain("ABCDE");
reply.From.Should().Be(agent.Name);
}
public async Task UpperCaseStreamingTestAsync(IStreamingAgent agent)
{
var message = new TextMessage(Role.User, "Please convert 'hello world' to upper case");
var option = new GenerateReplyOptions
{
Temperature = 0,
};
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message }, option);
var answer = "HELLO WORLD";
TextMessage? finalReply = default;
await foreach (var reply in replyStream)
{
if (reply is TextMessageUpdate update)
{
update.From.Should().Be(agent.Name);
if (finalReply is null)
{
finalReply = new TextMessage(update);
}
else
{
finalReply.Update(update);
}
continue;
}
else if (reply is TextMessage textMessage)
{
finalReply = textMessage;
continue;
}
throw new Exception("unexpected message type");
}
finalReply!.Content.Should().Contain(answer);
finalReply!.Role.Should().Be(Role.Assistant);
finalReply!.From.Should().Be(agent.Name);
throw new Exception("unexpected message type");
}
}
public async Task UpperCaseTestAsync(IAgent agent)
{
var message = new TextMessage(Role.User, "Please convert abcde to upper case.");
var reply = await agent.SendAsync(chatHistory: new[] { message });
reply.GetContent().Should().Contain("ABCDE");
reply.From.Should().Be(agent.Name);
}
public async Task UpperCaseStreamingTestAsync(IStreamingAgent agent)
{
var message = new TextMessage(Role.User, "Please convert 'hello world' to upper case");
var option = new GenerateReplyOptions
{
Temperature = 0,
};
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message }, option);
var answer = "HELLO WORLD";
TextMessage? finalReply = default;
await foreach (var reply in replyStream)
{
if (reply is TextMessageUpdate update)
{
update.From.Should().Be(agent.Name);
if (finalReply is null)
{
finalReply = new TextMessage(update);
}
else
{
finalReply.Update(update);
}
continue;
}
else if (reply is TextMessage textMessage)
{
finalReply = textMessage;
continue;
}
throw new Exception("unexpected message type");
}
finalReply!.Content.Should().Contain(answer);
finalReply!.Role.Should().Be(Role.Assistant);
finalReply!.From.Should().Be(agent.Name);
}
}

View File

@ -18,12 +18,16 @@ from autogen_core.components.tools import FunctionTool, Tool
from pydantic import BaseModel, ConfigDict, Field, model_validator
from .. import EVENT_LOGGER_NAME
from ..base import Response
from ..messages import (
ChatMessage,
HandoffMessage,
InnerMessage,
ResetMessage,
StopMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessages,
)
from ._base_chat_agent import BaseChatAgent
@ -207,7 +211,14 @@ class AssistantAgent(BaseChatAgent):
)
self._model_context: List[LLMMessage] = []
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
"""The types of messages that the assistant agent produces."""
if self._handoffs:
return [TextMessage, HandoffMessage, StopMessage]
return [TextMessage, StopMessage]
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
# Add messages to the model context.
for msg in messages:
if isinstance(msg, ResetMessage):
@ -215,6 +226,9 @@ class AssistantAgent(BaseChatAgent):
else:
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
# Inner messages.
inner_messages: List[InnerMessage] = []
# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
result = await self._model_client.create(
@ -227,12 +241,16 @@ class AssistantAgent(BaseChatAgent):
# Run tool calls until the model produces a string response.
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
# Add the tool call message to the output.
inner_messages.append(ToolCallMessage(content=result.content, source=self.name))
# Execute the tool calls.
results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
)
event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name))
self._model_context.append(FunctionExecutionResultMessage(content=results))
inner_messages.append(ToolCallResultMessages(content=results, source=self.name))
# Detect handoff requests.
handoffs: List[Handoff] = []
@ -242,8 +260,13 @@ class AssistantAgent(BaseChatAgent):
if len(handoffs) > 0:
if len(handoffs) > 1:
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
# Respond with a handoff message.
return HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name)
# Return the output messages to signal the handoff.
return Response(
chat_message=HandoffMessage(
content=handoffs[0].message, target=handoffs[0].target, source=self.name
),
inner_messages=inner_messages,
)
# Generate an inference result based on the current model context.
result = await self._model_client.create(
@ -255,9 +278,13 @@ class AssistantAgent(BaseChatAgent):
# Detect stop request.
request_stop = "terminate" in result.content.strip().lower()
if request_stop:
return StopMessage(content=result.content, source=self.name)
return Response(
chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages
)
return TextMessage(content=result.content, source=self.name)
return Response(
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
)
async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken

View File

@ -1,11 +1,10 @@
from abc import ABC, abstractmethod
from typing import Sequence
from typing import List, Sequence
from autogen_core.base import CancellationToken
from ..base import ChatAgent, TaskResult, TerminationCondition
from ..messages import ChatMessage
from ..teams import RoundRobinGroupChat
from ..base import ChatAgent, Response, TaskResult, TerminationCondition
from ..messages import ChatMessage, InnerMessage, TextMessage
class BaseChatAgent(ChatAgent, ABC):
@ -30,9 +29,15 @@ class BaseChatAgent(ChatAgent, ABC):
describe the agent's capabilities and how to interact with it."""
return self._description
@property
@abstractmethod
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
"""Handle incoming messages and return a response message."""
def produced_message_types(self) -> List[type[ChatMessage]]:
"""The types of messages that the agent produces."""
...
@abstractmethod
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
"""Handles incoming messages and returns a response."""
...
async def run(
@ -43,10 +48,12 @@ class BaseChatAgent(ChatAgent, ABC):
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the agent with the given task and return the result."""
group_chat = RoundRobinGroupChat(participants=[self])
result = await group_chat.run(
task=task,
cancellation_token=cancellation_token,
termination_condition=termination_condition,
)
return result
if cancellation_token is None:
cancellation_token = CancellationToken()
first_message = TextMessage(content=task, source="user")
response = await self.on_messages([first_message], cancellation_token)
messages: List[InnerMessage | ChatMessage] = [first_message]
if response.inner_messages is not None:
messages += response.inner_messages
messages.append(response.chat_message)
return TaskResult(messages=messages)

View File

@ -3,6 +3,7 @@ from typing import List, Sequence
from autogen_core.base import CancellationToken
from autogen_core.components.code_executor import CodeBlock, CodeExecutor, extract_markdown_code_blocks
from ..base import Response
from ..messages import ChatMessage, TextMessage
from ._base_chat_agent import BaseChatAgent
@ -20,7 +21,12 @@ class CodeExecutorAgent(BaseChatAgent):
super().__init__(name=name, description=description)
self._code_executor = code_executor
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
"""The types of messages that the code executor agent produces."""
return [TextMessage]
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
# Extract code blocks from the messages.
code_blocks: List[CodeBlock] = []
for msg in messages:
@ -29,6 +35,6 @@ class CodeExecutorAgent(BaseChatAgent):
if code_blocks:
# Execute the code blocks.
result = await self._code_executor.execute_code_blocks(code_blocks, cancellation_token=cancellation_token)
return TextMessage(content=result.output, source=self.name)
return Response(chat_message=TextMessage(content=result.output, source=self.name))
else:
return TextMessage(content="No code blocks found in the thread.", source=self.name)
return Response(chat_message=TextMessage(content="No code blocks found in the thread.", source=self.name))

View File

@ -1,10 +1,11 @@
from ._chat_agent import ChatAgent
from ._chat_agent import ChatAgent, Response
from ._task import TaskResult, TaskRunner
from ._team import Team
from ._termination import TerminatedException, TerminationCondition
__all__ = [
"ChatAgent",
"Response",
"Team",
"TerminatedException",
"TerminationCondition",

View File

@ -1,12 +1,24 @@
from typing import Protocol, Sequence, runtime_checkable
from dataclasses import dataclass
from typing import List, Protocol, Sequence, runtime_checkable
from autogen_core.base import CancellationToken
from ..messages import ChatMessage
from ..messages import ChatMessage, InnerMessage
from ._task import TaskResult, TaskRunner
from ._termination import TerminationCondition
@dataclass(kw_only=True)
class Response:
"""A response from calling :meth:`ChatAgent.on_messages`."""
chat_message: ChatMessage
"""A chat message produced by the agent as the response."""
inner_messages: List[InnerMessage] | None = None
"""Inner messages produced by the agent."""
@runtime_checkable
class ChatAgent(TaskRunner, Protocol):
"""Protocol for a chat agent."""
@ -24,8 +36,13 @@ class ChatAgent(TaskRunner, Protocol):
describe the agent's capabilities and how to interact with it."""
...
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
"""Handle incoming messages and return a response message."""
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
"""The types of messages that the agent produces."""
...
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
"""Handles incoming messages and returns a response."""
...
async def run(

View File

@ -3,7 +3,7 @@ from typing import Protocol, Sequence
from autogen_core.base import CancellationToken
from ..messages import ChatMessage
from ..messages import ChatMessage, InnerMessage
from ._termination import TerminationCondition
@ -11,7 +11,7 @@ from ._termination import TerminationCondition
class TaskResult:
"""Result of running a task."""
messages: Sequence[ChatMessage]
messages: Sequence[InnerMessage | ChatMessage]
"""Messages produced by the task."""

View File

@ -1,6 +1,7 @@
from typing import List
from autogen_core.components import Image
from autogen_core.components import FunctionCall, Image
from autogen_core.components.models import FunctionExecutionResult
from pydantic import BaseModel
@ -49,8 +50,26 @@ class ResetMessage(BaseMessage):
"""The content for the reset message."""
class ToolCallMessage(BaseMessage):
"""A message signaling the use of tools."""
content: List[FunctionCall]
"""The tool calls."""
class ToolCallResultMessages(BaseMessage):
"""A message signaling the results of tool calls."""
content: List[FunctionExecutionResult]
"""The tool call results."""
InnerMessage = ToolCallMessage | ToolCallResultMessages
"""Messages for intra-agent monologues."""
ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ResetMessage
"""A message used by agents in a team."""
"""Messages for agent-to-agent communication."""
__all__ = [
@ -60,5 +79,7 @@ __all__ = [
"StopMessage",
"HandoffMessage",
"ResetMessage",
"ToolCallMessage",
"ToolCallResultMessages",
"ChatMessage",
]

View File

@ -15,7 +15,7 @@ from autogen_core.base import (
from autogen_core.components import ClosureAgent, TypeSubscription
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
from ...messages import ChatMessage, TextMessage
from ...messages import ChatMessage, InnerMessage, TextMessage
from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent
from ._base_group_chat_manager import BaseGroupChatManager
from ._chat_agent_container import ChatAgentContainer
@ -56,12 +56,13 @@ class BaseGroupChat(Team, ABC):
def _create_participant_factory(
self,
parent_topic_type: str,
output_topic_type: str,
agent: ChatAgent,
) -> Callable[[], ChatAgentContainer]:
def _factory() -> ChatAgentContainer:
id = AgentInstantiationContext.current_agent_id()
assert id == AgentId(type=agent.name, key=self._team_id)
container = ChatAgentContainer(parent_topic_type, agent)
container = ChatAgentContainer(parent_topic_type, output_topic_type, agent)
assert container.id == id
return container
@ -85,6 +86,7 @@ class BaseGroupChat(Team, ABC):
group_chat_manager_topic_type = group_chat_manager_agent_type.type
group_topic_type = "round_robin_group_topic"
team_topic_type = "team_topic"
output_topic_type = "output_topic"
# Register participants.
participant_topic_types: List[str] = []
@ -97,7 +99,7 @@ class BaseGroupChat(Team, ABC):
await ChatAgentContainer.register(
runtime,
type=agent_type,
factory=self._create_participant_factory(group_topic_type, participant),
factory=self._create_participant_factory(group_topic_type, output_topic_type, participant),
)
# Add subscriptions for the participant.
await runtime.add_subscription(TypeSubscription(topic_type=topic_type, agent_type=agent_type))
@ -129,22 +131,22 @@ class BaseGroupChat(Team, ABC):
TypeSubscription(topic_type=team_topic_type, agent_type=group_chat_manager_agent_type.type)
)
group_chat_messages: List[ChatMessage] = []
output_messages: List[InnerMessage | ChatMessage] = []
async def collect_group_chat_messages(
async def collect_output_messages(
_runtime: AgentRuntime,
id: AgentId,
message: GroupChatPublishEvent,
message: InnerMessage | ChatMessage,
ctx: MessageContext,
) -> None:
group_chat_messages.append(message.agent_message)
output_messages.append(message)
await ClosureAgent.register(
runtime,
type="collect_group_chat_messages",
closure=collect_group_chat_messages,
type="collect_output_messages",
closure=collect_output_messages,
subscriptions=lambda: [
TypeSubscription(topic_type=group_topic_type, agent_type="collect_group_chat_messages"),
TypeSubscription(topic_type=output_topic_type, agent_type="collect_output_messages"),
],
)
@ -154,8 +156,10 @@ class BaseGroupChat(Team, ABC):
# Run the team by publishing the task to the team topic and then requesting the result.
team_topic_id = TopicId(type=team_topic_type, source=self._team_id)
group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id)
first_chat_message = TextMessage(content=task, source="user")
output_messages.append(first_chat_message)
await runtime.publish_message(
GroupChatPublishEvent(agent_message=TextMessage(content=task, source="user")),
GroupChatPublishEvent(agent_message=first_chat_message),
topic_id=team_topic_id,
)
await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id)
@ -164,4 +168,4 @@ class BaseGroupChat(Team, ABC):
await runtime.stop_when_idle()
# Return the result.
return TaskResult(messages=group_chat_messages)
return TaskResult(messages=output_messages)

View File

@ -16,12 +16,14 @@ class ChatAgentContainer(SequentialRoutedAgent):
Args:
parent_topic_type (str): The topic type of the parent orchestrator.
output_topic_type (str): The topic type for the output.
agent (ChatAgent): The agent to delegate message handling to.
"""
def __init__(self, parent_topic_type: str, agent: ChatAgent) -> None:
def __init__(self, parent_topic_type: str, output_topic_type: str, agent: ChatAgent) -> None:
super().__init__(description=agent.description)
self._parent_topic_type = parent_topic_type
self._output_topic_type = output_topic_type
self._agent = agent
self._message_buffer: List[ChatMessage] = []
@ -36,13 +38,27 @@ class ChatAgentContainer(SequentialRoutedAgent):
to the delegate agent and publish the response."""
# Pass the messages in the buffer to the delegate agent.
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)
if not any(isinstance(response.chat_message, msg_type) for msg_type in self._agent.produced_message_types):
raise ValueError(
f"The agent {self._agent.name} produced an unexpected message type: {type(response)}. "
f"Expected one of: {self._agent.produced_message_types}. "
f"Check the agent's produced_message_types property."
)
# Publish inner messages to the output topic.
if response.inner_messages is not None:
for inner_message in response.inner_messages:
await self.publish_message(inner_message, topic_id=DefaultTopicId(type=self._output_topic_type))
# Publish the response.
self._message_buffer.clear()
await self.publish_message(
GroupChatPublishEvent(agent_message=response, source=self.id),
GroupChatPublishEvent(agent_message=response.chat_message, source=self.id),
topic_id=DefaultTopicId(type=self._parent_topic_type),
)
# Publish the response to the output topic.
await self.publish_message(response.chat_message, topic_id=DefaultTopicId(type=self._output_topic_type))
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
raise ValueError(f"Unhandled message in agent container: {type(message)}")

View File

@ -86,6 +86,10 @@ class Swarm(BaseGroupChat):
super().__init__(
participants, termination_condition=termination_condition, group_chat_manager_class=SwarmGroupChatManager
)
# The first participant must be able to produce handoff messages.
first_participant = self._participants[0]
if HandoffMessage not in first_participant.produced_message_types:
raise ValueError("The first participant must be able to produce a handoff messages.")
def _create_group_chat_manager_factory(
self,

View File

@ -7,7 +7,7 @@ import pytest
from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat.agents import AssistantAgent, Handoff
from autogen_agentchat.logging import FileLogHandler
from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage
from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessages
from autogen_core.base import CancellationToken
from autogen_core.components.tools import FunctionTool
from autogen_ext.models import OpenAIChatCompletionClient
@ -111,10 +111,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
)
result = await tool_use_agent.run("task")
assert len(result.messages) == 3
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert isinstance(result.messages[1], TextMessage)
assert isinstance(result.messages[2], StopMessage)
assert isinstance(result.messages[1], ToolCallMessage)
assert isinstance(result.messages[2], ToolCallResultMessages)
assert isinstance(result.messages[3], TextMessage)
@pytest.mark.asyncio
@ -158,8 +159,9 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
handoffs=[handoff],
)
assert HandoffMessage in tool_use_agent.produced_message_types
response = await tool_use_agent.on_messages(
[TextMessage(content="task", source="user")], cancellation_token=CancellationToken()
)
assert isinstance(response, HandoffMessage)
assert response.target == "agent2"
assert isinstance(response.chat_message, HandoffMessage)
assert response.chat_message.target == "agent2"

View File

@ -12,12 +12,15 @@ from autogen_agentchat.agents import (
CodeExecutorAgent,
Handoff,
)
from autogen_agentchat.base import Response
from autogen_agentchat.logging import FileLogHandler
from autogen_agentchat.messages import (
ChatMessage,
HandoffMessage,
StopMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessages,
)
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination
from autogen_agentchat.teams import (
@ -62,14 +65,18 @@ class _EchoAgent(BaseChatAgent):
super().__init__(name, description)
self._last_message: str | None = None
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
return [TextMessage]
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
if len(messages) > 0:
assert isinstance(messages[0], TextMessage)
self._last_message = messages[0].content
return TextMessage(content=messages[0].content, source=self.name)
return Response(chat_message=TextMessage(content=messages[0].content, source=self.name))
else:
assert self._last_message is not None
return TextMessage(content=self._last_message, source=self.name)
return Response(chat_message=TextMessage(content=self._last_message, source=self.name))
class _StopAgent(_EchoAgent):
@ -78,11 +85,15 @@ class _StopAgent(_EchoAgent):
self._count = 0
self._stop_at = stop_at
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
return [TextMessage, StopMessage]
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
self._count += 1
if self._count < self._stop_at:
return await super().on_messages(messages, cancellation_token)
return StopMessage(content="TERMINATE", source=self.name)
return Response(chat_message=StopMessage(content="TERMINATE", source=self.name))
def _pass_function(input: str) -> str:
@ -222,11 +233,13 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
)
assert len(result.messages) == 4
assert len(result.messages) == 6
assert isinstance(result.messages[0], TextMessage) # task
assert isinstance(result.messages[1], TextMessage) # tool use agent response
assert isinstance(result.messages[2], TextMessage) # echo agent response
assert isinstance(result.messages[3], StopMessage) # tool use agent response
assert isinstance(result.messages[1], ToolCallMessage) # tool call
assert isinstance(result.messages[2], ToolCallResultMessages) # tool call result
assert isinstance(result.messages[3], TextMessage) # tool use agent response
assert isinstance(result.messages[4], TextMessage) # echo agent response
assert isinstance(result.messages[5], StopMessage) # tool use agent response
context = tool_use_agent._model_context # pyright: ignore
assert context[0].content == "Write a program that prints 'Hello, world!'"
@ -415,8 +428,16 @@ class _HandOffAgent(BaseChatAgent):
super().__init__(name, description)
self._next_agent = next_agent
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
return HandoffMessage(content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name)
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
return [HandoffMessage]
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
return Response(
chat_message=HandoffMessage(
content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name
)
)
@pytest.mark.asyncio
@ -501,9 +522,11 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
team = Swarm([agnet1, agent2])
result = await team.run("task", termination_condition=StopMessageTermination())
assert len(result.messages) == 5
assert len(result.messages) == 7
assert result.messages[0].content == "task"
assert result.messages[1].content == "handoff to agent2"
assert result.messages[2].content == "Transferred to agent1."
assert result.messages[3].content == "Hello"
assert result.messages[4].content == "TERMINATE"
assert isinstance(result.messages[1], ToolCallMessage)
assert isinstance(result.messages[2], ToolCallResultMessages)
assert result.messages[3].content == "handoff to agent2"
assert result.messages[4].content == "Transferred to agent1."
assert result.messages[5].content == "Hello"
assert result.messages[6].content == "TERMINATE"

View File

@ -248,9 +248,10 @@
],
"source": [
"import asyncio\n",
"from typing import Sequence\n",
"from typing import List, Sequence\n",
"\n",
"from autogen_agentchat.agents import BaseChatAgent\n",
"from autogen_agentchat.base import Response\n",
"from autogen_agentchat.messages import (\n",
" ChatMessage,\n",
" StopMessage,\n",
@ -262,11 +263,15 @@
" def __init__(self, name: str) -> None:\n",
" super().__init__(name, \"A human user.\")\n",
"\n",
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:\n",
" @property\n",
" def produced_message_types(self) -> List[type[ChatMessage]]:\n",
" return [TextMessage, StopMessage]\n",
"\n",
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
" user_input = await asyncio.get_event_loop().run_in_executor(None, input, \"Enter your response: \")\n",
" if \"TERMINATE\" in user_input:\n",
" return StopMessage(content=\"User has terminated the conversation.\", source=self.name)\n",
" return TextMessage(content=user_input, source=self.name)\n",
" return Response(chat_message=StopMessage(content=\"User has terminated the conversation.\", source=self.name))\n",
" return Response(chat_message=TextMessage(content=user_input, source=self.name))\n",
"\n",
"\n",
"user_proxy_agent = UserProxyAgent(name=\"user_proxy_agent\")\n",
@ -312,7 +317,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
"version": "3.11.5"
}
},
"nbformat": 4,

View File

@ -38,13 +38,14 @@
"outputs": [],
"source": [
"import asyncio\n",
"from typing import Sequence\n",
"from typing import List, Sequence\n",
"\n",
"from autogen_agentchat.agents import (\n",
" BaseChatAgent,\n",
" CodingAssistantAgent,\n",
" ToolUseAssistantAgent,\n",
")\n",
"from autogen_agentchat.base import Response\n",
"from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage\n",
"from autogen_agentchat.task import StopMessageTermination\n",
"from autogen_agentchat.teams import SelectorGroupChat\n",
@ -71,11 +72,15 @@
" def __init__(self, name: str) -> None:\n",
" super().__init__(name, \"A human user.\")\n",
"\n",
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:\n",
" @property\n",
" def produced_message_types(self) -> List[type[ChatMessage]]:\n",
" return [TextMessage, StopMessage]\n",
"\n",
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
" user_input = await asyncio.get_event_loop().run_in_executor(None, input, \"Enter your response: \")\n",
" if \"TERMINATE\" in user_input:\n",
" return StopMessage(content=\"User has terminated the conversation.\", source=self.name)\n",
" return TextMessage(content=user_input, source=self.name)"
" return Response(chat_message=StopMessage(content=\"User has terminated the conversation.\", source=self.name))\n",
" return Response(chat_message=TextMessage(content=user_input, source=self.name))"
]
},
{
@ -269,7 +274,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
"version": "3.11.5"
}
},
"nbformat": 4,