mirror of https://github.com/microsoft/autogen.git
Merge branch 'main' into u/refactor
This commit is contained in:
commit
9d052a1661
|
@ -217,13 +217,14 @@ dotnet_diagnostic.IDE0161.severity = warning # Use file-scoped namespace
|
|||
|
||||
csharp_style_var_elsewhere = true:suggestion # Prefer 'var' everywhere
|
||||
csharp_prefer_simple_using_statement = true:suggestion
|
||||
csharp_style_namespace_declarations = block_scoped:silent
|
||||
csharp_style_namespace_declarations = file_scoped:warning
|
||||
csharp_style_prefer_method_group_conversion = true:silent
|
||||
csharp_style_prefer_top_level_statements = true:silent
|
||||
csharp_style_prefer_primary_constructors = true:suggestion
|
||||
csharp_style_expression_bodied_lambdas = true:silent
|
||||
csharp_style_prefer_local_over_anonymous_function = true:suggestion
|
||||
dotnet_diagnostic.CA2016.severity = suggestion
|
||||
csharp_prefer_static_anonymous_function = true:suggestion
|
||||
|
||||
# disable check for generated code
|
||||
[*.generated.cs]
|
||||
|
@ -693,6 +694,7 @@ dotnet_style_prefer_compound_assignment = true:suggestion
|
|||
dotnet_style_prefer_simplified_interpolation = true:suggestion
|
||||
dotnet_style_prefer_collection_expression = when_types_loosely_match:suggestion
|
||||
dotnet_style_namespace_match_folder = true:suggestion
|
||||
dotnet_style_qualification_for_method = false:silent
|
||||
|
||||
[**/*.g.cs]
|
||||
generated_code = true
|
||||
|
|
|
@ -127,7 +127,7 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.Agents.Te
|
|||
EndProject
|
||||
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "sample", "sample", "{686480D7-8FEC-4ED3-9C5D-CEBE1057A7ED}"
|
||||
EndProject
|
||||
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "HelloAgentState", "samples\Hello\HelloAgentState\HelloAgentState.csproj", "{64EF61E7-00A6-4E5E-9808-62E10993A0E5}"
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HelloAgentState", "samples\Hello\HelloAgentState\HelloAgentState.csproj", "{64EF61E7-00A6-4E5E-9808-62E10993A0E5}"
|
||||
EndProject
|
||||
Global
|
||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution
|
||||
|
|
|
@ -27,10 +27,11 @@ namespace Hello
|
|||
[TopicSubscription("HelloAgents")]
|
||||
public class HelloAgent(
|
||||
IAgentContext context,
|
||||
[FromKeyedServices("EventTypes")] EventTypes typeRegistry) : ConsoleAgent(
|
||||
[FromKeyedServices("EventTypes")] EventTypes typeRegistry) : AgentBase(
|
||||
context,
|
||||
typeRegistry),
|
||||
ISayHello,
|
||||
IHandleConsole,
|
||||
IHandle<NewMessageReceived>,
|
||||
IHandle<ConversationClosed>
|
||||
{
|
||||
|
|
|
@ -5,45 +5,44 @@ using System;
|
|||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
|
||||
namespace AutoGen
|
||||
namespace AutoGen;
|
||||
|
||||
public static class LLMConfigAPI
|
||||
{
|
||||
public static class LLMConfigAPI
|
||||
public static IEnumerable<ILLMConfig> GetOpenAIConfigList(
|
||||
string apiKey,
|
||||
IEnumerable<string>? modelIDs = null)
|
||||
{
|
||||
public static IEnumerable<ILLMConfig> GetOpenAIConfigList(
|
||||
string apiKey,
|
||||
IEnumerable<string>? modelIDs = null)
|
||||
var models = modelIDs ?? new[]
|
||||
{
|
||||
var models = modelIDs ?? new[]
|
||||
{
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-4-1106-preview",
|
||||
};
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-4-1106-preview",
|
||||
};
|
||||
|
||||
return models.Select(modelId => new OpenAIConfig(apiKey, modelId));
|
||||
}
|
||||
return models.Select(modelId => new OpenAIConfig(apiKey, modelId));
|
||||
}
|
||||
|
||||
public static IEnumerable<ILLMConfig> GetAzureOpenAIConfigList(
|
||||
string endpoint,
|
||||
string apiKey,
|
||||
IEnumerable<string> deploymentNames)
|
||||
{
|
||||
return deploymentNames.Select(deploymentName => new AzureOpenAIConfig(endpoint, deploymentName, apiKey));
|
||||
}
|
||||
public static IEnumerable<ILLMConfig> GetAzureOpenAIConfigList(
|
||||
string endpoint,
|
||||
string apiKey,
|
||||
IEnumerable<string> deploymentNames)
|
||||
{
|
||||
return deploymentNames.Select(deploymentName => new AzureOpenAIConfig(endpoint, deploymentName, apiKey));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get a list of LLMConfig objects from a JSON file.
|
||||
/// </summary>
|
||||
internal static IEnumerable<ILLMConfig> ConfigListFromJson(
|
||||
string filePath,
|
||||
IEnumerable<string>? filterModels = null)
|
||||
{
|
||||
// Disable this API from documentation for now.
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
/// <summary>
|
||||
/// Get a list of LLMConfig objects from a JSON file.
|
||||
/// </summary>
|
||||
internal static IEnumerable<ILLMConfig> ConfigListFromJson(
|
||||
string filePath,
|
||||
IEnumerable<string>? filterModels = null)
|
||||
{
|
||||
// Disable this API from documentation for now.
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
// AgentBase.cs
|
||||
|
||||
using System.Diagnostics;
|
||||
using System.Reflection;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Threading.Channels;
|
||||
|
@ -20,6 +21,8 @@ public abstract class AgentBase : IAgentBase, IHandle
|
|||
|
||||
private readonly Channel<object> _mailbox = Channel.CreateUnbounded<object>();
|
||||
private readonly IAgentContext _context;
|
||||
public string Route { get; set; } = "base";
|
||||
|
||||
protected internal ILogger Logger => _context.Logger;
|
||||
public IAgentContext Context => _context;
|
||||
protected readonly EventTypes EventTypes;
|
||||
|
@ -215,14 +218,39 @@ public abstract class AgentBase : IAgentBase, IHandle
|
|||
public Task CallHandler(CloudEvent item)
|
||||
{
|
||||
// Only send the event to the handler if the agent type is handling that type
|
||||
if (EventTypes.EventsMap[GetType()].Contains(item.Type))
|
||||
// foreach of the keys in the EventTypes.EventsMap[] if it contains the item.type
|
||||
foreach (var key in EventTypes.EventsMap.Keys)
|
||||
{
|
||||
var payload = item.ProtoData.Unpack(EventTypes.TypeRegistry);
|
||||
var convertedPayload = Convert.ChangeType(payload, EventTypes.Types[item.Type]);
|
||||
var genericInterfaceType = typeof(IHandle<>).MakeGenericType(EventTypes.Types[item.Type]);
|
||||
var methodInfo = genericInterfaceType.GetMethod(nameof(IHandle<object>.Handle)) ?? throw new InvalidOperationException($"Method not found on type {genericInterfaceType.FullName}");
|
||||
return methodInfo.Invoke(this, [payload]) as Task ?? Task.CompletedTask;
|
||||
if (EventTypes.EventsMap[key].Contains(item.Type))
|
||||
{
|
||||
var payload = item.ProtoData.Unpack(EventTypes.TypeRegistry);
|
||||
var convertedPayload = Convert.ChangeType(payload, EventTypes.Types[item.Type]);
|
||||
var genericInterfaceType = typeof(IHandle<>).MakeGenericType(EventTypes.Types[item.Type]);
|
||||
|
||||
MethodInfo methodInfo;
|
||||
try
|
||||
{
|
||||
// check that our target actually implements this interface, otherwise call the default static
|
||||
if (genericInterfaceType.IsAssignableFrom(this.GetType()))
|
||||
{
|
||||
methodInfo = genericInterfaceType.GetMethod(nameof(IHandle<object>.Handle), BindingFlags.Public | BindingFlags.Instance)
|
||||
?? throw new InvalidOperationException($"Method not found on type {genericInterfaceType.FullName}");
|
||||
return methodInfo.Invoke(this, [payload]) as Task ?? Task.CompletedTask;
|
||||
}
|
||||
else
|
||||
{
|
||||
// The error here is we have registered for an event that we do not have code to listen to
|
||||
throw new InvalidOperationException($"No handler found for event '{item.Type}'; expecting IHandle<{item.Type}> implementation.");
|
||||
}
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
Logger.LogError(ex, $"Error invoking method {nameof(IHandle<object>.Handle)}");
|
||||
throw; // TODO: ?
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
using Microsoft.AutoGen.Abstractions;
|
||||
|
||||
namespace Microsoft.AutoGen.Agents;
|
||||
|
||||
public interface IHandleConsole : IHandle<Output>, IHandle<Input>
|
||||
{
|
||||
string Route { get; }
|
||||
AgentId AgentId { get; }
|
||||
ValueTask PublishEvent(CloudEvent item);
|
||||
|
||||
async Task IHandle<Output>.Handle(Output item)
|
||||
{
|
||||
// Assuming item has a property `Message` that we want to write to the console
|
||||
Console.WriteLine(item.Message);
|
||||
await ProcessOutput(item.Message);
|
||||
|
||||
var evt = new OutputWritten
|
||||
{
|
||||
Route = "console"
|
||||
}.ToCloudEvent(AgentId.Key);
|
||||
await PublishEvent(evt);
|
||||
}
|
||||
async Task IHandle<Input>.Handle(Input item)
|
||||
{
|
||||
Console.WriteLine("Please enter input:");
|
||||
string content = Console.ReadLine() ?? string.Empty;
|
||||
|
||||
await ProcessInput(content);
|
||||
|
||||
var evt = new InputProcessed
|
||||
{
|
||||
Route = "console"
|
||||
}.ToCloudEvent(AgentId.Key);
|
||||
await PublishEvent(evt);
|
||||
}
|
||||
static Task ProcessOutput(string message)
|
||||
{
|
||||
// Implement your output processing logic here
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
static Task<string> ProcessInput(string message)
|
||||
{
|
||||
// Implement your input processing logic here
|
||||
return Task.FromResult(message);
|
||||
}
|
||||
}
|
|
@ -86,6 +86,7 @@ public sealed class GrpcAgentWorkerRuntime : IHostedService, IDisposable, IAgent
|
|||
message.Response.RequestId = request.OriginalRequestId;
|
||||
request.Agent.ReceiveMessage(message);
|
||||
break;
|
||||
|
||||
case Message.MessageOneofCase.RegisterAgentTypeResponse:
|
||||
if (!message.RegisterAgentTypeResponse.Success)
|
||||
{
|
||||
|
|
|
@ -74,7 +74,52 @@ public static class HostBuilderExtensions
|
|||
.Where(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IHandle<>))
|
||||
.Select(i => (GetMessageDescriptor(i.GetGenericArguments().First())?.FullName ?? "")).ToHashSet()))
|
||||
.ToDictionary(item => item.t, item => item.Item2);
|
||||
// if the assembly contains any interfaces of type IHandler, then add all the methods of the interface to the eventsMap
|
||||
var handlersMap = AppDomain.CurrentDomain.GetAssemblies()
|
||||
.SelectMany(assembly => assembly.GetTypes())
|
||||
.Where(type => ReflectionHelper.IsSubclassOfGeneric(type, typeof(AgentBase)) && !type.IsAbstract)
|
||||
.Select(t => (t, t.GetMethods()
|
||||
.Where(m => m.Name == "Handle")
|
||||
.Select(m => (GetMessageDescriptor(m.GetParameters().First().ParameterType)?.FullName ?? "")).ToHashSet()))
|
||||
.ToDictionary(item => item.t, item => item.Item2);
|
||||
// get interfaces implemented by the agent and get the methods of the interface if they are named Handle
|
||||
var ifaceHandlersMap = AppDomain.CurrentDomain.GetAssemblies()
|
||||
.SelectMany(assembly => assembly.GetTypes())
|
||||
.Where(type => ReflectionHelper.IsSubclassOfGeneric(type, typeof(AgentBase)) && !type.IsAbstract)
|
||||
.Select(t => t.GetInterfaces()
|
||||
.Select(i => (t, i, i.GetMethods()
|
||||
.Where(m => m.Name == "Handle")
|
||||
.Select(m => (GetMessageDescriptor(m.GetParameters().First().ParameterType)?.FullName ?? ""))
|
||||
//to dictionary of type t and paramter type of the method
|
||||
.ToDictionary(m => m, m => m).Keys.ToHashSet())).ToList());
|
||||
// for each item in ifaceHandlersMap, add the handlers to eventsMap with item as the key
|
||||
foreach (var item in ifaceHandlersMap)
|
||||
{
|
||||
foreach (var iface in item)
|
||||
{
|
||||
if (eventsMap.TryGetValue(iface.Item2, out var events))
|
||||
{
|
||||
events.UnionWith(iface.Item3);
|
||||
}
|
||||
else
|
||||
{
|
||||
eventsMap[iface.Item2] = iface.Item3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// merge the handlersMap into the eventsMap
|
||||
foreach (var item in handlersMap)
|
||||
{
|
||||
if (eventsMap.TryGetValue(item.Key, out var events))
|
||||
{
|
||||
events.UnionWith(item.Value);
|
||||
}
|
||||
else
|
||||
{
|
||||
eventsMap[item.Key] = item.Value;
|
||||
}
|
||||
}
|
||||
return new EventTypes(typeRegistry, types, eventsMap);
|
||||
});
|
||||
return new AgentApplicationBuilder(builder);
|
||||
|
|
|
@ -3,34 +3,33 @@
|
|||
|
||||
using Microsoft.Extensions.AI;
|
||||
|
||||
namespace Microsoft.Extensions.Hosting
|
||||
{
|
||||
public static class AIModelClient
|
||||
{
|
||||
public static IHostApplicationBuilder AddChatCompletionService(this IHostApplicationBuilder builder, string serviceName)
|
||||
{
|
||||
var pipeline = (ChatClientBuilder pipeline) => pipeline
|
||||
.UseLogging()
|
||||
.UseFunctionInvocation()
|
||||
.UseOpenTelemetry(configure: c => c.EnableSensitiveData = true);
|
||||
namespace Microsoft.Extensions.Hosting;
|
||||
|
||||
if (builder.Configuration[$"{serviceName}:ModelType"] == "ollama")
|
||||
{
|
||||
builder.AddOllamaChatClient(serviceName, pipeline);
|
||||
}
|
||||
else if (builder.Configuration[$"{serviceName}:ModelType"] == "openai" || builder.Configuration[$"{serviceName}:ModelType"] == "azureopenai")
|
||||
{
|
||||
builder.AddOpenAIChatClient(serviceName, pipeline);
|
||||
}
|
||||
else if (builder.Configuration[$"{serviceName}:ModelType"] == "azureaiinference")
|
||||
{
|
||||
builder.AddAzureChatClient(serviceName, pipeline);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("Did not find a valid model implementation for the given service name ${serviceName}, valid supported implemenation types are ollama, openai, azureopenai, azureaiinference");
|
||||
}
|
||||
return builder;
|
||||
public static class AIModelClient
|
||||
{
|
||||
public static IHostApplicationBuilder AddChatCompletionService(this IHostApplicationBuilder builder, string serviceName)
|
||||
{
|
||||
var pipeline = (ChatClientBuilder pipeline) => pipeline
|
||||
.UseLogging()
|
||||
.UseFunctionInvocation()
|
||||
.UseOpenTelemetry(configure: c => c.EnableSensitiveData = true);
|
||||
|
||||
if (builder.Configuration[$"{serviceName}:ModelType"] == "ollama")
|
||||
{
|
||||
builder.AddOllamaChatClient(serviceName, pipeline);
|
||||
}
|
||||
else if (builder.Configuration[$"{serviceName}:ModelType"] == "openai" || builder.Configuration[$"{serviceName}:ModelType"] == "azureopenai")
|
||||
{
|
||||
builder.AddOpenAIChatClient(serviceName, pipeline);
|
||||
}
|
||||
else if (builder.Configuration[$"{serviceName}:ModelType"] == "azureaiinference")
|
||||
{
|
||||
builder.AddAzureChatClient(serviceName, pipeline);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("Did not find a valid model implementation for the given service name ${serviceName}, valid supported implemenation types are ollama, openai, azureopenai, azureaiinference");
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,206 +14,205 @@ using FluentAssertions;
|
|||
using OpenAI;
|
||||
using Xunit.Abstractions;
|
||||
|
||||
namespace AutoGen.OpenAI.Tests
|
||||
namespace AutoGen.OpenAI.Tests;
|
||||
|
||||
public partial class MathClassTest
|
||||
{
|
||||
public partial class MathClassTest
|
||||
private readonly ITestOutputHelper _output;
|
||||
|
||||
// as of 2024-05-20, aoai return 500 error when round > 1
|
||||
// I'm pretty sure that round > 5 was supported before
|
||||
// So this is probably some wield regression on aoai side
|
||||
// I'll keep this test case here for now, plus setting round to 1
|
||||
// so the test can still pass.
|
||||
// In the future, we should rewind this test case to round > 1 (previously was 5)
|
||||
private int round = 1;
|
||||
public MathClassTest(ITestOutputHelper output)
|
||||
{
|
||||
private readonly ITestOutputHelper _output;
|
||||
_output = output;
|
||||
}
|
||||
|
||||
// as of 2024-05-20, aoai return 500 error when round > 1
|
||||
// I'm pretty sure that round > 5 was supported before
|
||||
// So this is probably some wield regression on aoai side
|
||||
// I'll keep this test case here for now, plus setting round to 1
|
||||
// so the test can still pass.
|
||||
// In the future, we should rewind this test case to round > 1 (previously was 5)
|
||||
private int round = 1;
|
||||
public MathClassTest(ITestOutputHelper output)
|
||||
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
|
||||
{
|
||||
try
|
||||
{
|
||||
_output = output;
|
||||
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
|
||||
|
||||
_output.WriteLine(reply.FormatMessage());
|
||||
return Task.FromResult(reply);
|
||||
}
|
||||
|
||||
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
|
||||
catch (Exception)
|
||||
{
|
||||
try
|
||||
_output.WriteLine("Request failed");
|
||||
_output.WriteLine($"agent name: {agent.Name}");
|
||||
foreach (var message in messages)
|
||||
{
|
||||
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
|
||||
|
||||
_output.WriteLine(reply.FormatMessage());
|
||||
return Task.FromResult(reply);
|
||||
}
|
||||
catch (Exception)
|
||||
{
|
||||
_output.WriteLine("Request failed");
|
||||
_output.WriteLine($"agent name: {agent.Name}");
|
||||
foreach (var message in messages)
|
||||
{
|
||||
_output.WriteLine(message.FormatMessage());
|
||||
}
|
||||
|
||||
throw;
|
||||
_output.WriteLine(message.FormatMessage());
|
||||
}
|
||||
|
||||
throw;
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> CreateMathQuestion(string question, int question_index)
|
||||
{
|
||||
return $@"[MATH_QUESTION]
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> CreateMathQuestion(string question, int question_index)
|
||||
{
|
||||
return $@"[MATH_QUESTION]
|
||||
Question {question_index}:
|
||||
{question}
|
||||
|
||||
Student, please answer";
|
||||
}
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerQuestion(string answer)
|
||||
{
|
||||
return $@"[MATH_ANSWER]
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerQuestion(string answer)
|
||||
{
|
||||
return $@"[MATH_ANSWER]
|
||||
The answer is {answer}
|
||||
teacher please check answer";
|
||||
}
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerIsCorrect(string message)
|
||||
{
|
||||
return $@"[ANSWER_IS_CORRECT]
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerIsCorrect(string message)
|
||||
{
|
||||
return $@"[ANSWER_IS_CORRECT]
|
||||
{message}
|
||||
please update progress";
|
||||
}
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> UpdateProgress(int correctAnswerCount)
|
||||
[FunctionAttribute]
|
||||
public async Task<string> UpdateProgress(int correctAnswerCount)
|
||||
{
|
||||
if (correctAnswerCount >= this.round)
|
||||
{
|
||||
if (correctAnswerCount >= this.round)
|
||||
{
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
{GroupChatExtension.TERMINATE}";
|
||||
}
|
||||
else
|
||||
{
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
}
|
||||
else
|
||||
{
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
the number of resolved question is {correctAnswerCount}
|
||||
teacher, please create the next math question";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task OpenAIAgentMathChatTestAsync()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
|
||||
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
|
||||
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
|
||||
var openaiClient = new AzureOpenAIClient(new Uri(endPoint), new ApiKeyCredential(key));
|
||||
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
|
||||
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task OpenAIAgentMathChatTestAsync()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
|
||||
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
|
||||
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
|
||||
var openaiClient = new AzureOpenAIClient(new Uri(endPoint), new ApiKeyCredential(key));
|
||||
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
|
||||
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
|
||||
|
||||
var adminFunctionMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.UpdateProgressFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
|
||||
});
|
||||
var admin = new OpenAIChatAgent(
|
||||
chatClient: openaiClient.GetChatClient(deployName),
|
||||
name: "Admin",
|
||||
systemMessage: $@"You are admin. You update progress after each question is answered.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(adminFunctionMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
var adminFunctionMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.UpdateProgressFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
|
||||
});
|
||||
var admin = new OpenAIChatAgent(
|
||||
chatClient: openaiClient.GetChatClient(deployName),
|
||||
name: "Admin",
|
||||
systemMessage: $@"You are admin. You update progress after each question is answered.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(adminFunctionMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
|
||||
var groupAdmin = new OpenAIChatAgent(
|
||||
chatClient: openaiClient.GetChatClient(deployName),
|
||||
name: "GroupAdmin",
|
||||
systemMessage: "You are group admin. You manage the group chat.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterMiddleware(Print);
|
||||
await RunMathChatAsync(teacher, student, admin, groupAdmin);
|
||||
}
|
||||
var groupAdmin = new OpenAIChatAgent(
|
||||
chatClient: openaiClient.GetChatClient(deployName),
|
||||
name: "GroupAdmin",
|
||||
systemMessage: "You are group admin. You manage the group chat.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterMiddleware(Print);
|
||||
await RunMathChatAsync(teacher, student, admin, groupAdmin);
|
||||
}
|
||||
|
||||
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
|
||||
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
|
||||
});
|
||||
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
|
||||
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
|
||||
});
|
||||
|
||||
var teacher = new OpenAIChatAgent(
|
||||
chatClient: client.GetChatClient(model),
|
||||
name: "Teacher",
|
||||
systemMessage: @"You are a preschool math teacher.
|
||||
var teacher = new OpenAIChatAgent(
|
||||
chatClient: client.GetChatClient(model),
|
||||
name: "Teacher",
|
||||
systemMessage: @"You are a preschool math teacher.
|
||||
You create math question and ask student to answer it.
|
||||
Then you check if the answer is correct.
|
||||
If the answer is wrong, you ask student to fix it")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
|
||||
return teacher;
|
||||
}
|
||||
return teacher;
|
||||
}
|
||||
|
||||
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.AnswerQuestionFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
|
||||
});
|
||||
var student = new OpenAIChatAgent(
|
||||
chatClient: client.GetChatClient(model),
|
||||
name: "Student",
|
||||
systemMessage: @"You are a student. You answer math question from teacher.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.AnswerQuestionFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
|
||||
});
|
||||
var student = new OpenAIChatAgent(
|
||||
chatClient: client.GetChatClient(model),
|
||||
name: "Student",
|
||||
systemMessage: @"You are a student. You answer math question from teacher.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
|
||||
return student;
|
||||
}
|
||||
return student;
|
||||
}
|
||||
|
||||
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
|
||||
{
|
||||
var teacher2Student = Transition.Create(teacher, student);
|
||||
var student2Teacher = Transition.Create(student, teacher);
|
||||
var teacher2Admin = Transition.Create(teacher, admin);
|
||||
var admin2Teacher = Transition.Create(admin, teacher);
|
||||
var workflow = new Graph(
|
||||
[
|
||||
teacher2Student,
|
||||
student2Teacher,
|
||||
teacher2Admin,
|
||||
admin2Teacher,
|
||||
]);
|
||||
var group = new GroupChat(
|
||||
workflow: workflow,
|
||||
members: [
|
||||
admin,
|
||||
teacher,
|
||||
student,
|
||||
],
|
||||
admin: groupAdmin);
|
||||
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
|
||||
{
|
||||
var teacher2Student = Transition.Create(teacher, student);
|
||||
var student2Teacher = Transition.Create(student, teacher);
|
||||
var teacher2Admin = Transition.Create(teacher, admin);
|
||||
var admin2Teacher = Transition.Create(admin, teacher);
|
||||
var workflow = new Graph(
|
||||
[
|
||||
teacher2Student,
|
||||
student2Teacher,
|
||||
teacher2Admin,
|
||||
admin2Teacher,
|
||||
]);
|
||||
var group = new GroupChat(
|
||||
workflow: workflow,
|
||||
members: [
|
||||
admin,
|
||||
teacher,
|
||||
student,
|
||||
],
|
||||
admin: groupAdmin);
|
||||
|
||||
var groupChatManager = new GroupChatManager(group);
|
||||
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
|
||||
var groupChatManager = new GroupChatManager(group);
|
||||
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
|
||||
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
|
||||
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
|
||||
// check if there's terminate chat message from admin
|
||||
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
|
||||
.Count()
|
||||
.Should().Be(1);
|
||||
}
|
||||
// check if there's terminate chat message from admin
|
||||
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
|
||||
.Count()
|
||||
.Should().Be(1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,214 +13,213 @@ using Azure.AI.OpenAI;
|
|||
using FluentAssertions;
|
||||
using Xunit.Abstractions;
|
||||
|
||||
namespace AutoGen.OpenAI.V1.Tests
|
||||
namespace AutoGen.OpenAI.V1.Tests;
|
||||
|
||||
public partial class MathClassTest
|
||||
{
|
||||
public partial class MathClassTest
|
||||
private readonly ITestOutputHelper _output;
|
||||
|
||||
// as of 2024-05-20, aoai return 500 error when round > 1
|
||||
// I'm pretty sure that round > 5 was supported before
|
||||
// So this is probably some wield regression on aoai side
|
||||
// I'll keep this test case here for now, plus setting round to 1
|
||||
// so the test can still pass.
|
||||
// In the future, we should rewind this test case to round > 1 (previously was 5)
|
||||
private int round = 1;
|
||||
public MathClassTest(ITestOutputHelper output)
|
||||
{
|
||||
private readonly ITestOutputHelper _output;
|
||||
_output = output;
|
||||
}
|
||||
|
||||
// as of 2024-05-20, aoai return 500 error when round > 1
|
||||
// I'm pretty sure that round > 5 was supported before
|
||||
// So this is probably some wield regression on aoai side
|
||||
// I'll keep this test case here for now, plus setting round to 1
|
||||
// so the test can still pass.
|
||||
// In the future, we should rewind this test case to round > 1 (previously was 5)
|
||||
private int round = 1;
|
||||
public MathClassTest(ITestOutputHelper output)
|
||||
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
|
||||
{
|
||||
try
|
||||
{
|
||||
_output = output;
|
||||
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
|
||||
|
||||
_output.WriteLine(reply.FormatMessage());
|
||||
return Task.FromResult(reply);
|
||||
}
|
||||
|
||||
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
|
||||
catch (Exception)
|
||||
{
|
||||
try
|
||||
_output.WriteLine("Request failed");
|
||||
_output.WriteLine($"agent name: {agent.Name}");
|
||||
foreach (var message in messages)
|
||||
{
|
||||
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
|
||||
|
||||
_output.WriteLine(reply.FormatMessage());
|
||||
return Task.FromResult(reply);
|
||||
}
|
||||
catch (Exception)
|
||||
{
|
||||
_output.WriteLine("Request failed");
|
||||
_output.WriteLine($"agent name: {agent.Name}");
|
||||
foreach (var message in messages)
|
||||
if (message is IMessage<object> envelope)
|
||||
{
|
||||
if (message is IMessage<object> envelope)
|
||||
{
|
||||
var json = JsonSerializer.Serialize(envelope.Content, new JsonSerializerOptions { WriteIndented = true });
|
||||
_output.WriteLine(json);
|
||||
}
|
||||
var json = JsonSerializer.Serialize(envelope.Content, new JsonSerializerOptions { WriteIndented = true });
|
||||
_output.WriteLine(json);
|
||||
}
|
||||
|
||||
throw;
|
||||
}
|
||||
|
||||
throw;
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> CreateMathQuestion(string question, int question_index)
|
||||
{
|
||||
return $@"[MATH_QUESTION]
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> CreateMathQuestion(string question, int question_index)
|
||||
{
|
||||
return $@"[MATH_QUESTION]
|
||||
Question {question_index}:
|
||||
{question}
|
||||
|
||||
Student, please answer";
|
||||
}
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerQuestion(string answer)
|
||||
{
|
||||
return $@"[MATH_ANSWER]
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerQuestion(string answer)
|
||||
{
|
||||
return $@"[MATH_ANSWER]
|
||||
The answer is {answer}
|
||||
teacher please check answer";
|
||||
}
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerIsCorrect(string message)
|
||||
{
|
||||
return $@"[ANSWER_IS_CORRECT]
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerIsCorrect(string message)
|
||||
{
|
||||
return $@"[ANSWER_IS_CORRECT]
|
||||
{message}
|
||||
please update progress";
|
||||
}
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> UpdateProgress(int correctAnswerCount)
|
||||
[FunctionAttribute]
|
||||
public async Task<string> UpdateProgress(int correctAnswerCount)
|
||||
{
|
||||
if (correctAnswerCount >= this.round)
|
||||
{
|
||||
if (correctAnswerCount >= this.round)
|
||||
{
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
{GroupChatExtension.TERMINATE}";
|
||||
}
|
||||
else
|
||||
{
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
}
|
||||
else
|
||||
{
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
the number of resolved question is {correctAnswerCount}
|
||||
teacher, please create the next math question";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task OpenAIAgentMathChatTestAsync()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
|
||||
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
|
||||
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
|
||||
var openaiClient = new OpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(key));
|
||||
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
|
||||
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task OpenAIAgentMathChatTestAsync()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
|
||||
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
|
||||
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
|
||||
var openaiClient = new OpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(key));
|
||||
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
|
||||
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
|
||||
|
||||
var adminFunctionMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.UpdateProgressFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
|
||||
});
|
||||
var admin = new OpenAIChatAgent(
|
||||
openAIClient: openaiClient,
|
||||
modelName: deployName,
|
||||
name: "Admin",
|
||||
systemMessage: $@"You are admin. You update progress after each question is answered.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(adminFunctionMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
var adminFunctionMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.UpdateProgressFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
|
||||
});
|
||||
var admin = new OpenAIChatAgent(
|
||||
openAIClient: openaiClient,
|
||||
modelName: deployName,
|
||||
name: "Admin",
|
||||
systemMessage: $@"You are admin. You update progress after each question is answered.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(adminFunctionMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
|
||||
var groupAdmin = new OpenAIChatAgent(
|
||||
openAIClient: openaiClient,
|
||||
modelName: deployName,
|
||||
name: "GroupAdmin",
|
||||
systemMessage: "You are group admin. You manage the group chat.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterMiddleware(Print);
|
||||
await RunMathChatAsync(teacher, student, admin, groupAdmin);
|
||||
}
|
||||
var groupAdmin = new OpenAIChatAgent(
|
||||
openAIClient: openaiClient,
|
||||
modelName: deployName,
|
||||
name: "GroupAdmin",
|
||||
systemMessage: "You are group admin. You manage the group chat.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterMiddleware(Print);
|
||||
await RunMathChatAsync(teacher, student, admin, groupAdmin);
|
||||
}
|
||||
|
||||
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
|
||||
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
|
||||
});
|
||||
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
|
||||
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
|
||||
});
|
||||
|
||||
var teacher = new OpenAIChatAgent(
|
||||
openAIClient: client,
|
||||
name: "Teacher",
|
||||
systemMessage: @"You are a preschool math teacher.
|
||||
var teacher = new OpenAIChatAgent(
|
||||
openAIClient: client,
|
||||
name: "Teacher",
|
||||
systemMessage: @"You are a preschool math teacher.
|
||||
You create math question and ask student to answer it.
|
||||
Then you check if the answer is correct.
|
||||
If the answer is wrong, you ask student to fix it",
|
||||
modelName: model)
|
||||
.RegisterMiddleware(Print)
|
||||
.RegisterMiddleware(new OpenAIChatRequestMessageConnector())
|
||||
.RegisterMiddleware(functionCallMiddleware);
|
||||
modelName: model)
|
||||
.RegisterMiddleware(Print)
|
||||
.RegisterMiddleware(new OpenAIChatRequestMessageConnector())
|
||||
.RegisterMiddleware(functionCallMiddleware);
|
||||
|
||||
return teacher;
|
||||
}
|
||||
return teacher;
|
||||
}
|
||||
|
||||
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.AnswerQuestionFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
|
||||
});
|
||||
var student = new OpenAIChatAgent(
|
||||
openAIClient: client,
|
||||
name: "Student",
|
||||
modelName: model,
|
||||
systemMessage: @"You are a student. You answer math question from teacher.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.AnswerQuestionFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
|
||||
});
|
||||
var student = new OpenAIChatAgent(
|
||||
openAIClient: client,
|
||||
name: "Student",
|
||||
modelName: model,
|
||||
systemMessage: @"You are a student. You answer math question from teacher.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
|
||||
return student;
|
||||
}
|
||||
return student;
|
||||
}
|
||||
|
||||
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
|
||||
{
|
||||
var teacher2Student = Transition.Create(teacher, student);
|
||||
var student2Teacher = Transition.Create(student, teacher);
|
||||
var teacher2Admin = Transition.Create(teacher, admin);
|
||||
var admin2Teacher = Transition.Create(admin, teacher);
|
||||
var workflow = new Graph(
|
||||
[
|
||||
teacher2Student,
|
||||
student2Teacher,
|
||||
teacher2Admin,
|
||||
admin2Teacher,
|
||||
]);
|
||||
var group = new GroupChat(
|
||||
workflow: workflow,
|
||||
members: [
|
||||
admin,
|
||||
teacher,
|
||||
student,
|
||||
],
|
||||
admin: groupAdmin);
|
||||
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
|
||||
{
|
||||
var teacher2Student = Transition.Create(teacher, student);
|
||||
var student2Teacher = Transition.Create(student, teacher);
|
||||
var teacher2Admin = Transition.Create(teacher, admin);
|
||||
var admin2Teacher = Transition.Create(admin, teacher);
|
||||
var workflow = new Graph(
|
||||
[
|
||||
teacher2Student,
|
||||
student2Teacher,
|
||||
teacher2Admin,
|
||||
admin2Teacher,
|
||||
]);
|
||||
var group = new GroupChat(
|
||||
workflow: workflow,
|
||||
members: [
|
||||
admin,
|
||||
teacher,
|
||||
student,
|
||||
],
|
||||
admin: groupAdmin);
|
||||
|
||||
var groupChatManager = new GroupChatManager(group);
|
||||
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
|
||||
var groupChatManager = new GroupChatManager(group);
|
||||
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
|
||||
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
|
||||
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
|
||||
// check if there's terminate chat message from admin
|
||||
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
|
||||
.Count()
|
||||
.Should().Be(1);
|
||||
}
|
||||
// check if there's terminate chat message from admin
|
||||
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
|
||||
.Count()
|
||||
.Should().Be(1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,85 +4,84 @@
|
|||
using AutoGen.SourceGenerator.Template; // Needed for FunctionCallTemplate
|
||||
using Xunit; // Needed for Fact and Assert
|
||||
|
||||
namespace AutoGen.SourceGenerator.Tests
|
||||
namespace AutoGen.SourceGenerator.Tests;
|
||||
|
||||
public class FunctionCallTemplateEncodingTests
|
||||
{
|
||||
public class FunctionCallTemplateEncodingTests
|
||||
[Fact]
|
||||
public void FunctionDescription_Should_Encode_DoubleQuotes()
|
||||
{
|
||||
[Fact]
|
||||
public void FunctionDescription_Should_Encode_DoubleQuotes()
|
||||
// Arrange
|
||||
var functionContracts = new List<SourceGeneratorFunctionContract>
|
||||
{
|
||||
// Arrange
|
||||
var functionContracts = new List<SourceGeneratorFunctionContract>
|
||||
new SourceGeneratorFunctionContract
|
||||
{
|
||||
new SourceGeneratorFunctionContract
|
||||
Name = "TestFunction",
|
||||
Description = "This is a \"test\" function",
|
||||
Parameters = new SourceGeneratorParameterContract[]
|
||||
{
|
||||
Name = "TestFunction",
|
||||
Description = "This is a \"test\" function",
|
||||
Parameters = new SourceGeneratorParameterContract[]
|
||||
new SourceGeneratorParameterContract
|
||||
{
|
||||
new SourceGeneratorParameterContract
|
||||
{
|
||||
Name = "param1",
|
||||
Description = "This is a \"parameter\" description",
|
||||
Type = "string",
|
||||
IsOptional = false
|
||||
}
|
||||
},
|
||||
ReturnType = "void"
|
||||
}
|
||||
};
|
||||
Name = "param1",
|
||||
Description = "This is a \"parameter\" description",
|
||||
Type = "string",
|
||||
IsOptional = false
|
||||
}
|
||||
},
|
||||
ReturnType = "void"
|
||||
}
|
||||
};
|
||||
|
||||
var template = new FunctionCallTemplate
|
||||
{
|
||||
NameSpace = "TestNamespace",
|
||||
ClassName = "TestClass",
|
||||
FunctionContracts = functionContracts
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = template.TransformText();
|
||||
|
||||
// Assert
|
||||
Assert.Contains("Description = @\"This is a \"\"test\"\" function\"", result);
|
||||
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ParameterDescription_Should_Encode_DoubleQuotes()
|
||||
var template = new FunctionCallTemplate
|
||||
{
|
||||
// Arrange
|
||||
var functionContracts = new List<SourceGeneratorFunctionContract>
|
||||
NameSpace = "TestNamespace",
|
||||
ClassName = "TestClass",
|
||||
FunctionContracts = functionContracts
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = template.TransformText();
|
||||
|
||||
// Assert
|
||||
Assert.Contains("Description = @\"This is a \"\"test\"\" function\"", result);
|
||||
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ParameterDescription_Should_Encode_DoubleQuotes()
|
||||
{
|
||||
// Arrange
|
||||
var functionContracts = new List<SourceGeneratorFunctionContract>
|
||||
{
|
||||
new SourceGeneratorFunctionContract
|
||||
{
|
||||
new SourceGeneratorFunctionContract
|
||||
Name = "TestFunction",
|
||||
Description = "This is a test function",
|
||||
Parameters = new SourceGeneratorParameterContract[]
|
||||
{
|
||||
Name = "TestFunction",
|
||||
Description = "This is a test function",
|
||||
Parameters = new SourceGeneratorParameterContract[]
|
||||
new SourceGeneratorParameterContract
|
||||
{
|
||||
new SourceGeneratorParameterContract
|
||||
{
|
||||
Name = "param1",
|
||||
Description = "This is a \"parameter\" description",
|
||||
Type = "string",
|
||||
IsOptional = false
|
||||
}
|
||||
},
|
||||
ReturnType = "void"
|
||||
}
|
||||
};
|
||||
Name = "param1",
|
||||
Description = "This is a \"parameter\" description",
|
||||
Type = "string",
|
||||
IsOptional = false
|
||||
}
|
||||
},
|
||||
ReturnType = "void"
|
||||
}
|
||||
};
|
||||
|
||||
var template = new FunctionCallTemplate
|
||||
{
|
||||
NameSpace = "TestNamespace",
|
||||
ClassName = "TestClass",
|
||||
FunctionContracts = functionContracts
|
||||
};
|
||||
var template = new FunctionCallTemplate
|
||||
{
|
||||
NameSpace = "TestNamespace",
|
||||
ClassName = "TestClass",
|
||||
FunctionContracts = functionContracts
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = template.TransformText();
|
||||
// Act
|
||||
var result = template.TransformText();
|
||||
|
||||
// Assert
|
||||
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
|
||||
}
|
||||
// Assert
|
||||
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,122 +10,121 @@ using FluentAssertions;
|
|||
using OpenAI.Chat;
|
||||
using Xunit;
|
||||
|
||||
namespace AutoGen.SourceGenerator.Tests
|
||||
namespace AutoGen.SourceGenerator.Tests;
|
||||
|
||||
public class FunctionExample
|
||||
{
|
||||
public class FunctionExample
|
||||
private readonly FunctionExamples functionExamples = new FunctionExamples();
|
||||
private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions
|
||||
{
|
||||
private readonly FunctionExamples functionExamples = new FunctionExamples();
|
||||
private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions
|
||||
WriteIndented = true,
|
||||
};
|
||||
|
||||
[Fact]
|
||||
public void Add_Test()
|
||||
{
|
||||
var args = new
|
||||
{
|
||||
WriteIndented = true,
|
||||
a = 1,
|
||||
b = 2,
|
||||
};
|
||||
|
||||
[Fact]
|
||||
public void Add_Test()
|
||||
this.VerifyFunction(functionExamples.AddWrapper, args, 3);
|
||||
this.VerifyFunctionDefinition(functionExamples.AddFunctionContract.ToChatTool());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Sum_Test()
|
||||
{
|
||||
var args = new
|
||||
{
|
||||
var args = new
|
||||
args = new double[] { 1, 2, 3 },
|
||||
};
|
||||
|
||||
this.VerifyFunction(functionExamples.SumWrapper, args, 6.0);
|
||||
this.VerifyFunctionDefinition(functionExamples.SumFunctionContract.ToChatTool());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task DictionaryToString_Test()
|
||||
{
|
||||
var args = new
|
||||
{
|
||||
xargs = new Dictionary<string, string>
|
||||
{
|
||||
a = 1,
|
||||
b = 2,
|
||||
};
|
||||
{ "a", "1" },
|
||||
{ "b", "2" },
|
||||
},
|
||||
};
|
||||
|
||||
this.VerifyFunction(functionExamples.AddWrapper, args, 3);
|
||||
this.VerifyFunctionDefinition(functionExamples.AddFunctionContract.ToChatTool());
|
||||
}
|
||||
await this.VerifyAsyncFunction(functionExamples.DictionaryToStringAsyncWrapper, args, JsonSerializer.Serialize(args.xargs, jsonSerializerOptions));
|
||||
this.VerifyFunctionDefinition(functionExamples.DictionaryToStringAsyncFunctionContract.ToChatTool());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Sum_Test()
|
||||
[Fact]
|
||||
public async Task TopLevelFunctionExampleAddTestAsync()
|
||||
{
|
||||
var example = new TopLevelStatementFunctionExample();
|
||||
var args = new
|
||||
{
|
||||
var args = new
|
||||
{
|
||||
args = new double[] { 1, 2, 3 },
|
||||
};
|
||||
a = 1,
|
||||
b = 2,
|
||||
};
|
||||
|
||||
this.VerifyFunction(functionExamples.SumWrapper, args, 6.0);
|
||||
this.VerifyFunctionDefinition(functionExamples.SumFunctionContract.ToChatTool());
|
||||
}
|
||||
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task DictionaryToString_Test()
|
||||
[Fact]
|
||||
public async Task FilescopeFunctionExampleAddTestAsync()
|
||||
{
|
||||
var example = new FilescopeNamespaceFunctionExample();
|
||||
var args = new
|
||||
{
|
||||
var args = new
|
||||
{
|
||||
xargs = new Dictionary<string, string>
|
||||
{
|
||||
{ "a", "1" },
|
||||
{ "b", "2" },
|
||||
},
|
||||
};
|
||||
a = 1,
|
||||
b = 2,
|
||||
};
|
||||
|
||||
await this.VerifyAsyncFunction(functionExamples.DictionaryToStringAsyncWrapper, args, JsonSerializer.Serialize(args.xargs, jsonSerializerOptions));
|
||||
this.VerifyFunctionDefinition(functionExamples.DictionaryToStringAsyncFunctionContract.ToChatTool());
|
||||
}
|
||||
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TopLevelFunctionExampleAddTestAsync()
|
||||
[Fact]
|
||||
public void Query_Test()
|
||||
{
|
||||
var args = new
|
||||
{
|
||||
var example = new TopLevelStatementFunctionExample();
|
||||
var args = new
|
||||
{
|
||||
a = 1,
|
||||
b = 2,
|
||||
};
|
||||
query = "hello",
|
||||
k = 3,
|
||||
};
|
||||
|
||||
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
|
||||
}
|
||||
this.VerifyFunction(functionExamples.QueryWrapper, args, new[] { "hello", "hello", "hello" });
|
||||
this.VerifyFunctionDefinition(functionExamples.QueryFunctionContract.ToChatTool());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task FilescopeFunctionExampleAddTestAsync()
|
||||
[UseReporter(typeof(DiffReporter))]
|
||||
[UseApprovalSubdirectory("ApprovalTests")]
|
||||
private void VerifyFunctionDefinition(ChatTool function)
|
||||
{
|
||||
var func = new
|
||||
{
|
||||
var example = new FilescopeNamespaceFunctionExample();
|
||||
var args = new
|
||||
{
|
||||
a = 1,
|
||||
b = 2,
|
||||
};
|
||||
name = function.FunctionName,
|
||||
description = function.FunctionDescription.Replace(Environment.NewLine, ","),
|
||||
parameters = function.FunctionParameters.ToObjectFromJson<object>(options: jsonSerializerOptions),
|
||||
};
|
||||
|
||||
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
|
||||
}
|
||||
Approvals.Verify(JsonSerializer.Serialize(func, jsonSerializerOptions));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Query_Test()
|
||||
{
|
||||
var args = new
|
||||
{
|
||||
query = "hello",
|
||||
k = 3,
|
||||
};
|
||||
private void VerifyFunction<T, U>(Func<string, T> func, U args, T expected)
|
||||
{
|
||||
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
|
||||
var res = func(str);
|
||||
res.Should().BeEquivalentTo(expected);
|
||||
}
|
||||
|
||||
this.VerifyFunction(functionExamples.QueryWrapper, args, new[] { "hello", "hello", "hello" });
|
||||
this.VerifyFunctionDefinition(functionExamples.QueryFunctionContract.ToChatTool());
|
||||
}
|
||||
|
||||
[UseReporter(typeof(DiffReporter))]
|
||||
[UseApprovalSubdirectory("ApprovalTests")]
|
||||
private void VerifyFunctionDefinition(ChatTool function)
|
||||
{
|
||||
var func = new
|
||||
{
|
||||
name = function.FunctionName,
|
||||
description = function.FunctionDescription.Replace(Environment.NewLine, ","),
|
||||
parameters = function.FunctionParameters.ToObjectFromJson<object>(options: jsonSerializerOptions),
|
||||
};
|
||||
|
||||
Approvals.Verify(JsonSerializer.Serialize(func, jsonSerializerOptions));
|
||||
}
|
||||
|
||||
private void VerifyFunction<T, U>(Func<string, T> func, U args, T expected)
|
||||
{
|
||||
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
|
||||
var res = func(str);
|
||||
res.Should().BeEquivalentTo(expected);
|
||||
}
|
||||
|
||||
private async Task VerifyAsyncFunction<T, U>(Func<string, Task<T>> func, U args, T expected)
|
||||
{
|
||||
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
|
||||
var res = await func(str);
|
||||
res.Should().BeEquivalentTo(expected);
|
||||
}
|
||||
private async Task VerifyAsyncFunction<T, U>(Func<string, Task<T>> func, U args, T expected)
|
||||
{
|
||||
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
|
||||
var res = await func(str);
|
||||
res.Should().BeEquivalentTo(expected);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,67 +4,66 @@
|
|||
using System.Text.Json;
|
||||
using AutoGen.Core;
|
||||
|
||||
namespace AutoGen.SourceGenerator.Tests
|
||||
namespace AutoGen.SourceGenerator.Tests;
|
||||
|
||||
public partial class FunctionExamples
|
||||
{
|
||||
public partial class FunctionExamples
|
||||
/// <summary>
|
||||
/// Add function
|
||||
/// </summary>
|
||||
/// <param name="a">a</param>
|
||||
/// <param name="b">b</param>
|
||||
[FunctionAttribute]
|
||||
public int Add(int a, int b)
|
||||
{
|
||||
/// <summary>
|
||||
/// Add function
|
||||
/// </summary>
|
||||
/// <param name="a">a</param>
|
||||
/// <param name="b">b</param>
|
||||
[FunctionAttribute]
|
||||
public int Add(int a, int b)
|
||||
{
|
||||
return a + b;
|
||||
}
|
||||
return a + b;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Add two numbers.
|
||||
/// </summary>
|
||||
/// <param name="a">The first number.</param>
|
||||
/// <param name="b">The second number.</param>
|
||||
[Function]
|
||||
public Task<string> AddAsync(int a, int b)
|
||||
{
|
||||
return Task.FromResult($"{a} + {b} = {a + b}");
|
||||
}
|
||||
/// <summary>
|
||||
/// Add two numbers.
|
||||
/// </summary>
|
||||
/// <param name="a">The first number.</param>
|
||||
/// <param name="b">The second number.</param>
|
||||
[Function]
|
||||
public Task<string> AddAsync(int a, int b)
|
||||
{
|
||||
return Task.FromResult($"{a} + {b} = {a + b}");
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sum function
|
||||
/// </summary>
|
||||
/// <param name="args">an array of double values</param>
|
||||
[FunctionAttribute]
|
||||
public double Sum(double[] args)
|
||||
{
|
||||
return args.Sum();
|
||||
}
|
||||
/// <summary>
|
||||
/// Sum function
|
||||
/// </summary>
|
||||
/// <param name="args">an array of double values</param>
|
||||
[FunctionAttribute]
|
||||
public double Sum(double[] args)
|
||||
{
|
||||
return args.Sum();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// DictionaryToString function
|
||||
/// </summary>
|
||||
/// <param name="xargs">an object of key-value pairs. key is string, value is string</param>
|
||||
[FunctionAttribute]
|
||||
public Task<string> DictionaryToStringAsync(Dictionary<string, string> xargs)
|
||||
/// <summary>
|
||||
/// DictionaryToString function
|
||||
/// </summary>
|
||||
/// <param name="xargs">an object of key-value pairs. key is string, value is string</param>
|
||||
[FunctionAttribute]
|
||||
public Task<string> DictionaryToStringAsync(Dictionary<string, string> xargs)
|
||||
{
|
||||
var res = JsonSerializer.Serialize(xargs, new JsonSerializerOptions
|
||||
{
|
||||
var res = JsonSerializer.Serialize(xargs, new JsonSerializerOptions
|
||||
{
|
||||
WriteIndented = true,
|
||||
});
|
||||
WriteIndented = true,
|
||||
});
|
||||
|
||||
return Task.FromResult(res);
|
||||
}
|
||||
return Task.FromResult(res);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// query function
|
||||
/// </summary>
|
||||
/// <param name="query">query, required</param>
|
||||
/// <param name="k">top k, optional, default value is 3</param>
|
||||
/// <param name="thresold">thresold, optional, default value is 0.5</param>
|
||||
[FunctionAttribute]
|
||||
public string[] Query(string query, int k = 3, float thresold = 0.5f)
|
||||
{
|
||||
return Enumerable.Repeat(query, k).ToArray();
|
||||
}
|
||||
/// <summary>
|
||||
/// query function
|
||||
/// </summary>
|
||||
/// <param name="query">query, required</param>
|
||||
/// <param name="k">top k, optional, default value is 3</param>
|
||||
/// <param name="thresold">thresold, optional, default value is 0.5</param>
|
||||
[FunctionAttribute]
|
||||
public string[] Query(string query, int k = 3, float thresold = 0.5f)
|
||||
{
|
||||
return Enumerable.Repeat(query, k).ToArray();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,73 +7,72 @@ using System.Threading.Tasks;
|
|||
using AutoGen.BasicSample;
|
||||
using Xunit.Abstractions;
|
||||
|
||||
namespace AutoGen.Tests
|
||||
namespace AutoGen.Tests;
|
||||
|
||||
public class BasicSampleTest
|
||||
{
|
||||
public class BasicSampleTest
|
||||
private readonly ITestOutputHelper _output;
|
||||
|
||||
public BasicSampleTest(ITestOutputHelper output)
|
||||
{
|
||||
private readonly ITestOutputHelper _output;
|
||||
_output = output;
|
||||
Console.SetOut(new ConsoleWriter(_output));
|
||||
}
|
||||
|
||||
public BasicSampleTest(ITestOutputHelper output)
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task AssistantAgentTestAsync()
|
||||
{
|
||||
await Example01_AssistantAgent.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task TwoAgentMathClassTestAsync()
|
||||
{
|
||||
await Example02_TwoAgent_MathChat.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("OPENAI_API_KEY")]
|
||||
public async Task AgentFunctionCallTestAsync()
|
||||
{
|
||||
await Example03_Agent_FunctionCall.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("MISTRAL_API_KEY")]
|
||||
public async Task MistralClientAgent_TokenCount()
|
||||
{
|
||||
await Example14_MistralClientAgent_TokenCount.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task DynamicGroupChatCalculateFibonacciAsync()
|
||||
{
|
||||
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunAsync();
|
||||
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunWorkflowAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("OPENAI_API_KEY")]
|
||||
public async Task DalleAndGPT4VTestAsync()
|
||||
{
|
||||
await Example05_Dalle_And_GPT4V.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("OPENAI_API_KEY")]
|
||||
public async Task GPT4ImageMessage()
|
||||
{
|
||||
await Example15_GPT4V_BinaryDataImageMessage.RunAsync();
|
||||
}
|
||||
|
||||
public class ConsoleWriter : StringWriter
|
||||
{
|
||||
private ITestOutputHelper output;
|
||||
public ConsoleWriter(ITestOutputHelper output)
|
||||
{
|
||||
_output = output;
|
||||
Console.SetOut(new ConsoleWriter(_output));
|
||||
this.output = output;
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task AssistantAgentTestAsync()
|
||||
public override void WriteLine(string? m)
|
||||
{
|
||||
await Example01_AssistantAgent.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task TwoAgentMathClassTestAsync()
|
||||
{
|
||||
await Example02_TwoAgent_MathChat.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("OPENAI_API_KEY")]
|
||||
public async Task AgentFunctionCallTestAsync()
|
||||
{
|
||||
await Example03_Agent_FunctionCall.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("MISTRAL_API_KEY")]
|
||||
public async Task MistralClientAgent_TokenCount()
|
||||
{
|
||||
await Example14_MistralClientAgent_TokenCount.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task DynamicGroupChatCalculateFibonacciAsync()
|
||||
{
|
||||
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunAsync();
|
||||
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunWorkflowAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("OPENAI_API_KEY")]
|
||||
public async Task DalleAndGPT4VTestAsync()
|
||||
{
|
||||
await Example05_Dalle_And_GPT4V.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("OPENAI_API_KEY")]
|
||||
public async Task GPT4ImageMessage()
|
||||
{
|
||||
await Example15_GPT4V_BinaryDataImageMessage.RunAsync();
|
||||
}
|
||||
|
||||
public class ConsoleWriter : StringWriter
|
||||
{
|
||||
private ITestOutputHelper output;
|
||||
public ConsoleWriter(ITestOutputHelper output)
|
||||
{
|
||||
this.output = output;
|
||||
}
|
||||
|
||||
public override void WriteLine(string? m)
|
||||
{
|
||||
output.WriteLine(m);
|
||||
}
|
||||
output.WriteLine(m);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,18 +3,17 @@
|
|||
|
||||
using Xunit;
|
||||
|
||||
namespace AutoGen.Tests
|
||||
{
|
||||
public class GraphTests
|
||||
{
|
||||
[Fact]
|
||||
public void GraphTest()
|
||||
{
|
||||
var graph1 = new Graph();
|
||||
Assert.NotNull(graph1);
|
||||
namespace AutoGen.Tests;
|
||||
|
||||
var graph2 = new Graph(null);
|
||||
Assert.NotNull(graph2);
|
||||
}
|
||||
public class GraphTests
|
||||
{
|
||||
[Fact]
|
||||
public void GraphTest()
|
||||
{
|
||||
var graph1 = new Graph();
|
||||
Assert.NotNull(graph1);
|
||||
|
||||
var graph2 = new Graph(null);
|
||||
Assert.NotNull(graph2);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,219 +9,218 @@ using FluentAssertions;
|
|||
using Xunit;
|
||||
using Xunit.Abstractions;
|
||||
|
||||
namespace AutoGen.Tests
|
||||
namespace AutoGen.Tests;
|
||||
|
||||
public partial class SingleAgentTest
|
||||
{
|
||||
public partial class SingleAgentTest
|
||||
private ITestOutputHelper _output;
|
||||
public SingleAgentTest(ITestOutputHelper output)
|
||||
{
|
||||
private ITestOutputHelper _output;
|
||||
public SingleAgentTest(ITestOutputHelper output)
|
||||
{
|
||||
_output = output;
|
||||
}
|
||||
_output = output;
|
||||
}
|
||||
|
||||
private ILLMConfig CreateAzureOpenAIGPT35TurboConfig()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
|
||||
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
|
||||
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
|
||||
return new AzureOpenAIConfig(endpoint, deployName, key);
|
||||
}
|
||||
private ILLMConfig CreateAzureOpenAIGPT35TurboConfig()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
|
||||
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
|
||||
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
|
||||
return new AzureOpenAIConfig(endpoint, deployName, key);
|
||||
}
|
||||
|
||||
private ILLMConfig CreateOpenAIGPT4VisionConfig()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new ArgumentException("OPENAI_API_KEY is not set");
|
||||
return new OpenAIConfig(key, "gpt-4-vision-preview");
|
||||
}
|
||||
private ILLMConfig CreateOpenAIGPT4VisionConfig()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new ArgumentException("OPENAI_API_KEY is not set");
|
||||
return new OpenAIConfig(key, "gpt-4-vision-preview");
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task AssistantAgentFunctionCallTestAsync()
|
||||
{
|
||||
var config = this.CreateAzureOpenAIGPT35TurboConfig();
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task AssistantAgentFunctionCallTestAsync()
|
||||
{
|
||||
var config = this.CreateAzureOpenAIGPT35TurboConfig();
|
||||
|
||||
var llmConfig = new ConversableAgentConfig
|
||||
var llmConfig = new ConversableAgentConfig
|
||||
{
|
||||
Temperature = 0,
|
||||
FunctionContracts = new[]
|
||||
{
|
||||
Temperature = 0,
|
||||
FunctionContracts = new[]
|
||||
{
|
||||
this.EchoAsyncFunctionContract,
|
||||
},
|
||||
ConfigList = new[]
|
||||
{
|
||||
config,
|
||||
},
|
||||
};
|
||||
|
||||
var assistantAgent = new AssistantAgent(
|
||||
name: "assistant",
|
||||
llmConfig: llmConfig);
|
||||
|
||||
await EchoFunctionCallTestAsync(assistantAgent);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task AssistantAgentDefaultReplyTestAsync()
|
||||
{
|
||||
var assistantAgent = new AssistantAgent(
|
||||
llmConfig: null,
|
||||
name: "assistant",
|
||||
defaultReply: "hello world");
|
||||
|
||||
var reply = await assistantAgent.SendAsync("hi");
|
||||
|
||||
reply.GetContent().Should().Be("hello world");
|
||||
reply.GetRole().Should().Be(Role.Assistant);
|
||||
reply.From.Should().Be(assistantAgent.Name);
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task AssistantAgentFunctionCallSelfExecutionTestAsync()
|
||||
{
|
||||
var config = this.CreateAzureOpenAIGPT35TurboConfig();
|
||||
var llmConfig = new ConversableAgentConfig
|
||||
this.EchoAsyncFunctionContract,
|
||||
},
|
||||
ConfigList = new[]
|
||||
{
|
||||
FunctionContracts = new[]
|
||||
{
|
||||
this.EchoAsyncFunctionContract,
|
||||
},
|
||||
ConfigList = new[]
|
||||
{
|
||||
config,
|
||||
},
|
||||
};
|
||||
var assistantAgent = new AssistantAgent(
|
||||
name: "assistant",
|
||||
llmConfig: llmConfig,
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ nameof(EchoAsync), this.EchoAsyncWrapper },
|
||||
});
|
||||
config,
|
||||
},
|
||||
};
|
||||
|
||||
await EchoFunctionCallExecutionTestAsync(assistantAgent);
|
||||
}
|
||||
var assistantAgent = new AssistantAgent(
|
||||
name: "assistant",
|
||||
llmConfig: llmConfig);
|
||||
|
||||
/// <summary>
|
||||
/// echo when asked.
|
||||
/// </summary>
|
||||
/// <param name="message">message to echo</param>
|
||||
[FunctionAttribute]
|
||||
public async Task<string> EchoAsync(string message)
|
||||
await EchoFunctionCallTestAsync(assistantAgent);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task AssistantAgentDefaultReplyTestAsync()
|
||||
{
|
||||
var assistantAgent = new AssistantAgent(
|
||||
llmConfig: null,
|
||||
name: "assistant",
|
||||
defaultReply: "hello world");
|
||||
|
||||
var reply = await assistantAgent.SendAsync("hi");
|
||||
|
||||
reply.GetContent().Should().Be("hello world");
|
||||
reply.GetRole().Should().Be(Role.Assistant);
|
||||
reply.From.Should().Be(assistantAgent.Name);
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task AssistantAgentFunctionCallSelfExecutionTestAsync()
|
||||
{
|
||||
var config = this.CreateAzureOpenAIGPT35TurboConfig();
|
||||
var llmConfig = new ConversableAgentConfig
|
||||
{
|
||||
return $"[ECHO] {message}";
|
||||
}
|
||||
FunctionContracts = new[]
|
||||
{
|
||||
this.EchoAsyncFunctionContract,
|
||||
},
|
||||
ConfigList = new[]
|
||||
{
|
||||
config,
|
||||
},
|
||||
};
|
||||
var assistantAgent = new AssistantAgent(
|
||||
name: "assistant",
|
||||
llmConfig: llmConfig,
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ nameof(EchoAsync), this.EchoAsyncWrapper },
|
||||
});
|
||||
|
||||
/// <summary>
|
||||
/// return the label name with hightest inference cost
|
||||
/// </summary>
|
||||
/// <param name="labelName"></param>
|
||||
/// <returns></returns>
|
||||
[FunctionAttribute]
|
||||
public async Task<string> GetHighestLabel(string labelName, string color)
|
||||
await EchoFunctionCallExecutionTestAsync(assistantAgent);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// echo when asked.
|
||||
/// </summary>
|
||||
/// <param name="message">message to echo</param>
|
||||
[FunctionAttribute]
|
||||
public async Task<string> EchoAsync(string message)
|
||||
{
|
||||
return $"[ECHO] {message}";
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// return the label name with hightest inference cost
|
||||
/// </summary>
|
||||
/// <param name="labelName"></param>
|
||||
/// <returns></returns>
|
||||
[FunctionAttribute]
|
||||
public async Task<string> GetHighestLabel(string labelName, string color)
|
||||
{
|
||||
return $"[HIGHEST_LABEL] {labelName} {color}";
|
||||
}
|
||||
|
||||
public async Task EchoFunctionCallTestAsync(IAgent agent)
|
||||
{
|
||||
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that call echo function");
|
||||
var helloWorld = new TextMessage(Role.User, "echo Hello world");
|
||||
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
|
||||
|
||||
reply.From.Should().Be(agent.Name);
|
||||
reply.GetToolCalls()!.First().FunctionName.Should().Be(nameof(EchoAsync));
|
||||
}
|
||||
|
||||
public async Task EchoFunctionCallExecutionTestAsync(IAgent agent)
|
||||
{
|
||||
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
|
||||
var helloWorld = new TextMessage(Role.User, "echo Hello world");
|
||||
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
|
||||
|
||||
reply.GetContent().Should().Be("[ECHO] Hello world");
|
||||
reply.From.Should().Be(agent.Name);
|
||||
reply.Should().BeOfType<ToolCallAggregateMessage>();
|
||||
}
|
||||
|
||||
public async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent agent)
|
||||
{
|
||||
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
|
||||
var helloWorld = new TextMessage(Role.User, "echo Hello world");
|
||||
var option = new GenerateReplyOptions
|
||||
{
|
||||
return $"[HIGHEST_LABEL] {labelName} {color}";
|
||||
}
|
||||
|
||||
public async Task EchoFunctionCallTestAsync(IAgent agent)
|
||||
Temperature = 0,
|
||||
};
|
||||
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { helloWorld }, option);
|
||||
var answer = "[ECHO] Hello world";
|
||||
IMessage? finalReply = default;
|
||||
await foreach (var reply in replyStream)
|
||||
{
|
||||
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that call echo function");
|
||||
var helloWorld = new TextMessage(Role.User, "echo Hello world");
|
||||
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
|
||||
|
||||
reply.From.Should().Be(agent.Name);
|
||||
reply.GetToolCalls()!.First().FunctionName.Should().Be(nameof(EchoAsync));
|
||||
finalReply = reply;
|
||||
}
|
||||
|
||||
public async Task EchoFunctionCallExecutionTestAsync(IAgent agent)
|
||||
if (finalReply is ToolCallAggregateMessage aggregateMessage)
|
||||
{
|
||||
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
|
||||
var helloWorld = new TextMessage(Role.User, "echo Hello world");
|
||||
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
|
||||
|
||||
reply.GetContent().Should().Be("[ECHO] Hello world");
|
||||
reply.From.Should().Be(agent.Name);
|
||||
reply.Should().BeOfType<ToolCallAggregateMessage>();
|
||||
var toolCallResultMessage = aggregateMessage.Message2;
|
||||
toolCallResultMessage.ToolCalls.First().Result.Should().Be(answer);
|
||||
toolCallResultMessage.From.Should().Be(agent.Name);
|
||||
toolCallResultMessage.ToolCalls.First().FunctionName.Should().Be(nameof(EchoAsync));
|
||||
}
|
||||
|
||||
public async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent agent)
|
||||
else
|
||||
{
|
||||
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
|
||||
var helloWorld = new TextMessage(Role.User, "echo Hello world");
|
||||
var option = new GenerateReplyOptions
|
||||
{
|
||||
Temperature = 0,
|
||||
};
|
||||
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { helloWorld }, option);
|
||||
var answer = "[ECHO] Hello world";
|
||||
IMessage? finalReply = default;
|
||||
await foreach (var reply in replyStream)
|
||||
{
|
||||
reply.From.Should().Be(agent.Name);
|
||||
finalReply = reply;
|
||||
}
|
||||
|
||||
if (finalReply is ToolCallAggregateMessage aggregateMessage)
|
||||
{
|
||||
var toolCallResultMessage = aggregateMessage.Message2;
|
||||
toolCallResultMessage.ToolCalls.First().Result.Should().Be(answer);
|
||||
toolCallResultMessage.From.Should().Be(agent.Name);
|
||||
toolCallResultMessage.ToolCalls.First().FunctionName.Should().Be(nameof(EchoAsync));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new Exception("unexpected message type");
|
||||
}
|
||||
}
|
||||
|
||||
public async Task UpperCaseTestAsync(IAgent agent)
|
||||
{
|
||||
var message = new TextMessage(Role.User, "Please convert abcde to upper case.");
|
||||
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { message });
|
||||
|
||||
reply.GetContent().Should().Contain("ABCDE");
|
||||
reply.From.Should().Be(agent.Name);
|
||||
}
|
||||
|
||||
public async Task UpperCaseStreamingTestAsync(IStreamingAgent agent)
|
||||
{
|
||||
var message = new TextMessage(Role.User, "Please convert 'hello world' to upper case");
|
||||
var option = new GenerateReplyOptions
|
||||
{
|
||||
Temperature = 0,
|
||||
};
|
||||
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message }, option);
|
||||
var answer = "HELLO WORLD";
|
||||
TextMessage? finalReply = default;
|
||||
await foreach (var reply in replyStream)
|
||||
{
|
||||
if (reply is TextMessageUpdate update)
|
||||
{
|
||||
update.From.Should().Be(agent.Name);
|
||||
|
||||
if (finalReply is null)
|
||||
{
|
||||
finalReply = new TextMessage(update);
|
||||
}
|
||||
else
|
||||
{
|
||||
finalReply.Update(update);
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
else if (reply is TextMessage textMessage)
|
||||
{
|
||||
finalReply = textMessage;
|
||||
continue;
|
||||
}
|
||||
|
||||
throw new Exception("unexpected message type");
|
||||
}
|
||||
|
||||
finalReply!.Content.Should().Contain(answer);
|
||||
finalReply!.Role.Should().Be(Role.Assistant);
|
||||
finalReply!.From.Should().Be(agent.Name);
|
||||
throw new Exception("unexpected message type");
|
||||
}
|
||||
}
|
||||
|
||||
public async Task UpperCaseTestAsync(IAgent agent)
|
||||
{
|
||||
var message = new TextMessage(Role.User, "Please convert abcde to upper case.");
|
||||
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { message });
|
||||
|
||||
reply.GetContent().Should().Contain("ABCDE");
|
||||
reply.From.Should().Be(agent.Name);
|
||||
}
|
||||
|
||||
public async Task UpperCaseStreamingTestAsync(IStreamingAgent agent)
|
||||
{
|
||||
var message = new TextMessage(Role.User, "Please convert 'hello world' to upper case");
|
||||
var option = new GenerateReplyOptions
|
||||
{
|
||||
Temperature = 0,
|
||||
};
|
||||
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message }, option);
|
||||
var answer = "HELLO WORLD";
|
||||
TextMessage? finalReply = default;
|
||||
await foreach (var reply in replyStream)
|
||||
{
|
||||
if (reply is TextMessageUpdate update)
|
||||
{
|
||||
update.From.Should().Be(agent.Name);
|
||||
|
||||
if (finalReply is null)
|
||||
{
|
||||
finalReply = new TextMessage(update);
|
||||
}
|
||||
else
|
||||
{
|
||||
finalReply.Update(update);
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
else if (reply is TextMessage textMessage)
|
||||
{
|
||||
finalReply = textMessage;
|
||||
continue;
|
||||
}
|
||||
|
||||
throw new Exception("unexpected message type");
|
||||
}
|
||||
|
||||
finalReply!.Content.Should().Contain(answer);
|
||||
finalReply!.Role.Should().Be(Role.Assistant);
|
||||
finalReply!.From.Should().Be(agent.Name);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,12 +18,16 @@ from autogen_core.components.tools import FunctionTool, Tool
|
|||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from .. import EVENT_LOGGER_NAME
|
||||
from ..base import Response
|
||||
from ..messages import (
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
InnerMessage,
|
||||
ResetMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallMessage,
|
||||
ToolCallResultMessages,
|
||||
)
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
|
@ -207,7 +211,14 @@ class AssistantAgent(BaseChatAgent):
|
|||
)
|
||||
self._model_context: List[LLMMessage] = []
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
"""The types of messages that the assistant agent produces."""
|
||||
if self._handoffs:
|
||||
return [TextMessage, HandoffMessage, StopMessage]
|
||||
return [TextMessage, StopMessage]
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
# Add messages to the model context.
|
||||
for msg in messages:
|
||||
if isinstance(msg, ResetMessage):
|
||||
|
@ -215,6 +226,9 @@ class AssistantAgent(BaseChatAgent):
|
|||
else:
|
||||
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
|
||||
|
||||
# Inner messages.
|
||||
inner_messages: List[InnerMessage] = []
|
||||
|
||||
# Generate an inference result based on the current model context.
|
||||
llm_messages = self._system_messages + self._model_context
|
||||
result = await self._model_client.create(
|
||||
|
@ -227,12 +241,16 @@ class AssistantAgent(BaseChatAgent):
|
|||
# Run tool calls until the model produces a string response.
|
||||
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
|
||||
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
|
||||
# Add the tool call message to the output.
|
||||
inner_messages.append(ToolCallMessage(content=result.content, source=self.name))
|
||||
|
||||
# Execute the tool calls.
|
||||
results = await asyncio.gather(
|
||||
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
|
||||
)
|
||||
event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name))
|
||||
self._model_context.append(FunctionExecutionResultMessage(content=results))
|
||||
inner_messages.append(ToolCallResultMessages(content=results, source=self.name))
|
||||
|
||||
# Detect handoff requests.
|
||||
handoffs: List[Handoff] = []
|
||||
|
@ -242,8 +260,13 @@ class AssistantAgent(BaseChatAgent):
|
|||
if len(handoffs) > 0:
|
||||
if len(handoffs) > 1:
|
||||
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
|
||||
# Respond with a handoff message.
|
||||
return HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name)
|
||||
# Return the output messages to signal the handoff.
|
||||
return Response(
|
||||
chat_message=HandoffMessage(
|
||||
content=handoffs[0].message, target=handoffs[0].target, source=self.name
|
||||
),
|
||||
inner_messages=inner_messages,
|
||||
)
|
||||
|
||||
# Generate an inference result based on the current model context.
|
||||
result = await self._model_client.create(
|
||||
|
@ -255,9 +278,13 @@ class AssistantAgent(BaseChatAgent):
|
|||
# Detect stop request.
|
||||
request_stop = "terminate" in result.content.strip().lower()
|
||||
if request_stop:
|
||||
return StopMessage(content=result.content, source=self.name)
|
||||
return Response(
|
||||
chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages
|
||||
)
|
||||
|
||||
return TextMessage(content=result.content, source=self.name)
|
||||
return Response(
|
||||
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
|
||||
)
|
||||
|
||||
async def _execute_tool_call(
|
||||
self, tool_call: FunctionCall, cancellation_token: CancellationToken
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Sequence
|
||||
from typing import List, Sequence
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
|
||||
from ..base import ChatAgent, TaskResult, TerminationCondition
|
||||
from ..messages import ChatMessage
|
||||
from ..teams import RoundRobinGroupChat
|
||||
from ..base import ChatAgent, Response, TaskResult, TerminationCondition
|
||||
from ..messages import ChatMessage, InnerMessage, TextMessage
|
||||
|
||||
|
||||
class BaseChatAgent(ChatAgent, ABC):
|
||||
|
@ -30,9 +29,15 @@ class BaseChatAgent(ChatAgent, ABC):
|
|||
describe the agent's capabilities and how to interact with it."""
|
||||
return self._description
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
"""Handle incoming messages and return a response message."""
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
"""The types of messages that the agent produces."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
"""Handles incoming messages and returns a response."""
|
||||
...
|
||||
|
||||
async def run(
|
||||
|
@ -43,10 +48,12 @@ class BaseChatAgent(ChatAgent, ABC):
|
|||
termination_condition: TerminationCondition | None = None,
|
||||
) -> TaskResult:
|
||||
"""Run the agent with the given task and return the result."""
|
||||
group_chat = RoundRobinGroupChat(participants=[self])
|
||||
result = await group_chat.run(
|
||||
task=task,
|
||||
cancellation_token=cancellation_token,
|
||||
termination_condition=termination_condition,
|
||||
)
|
||||
return result
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
first_message = TextMessage(content=task, source="user")
|
||||
response = await self.on_messages([first_message], cancellation_token)
|
||||
messages: List[InnerMessage | ChatMessage] = [first_message]
|
||||
if response.inner_messages is not None:
|
||||
messages += response.inner_messages
|
||||
messages.append(response.chat_message)
|
||||
return TaskResult(messages=messages)
|
||||
|
|
|
@ -3,6 +3,7 @@ from typing import List, Sequence
|
|||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components.code_executor import CodeBlock, CodeExecutor, extract_markdown_code_blocks
|
||||
|
||||
from ..base import Response
|
||||
from ..messages import ChatMessage, TextMessage
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
|
@ -20,7 +21,12 @@ class CodeExecutorAgent(BaseChatAgent):
|
|||
super().__init__(name=name, description=description)
|
||||
self._code_executor = code_executor
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
"""The types of messages that the code executor agent produces."""
|
||||
return [TextMessage]
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
# Extract code blocks from the messages.
|
||||
code_blocks: List[CodeBlock] = []
|
||||
for msg in messages:
|
||||
|
@ -29,6 +35,6 @@ class CodeExecutorAgent(BaseChatAgent):
|
|||
if code_blocks:
|
||||
# Execute the code blocks.
|
||||
result = await self._code_executor.execute_code_blocks(code_blocks, cancellation_token=cancellation_token)
|
||||
return TextMessage(content=result.output, source=self.name)
|
||||
return Response(chat_message=TextMessage(content=result.output, source=self.name))
|
||||
else:
|
||||
return TextMessage(content="No code blocks found in the thread.", source=self.name)
|
||||
return Response(chat_message=TextMessage(content="No code blocks found in the thread.", source=self.name))
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
from ._chat_agent import ChatAgent
|
||||
from ._chat_agent import ChatAgent, Response
|
||||
from ._task import TaskResult, TaskRunner
|
||||
from ._team import Team
|
||||
from ._termination import TerminatedException, TerminationCondition
|
||||
|
||||
__all__ = [
|
||||
"ChatAgent",
|
||||
"Response",
|
||||
"Team",
|
||||
"TerminatedException",
|
||||
"TerminationCondition",
|
||||
|
|
|
@ -1,12 +1,24 @@
|
|||
from typing import Protocol, Sequence, runtime_checkable
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Protocol, Sequence, runtime_checkable
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
|
||||
from ..messages import ChatMessage
|
||||
from ..messages import ChatMessage, InnerMessage
|
||||
from ._task import TaskResult, TaskRunner
|
||||
from ._termination import TerminationCondition
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Response:
|
||||
"""A response from calling :meth:`ChatAgent.on_messages`."""
|
||||
|
||||
chat_message: ChatMessage
|
||||
"""A chat message produced by the agent as the response."""
|
||||
|
||||
inner_messages: List[InnerMessage] | None = None
|
||||
"""Inner messages produced by the agent."""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ChatAgent(TaskRunner, Protocol):
|
||||
"""Protocol for a chat agent."""
|
||||
|
@ -24,8 +36,13 @@ class ChatAgent(TaskRunner, Protocol):
|
|||
describe the agent's capabilities and how to interact with it."""
|
||||
...
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
"""Handle incoming messages and return a response message."""
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
"""The types of messages that the agent produces."""
|
||||
...
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
"""Handles incoming messages and returns a response."""
|
||||
...
|
||||
|
||||
async def run(
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Protocol, Sequence
|
|||
|
||||
from autogen_core.base import CancellationToken
|
||||
|
||||
from ..messages import ChatMessage
|
||||
from ..messages import ChatMessage, InnerMessage
|
||||
from ._termination import TerminationCondition
|
||||
|
||||
|
||||
|
@ -11,7 +11,7 @@ from ._termination import TerminationCondition
|
|||
class TaskResult:
|
||||
"""Result of running a task."""
|
||||
|
||||
messages: Sequence[ChatMessage]
|
||||
messages: Sequence[InnerMessage | ChatMessage]
|
||||
"""Messages produced by the task."""
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from typing import List
|
||||
|
||||
from autogen_core.components import Image
|
||||
from autogen_core.components import FunctionCall, Image
|
||||
from autogen_core.components.models import FunctionExecutionResult
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
@ -49,8 +50,26 @@ class ResetMessage(BaseMessage):
|
|||
"""The content for the reset message."""
|
||||
|
||||
|
||||
class ToolCallMessage(BaseMessage):
|
||||
"""A message signaling the use of tools."""
|
||||
|
||||
content: List[FunctionCall]
|
||||
"""The tool calls."""
|
||||
|
||||
|
||||
class ToolCallResultMessages(BaseMessage):
|
||||
"""A message signaling the results of tool calls."""
|
||||
|
||||
content: List[FunctionExecutionResult]
|
||||
"""The tool call results."""
|
||||
|
||||
|
||||
InnerMessage = ToolCallMessage | ToolCallResultMessages
|
||||
"""Messages for intra-agent monologues."""
|
||||
|
||||
|
||||
ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ResetMessage
|
||||
"""A message used by agents in a team."""
|
||||
"""Messages for agent-to-agent communication."""
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -60,5 +79,7 @@ __all__ = [
|
|||
"StopMessage",
|
||||
"HandoffMessage",
|
||||
"ResetMessage",
|
||||
"ToolCallMessage",
|
||||
"ToolCallResultMessages",
|
||||
"ChatMessage",
|
||||
]
|
||||
|
|
|
@ -15,7 +15,7 @@ from autogen_core.base import (
|
|||
from autogen_core.components import ClosureAgent, TypeSubscription
|
||||
|
||||
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
||||
from ...messages import ChatMessage, TextMessage
|
||||
from ...messages import ChatMessage, InnerMessage, TextMessage
|
||||
from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
from ._chat_agent_container import ChatAgentContainer
|
||||
|
@ -56,12 +56,13 @@ class BaseGroupChat(Team, ABC):
|
|||
def _create_participant_factory(
|
||||
self,
|
||||
parent_topic_type: str,
|
||||
output_topic_type: str,
|
||||
agent: ChatAgent,
|
||||
) -> Callable[[], ChatAgentContainer]:
|
||||
def _factory() -> ChatAgentContainer:
|
||||
id = AgentInstantiationContext.current_agent_id()
|
||||
assert id == AgentId(type=agent.name, key=self._team_id)
|
||||
container = ChatAgentContainer(parent_topic_type, agent)
|
||||
container = ChatAgentContainer(parent_topic_type, output_topic_type, agent)
|
||||
assert container.id == id
|
||||
return container
|
||||
|
||||
|
@ -85,6 +86,7 @@ class BaseGroupChat(Team, ABC):
|
|||
group_chat_manager_topic_type = group_chat_manager_agent_type.type
|
||||
group_topic_type = "round_robin_group_topic"
|
||||
team_topic_type = "team_topic"
|
||||
output_topic_type = "output_topic"
|
||||
|
||||
# Register participants.
|
||||
participant_topic_types: List[str] = []
|
||||
|
@ -97,7 +99,7 @@ class BaseGroupChat(Team, ABC):
|
|||
await ChatAgentContainer.register(
|
||||
runtime,
|
||||
type=agent_type,
|
||||
factory=self._create_participant_factory(group_topic_type, participant),
|
||||
factory=self._create_participant_factory(group_topic_type, output_topic_type, participant),
|
||||
)
|
||||
# Add subscriptions for the participant.
|
||||
await runtime.add_subscription(TypeSubscription(topic_type=topic_type, agent_type=agent_type))
|
||||
|
@ -129,22 +131,22 @@ class BaseGroupChat(Team, ABC):
|
|||
TypeSubscription(topic_type=team_topic_type, agent_type=group_chat_manager_agent_type.type)
|
||||
)
|
||||
|
||||
group_chat_messages: List[ChatMessage] = []
|
||||
output_messages: List[InnerMessage | ChatMessage] = []
|
||||
|
||||
async def collect_group_chat_messages(
|
||||
async def collect_output_messages(
|
||||
_runtime: AgentRuntime,
|
||||
id: AgentId,
|
||||
message: GroupChatPublishEvent,
|
||||
message: InnerMessage | ChatMessage,
|
||||
ctx: MessageContext,
|
||||
) -> None:
|
||||
group_chat_messages.append(message.agent_message)
|
||||
output_messages.append(message)
|
||||
|
||||
await ClosureAgent.register(
|
||||
runtime,
|
||||
type="collect_group_chat_messages",
|
||||
closure=collect_group_chat_messages,
|
||||
type="collect_output_messages",
|
||||
closure=collect_output_messages,
|
||||
subscriptions=lambda: [
|
||||
TypeSubscription(topic_type=group_topic_type, agent_type="collect_group_chat_messages"),
|
||||
TypeSubscription(topic_type=output_topic_type, agent_type="collect_output_messages"),
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -154,8 +156,10 @@ class BaseGroupChat(Team, ABC):
|
|||
# Run the team by publishing the task to the team topic and then requesting the result.
|
||||
team_topic_id = TopicId(type=team_topic_type, source=self._team_id)
|
||||
group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id)
|
||||
first_chat_message = TextMessage(content=task, source="user")
|
||||
output_messages.append(first_chat_message)
|
||||
await runtime.publish_message(
|
||||
GroupChatPublishEvent(agent_message=TextMessage(content=task, source="user")),
|
||||
GroupChatPublishEvent(agent_message=first_chat_message),
|
||||
topic_id=team_topic_id,
|
||||
)
|
||||
await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id)
|
||||
|
@ -164,4 +168,4 @@ class BaseGroupChat(Team, ABC):
|
|||
await runtime.stop_when_idle()
|
||||
|
||||
# Return the result.
|
||||
return TaskResult(messages=group_chat_messages)
|
||||
return TaskResult(messages=output_messages)
|
||||
|
|
|
@ -16,12 +16,14 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
|||
|
||||
Args:
|
||||
parent_topic_type (str): The topic type of the parent orchestrator.
|
||||
output_topic_type (str): The topic type for the output.
|
||||
agent (ChatAgent): The agent to delegate message handling to.
|
||||
"""
|
||||
|
||||
def __init__(self, parent_topic_type: str, agent: ChatAgent) -> None:
|
||||
def __init__(self, parent_topic_type: str, output_topic_type: str, agent: ChatAgent) -> None:
|
||||
super().__init__(description=agent.description)
|
||||
self._parent_topic_type = parent_topic_type
|
||||
self._output_topic_type = output_topic_type
|
||||
self._agent = agent
|
||||
self._message_buffer: List[ChatMessage] = []
|
||||
|
||||
|
@ -36,13 +38,27 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
|||
to the delegate agent and publish the response."""
|
||||
# Pass the messages in the buffer to the delegate agent.
|
||||
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)
|
||||
if not any(isinstance(response.chat_message, msg_type) for msg_type in self._agent.produced_message_types):
|
||||
raise ValueError(
|
||||
f"The agent {self._agent.name} produced an unexpected message type: {type(response)}. "
|
||||
f"Expected one of: {self._agent.produced_message_types}. "
|
||||
f"Check the agent's produced_message_types property."
|
||||
)
|
||||
|
||||
# Publish inner messages to the output topic.
|
||||
if response.inner_messages is not None:
|
||||
for inner_message in response.inner_messages:
|
||||
await self.publish_message(inner_message, topic_id=DefaultTopicId(type=self._output_topic_type))
|
||||
|
||||
# Publish the response.
|
||||
self._message_buffer.clear()
|
||||
await self.publish_message(
|
||||
GroupChatPublishEvent(agent_message=response, source=self.id),
|
||||
GroupChatPublishEvent(agent_message=response.chat_message, source=self.id),
|
||||
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
||||
)
|
||||
|
||||
# Publish the response to the output topic.
|
||||
await self.publish_message(response.chat_message, topic_id=DefaultTopicId(type=self._output_topic_type))
|
||||
|
||||
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
|
||||
raise ValueError(f"Unhandled message in agent container: {type(message)}")
|
||||
|
|
|
@ -86,6 +86,10 @@ class Swarm(BaseGroupChat):
|
|||
super().__init__(
|
||||
participants, termination_condition=termination_condition, group_chat_manager_class=SwarmGroupChatManager
|
||||
)
|
||||
# The first participant must be able to produce handoff messages.
|
||||
first_participant = self._participants[0]
|
||||
if HandoffMessage not in first_participant.produced_message_types:
|
||||
raise ValueError("The first participant must be able to produce a handoff messages.")
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
|
|
|
@ -7,7 +7,7 @@ import pytest
|
|||
from autogen_agentchat import EVENT_LOGGER_NAME
|
||||
from autogen_agentchat.agents import AssistantAgent, Handoff
|
||||
from autogen_agentchat.logging import FileLogHandler
|
||||
from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage
|
||||
from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessages
|
||||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components.tools import FunctionTool
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
|
@ -111,10 +111,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||
)
|
||||
result = await tool_use_agent.run("task")
|
||||
assert len(result.messages) == 3
|
||||
assert len(result.messages) == 4
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result.messages[1], TextMessage)
|
||||
assert isinstance(result.messages[2], StopMessage)
|
||||
assert isinstance(result.messages[1], ToolCallMessage)
|
||||
assert isinstance(result.messages[2], ToolCallResultMessages)
|
||||
assert isinstance(result.messages[3], TextMessage)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -158,8 +159,9 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||
handoffs=[handoff],
|
||||
)
|
||||
assert HandoffMessage in tool_use_agent.produced_message_types
|
||||
response = await tool_use_agent.on_messages(
|
||||
[TextMessage(content="task", source="user")], cancellation_token=CancellationToken()
|
||||
)
|
||||
assert isinstance(response, HandoffMessage)
|
||||
assert response.target == "agent2"
|
||||
assert isinstance(response.chat_message, HandoffMessage)
|
||||
assert response.chat_message.target == "agent2"
|
||||
|
|
|
@ -12,12 +12,15 @@ from autogen_agentchat.agents import (
|
|||
CodeExecutorAgent,
|
||||
Handoff,
|
||||
)
|
||||
from autogen_agentchat.base import Response
|
||||
from autogen_agentchat.logging import FileLogHandler
|
||||
from autogen_agentchat.messages import (
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallMessage,
|
||||
ToolCallResultMessages,
|
||||
)
|
||||
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination
|
||||
from autogen_agentchat.teams import (
|
||||
|
@ -62,14 +65,18 @@ class _EchoAgent(BaseChatAgent):
|
|||
super().__init__(name, description)
|
||||
self._last_message: str | None = None
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
return [TextMessage]
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
if len(messages) > 0:
|
||||
assert isinstance(messages[0], TextMessage)
|
||||
self._last_message = messages[0].content
|
||||
return TextMessage(content=messages[0].content, source=self.name)
|
||||
return Response(chat_message=TextMessage(content=messages[0].content, source=self.name))
|
||||
else:
|
||||
assert self._last_message is not None
|
||||
return TextMessage(content=self._last_message, source=self.name)
|
||||
return Response(chat_message=TextMessage(content=self._last_message, source=self.name))
|
||||
|
||||
|
||||
class _StopAgent(_EchoAgent):
|
||||
|
@ -78,11 +85,15 @@ class _StopAgent(_EchoAgent):
|
|||
self._count = 0
|
||||
self._stop_at = stop_at
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
return [TextMessage, StopMessage]
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
self._count += 1
|
||||
if self._count < self._stop_at:
|
||||
return await super().on_messages(messages, cancellation_token)
|
||||
return StopMessage(content="TERMINATE", source=self.name)
|
||||
return Response(chat_message=StopMessage(content="TERMINATE", source=self.name))
|
||||
|
||||
|
||||
def _pass_function(input: str) -> str:
|
||||
|
@ -222,11 +233,13 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
|
|||
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
||||
)
|
||||
|
||||
assert len(result.messages) == 4
|
||||
assert len(result.messages) == 6
|
||||
assert isinstance(result.messages[0], TextMessage) # task
|
||||
assert isinstance(result.messages[1], TextMessage) # tool use agent response
|
||||
assert isinstance(result.messages[2], TextMessage) # echo agent response
|
||||
assert isinstance(result.messages[3], StopMessage) # tool use agent response
|
||||
assert isinstance(result.messages[1], ToolCallMessage) # tool call
|
||||
assert isinstance(result.messages[2], ToolCallResultMessages) # tool call result
|
||||
assert isinstance(result.messages[3], TextMessage) # tool use agent response
|
||||
assert isinstance(result.messages[4], TextMessage) # echo agent response
|
||||
assert isinstance(result.messages[5], StopMessage) # tool use agent response
|
||||
|
||||
context = tool_use_agent._model_context # pyright: ignore
|
||||
assert context[0].content == "Write a program that prints 'Hello, world!'"
|
||||
|
@ -415,8 +428,16 @@ class _HandOffAgent(BaseChatAgent):
|
|||
super().__init__(name, description)
|
||||
self._next_agent = next_agent
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
return HandoffMessage(content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name)
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
return [HandoffMessage]
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
return Response(
|
||||
chat_message=HandoffMessage(
|
||||
content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -501,9 +522,11 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
|
|||
agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
|
||||
team = Swarm([agnet1, agent2])
|
||||
result = await team.run("task", termination_condition=StopMessageTermination())
|
||||
assert len(result.messages) == 5
|
||||
assert len(result.messages) == 7
|
||||
assert result.messages[0].content == "task"
|
||||
assert result.messages[1].content == "handoff to agent2"
|
||||
assert result.messages[2].content == "Transferred to agent1."
|
||||
assert result.messages[3].content == "Hello"
|
||||
assert result.messages[4].content == "TERMINATE"
|
||||
assert isinstance(result.messages[1], ToolCallMessage)
|
||||
assert isinstance(result.messages[2], ToolCallResultMessages)
|
||||
assert result.messages[3].content == "handoff to agent2"
|
||||
assert result.messages[4].content == "Transferred to agent1."
|
||||
assert result.messages[5].content == "Hello"
|
||||
assert result.messages[6].content == "TERMINATE"
|
||||
|
|
|
@ -248,9 +248,10 @@
|
|||
],
|
||||
"source": [
|
||||
"import asyncio\n",
|
||||
"from typing import Sequence\n",
|
||||
"from typing import List, Sequence\n",
|
||||
"\n",
|
||||
"from autogen_agentchat.agents import BaseChatAgent\n",
|
||||
"from autogen_agentchat.base import Response\n",
|
||||
"from autogen_agentchat.messages import (\n",
|
||||
" ChatMessage,\n",
|
||||
" StopMessage,\n",
|
||||
|
@ -262,11 +263,15 @@
|
|||
" def __init__(self, name: str) -> None:\n",
|
||||
" super().__init__(name, \"A human user.\")\n",
|
||||
"\n",
|
||||
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:\n",
|
||||
" @property\n",
|
||||
" def produced_message_types(self) -> List[type[ChatMessage]]:\n",
|
||||
" return [TextMessage, StopMessage]\n",
|
||||
"\n",
|
||||
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||||
" user_input = await asyncio.get_event_loop().run_in_executor(None, input, \"Enter your response: \")\n",
|
||||
" if \"TERMINATE\" in user_input:\n",
|
||||
" return StopMessage(content=\"User has terminated the conversation.\", source=self.name)\n",
|
||||
" return TextMessage(content=user_input, source=self.name)\n",
|
||||
" return Response(chat_message=StopMessage(content=\"User has terminated the conversation.\", source=self.name))\n",
|
||||
" return Response(chat_message=TextMessage(content=user_input, source=self.name))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"user_proxy_agent = UserProxyAgent(name=\"user_proxy_agent\")\n",
|
||||
|
@ -312,7 +317,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.6"
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -38,13 +38,14 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"import asyncio\n",
|
||||
"from typing import Sequence\n",
|
||||
"from typing import List, Sequence\n",
|
||||
"\n",
|
||||
"from autogen_agentchat.agents import (\n",
|
||||
" BaseChatAgent,\n",
|
||||
" CodingAssistantAgent,\n",
|
||||
" ToolUseAssistantAgent,\n",
|
||||
")\n",
|
||||
"from autogen_agentchat.base import Response\n",
|
||||
"from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage\n",
|
||||
"from autogen_agentchat.task import StopMessageTermination\n",
|
||||
"from autogen_agentchat.teams import SelectorGroupChat\n",
|
||||
|
@ -71,11 +72,15 @@
|
|||
" def __init__(self, name: str) -> None:\n",
|
||||
" super().__init__(name, \"A human user.\")\n",
|
||||
"\n",
|
||||
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:\n",
|
||||
" @property\n",
|
||||
" def produced_message_types(self) -> List[type[ChatMessage]]:\n",
|
||||
" return [TextMessage, StopMessage]\n",
|
||||
"\n",
|
||||
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||||
" user_input = await asyncio.get_event_loop().run_in_executor(None, input, \"Enter your response: \")\n",
|
||||
" if \"TERMINATE\" in user_input:\n",
|
||||
" return StopMessage(content=\"User has terminated the conversation.\", source=self.name)\n",
|
||||
" return TextMessage(content=user_input, source=self.name)"
|
||||
" return Response(chat_message=StopMessage(content=\"User has terminated the conversation.\", source=self.name))\n",
|
||||
" return Response(chat_message=TextMessage(content=user_input, source=self.name))"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -269,7 +274,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.6"
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
Loading…
Reference in New Issue