Compare commits

...

4 Commits

Author SHA1 Message Date
Eric Zhu 456e0525e5
Merge branch 'main' into agentchat-response 2024-10-30 10:25:04 -07:00
Xiaoyun Zhang e63fd17ed5
[.Net] use file-scope (#3997)
* use file-scope

* reformat
2024-10-30 10:05:58 -07:00
Eric Zhu 3ca80a8d82
Merge branch 'main' into agentchat-response 2024-10-30 10:00:15 -07:00
Ryan Sweet 51cd5b8d1f
interface inheritance examples (#3989)
changes to AgentBase and HostBuilderExtensions to enable leveraging handlers from composition (interfaces) vs inheritance... see HelloAgents sample for usage

closes #3928
is related to #3925
2024-10-30 09:51:01 -07:00
18 changed files with 997 additions and 885 deletions

View File

@ -221,13 +221,14 @@ dotnet_diagnostic.IDE0161.severity = warning # Use file-scoped namespace
csharp_style_var_elsewhere = true:suggestion # Prefer 'var' everywhere
csharp_prefer_simple_using_statement = true:suggestion
csharp_style_namespace_declarations = block_scoped:silent
csharp_style_namespace_declarations = file_scoped:warning
csharp_style_prefer_method_group_conversion = true:silent
csharp_style_prefer_top_level_statements = true:silent
csharp_style_prefer_primary_constructors = true:suggestion
csharp_style_expression_bodied_lambdas = true:silent
csharp_style_prefer_local_over_anonymous_function = true:suggestion
dotnet_diagnostic.CA2016.severity = suggestion
csharp_prefer_static_anonymous_function = true:suggestion
# disable check for generated code
[*.generated.cs]
@ -697,6 +698,7 @@ dotnet_style_prefer_compound_assignment = true:suggestion
dotnet_style_prefer_simplified_interpolation = true:suggestion
dotnet_style_prefer_collection_expression = when_types_loosely_match:suggestion
dotnet_style_namespace_match_folder = true:suggestion
dotnet_style_qualification_for_method = false:silent
[**/*.g.cs]
generated_code = true

View File

@ -125,7 +125,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AIModelClientHostingExtensi
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "sample", "sample", "{686480D7-8FEC-4ED3-9C5D-CEBE1057A7ED}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "HelloAgentState", "samples\Hello\HelloAgentState\HelloAgentState.csproj", "{64EF61E7-00A6-4E5E-9808-62E10993A0E5}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HelloAgentState", "samples\Hello\HelloAgentState\HelloAgentState.csproj", "{64EF61E7-00A6-4E5E-9808-62E10993A0E5}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution

View File

@ -18,10 +18,11 @@ namespace Hello
[TopicSubscription("HelloAgents")]
public class HelloAgent(
IAgentContext context,
[FromKeyedServices("EventTypes")] EventTypes typeRegistry) : ConsoleAgent(
[FromKeyedServices("EventTypes")] EventTypes typeRegistry) : AgentBase(
context,
typeRegistry),
ISayHello,
IHandleConsole,
IHandle<NewMessageReceived>,
IHandle<ConversationClosed>
{

View File

@ -5,45 +5,44 @@ using System;
using System.Collections.Generic;
using System.Linq;
namespace AutoGen
namespace AutoGen;
public static class LLMConfigAPI
{
public static class LLMConfigAPI
public static IEnumerable<ILLMConfig> GetOpenAIConfigList(
string apiKey,
IEnumerable<string>? modelIDs = null)
{
public static IEnumerable<ILLMConfig> GetOpenAIConfigList(
string apiKey,
IEnumerable<string>? modelIDs = null)
var models = modelIDs ?? new[]
{
var models = modelIDs ?? new[]
{
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-4-32k",
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-4-1106-preview",
};
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-4-32k",
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-4-1106-preview",
};
return models.Select(modelId => new OpenAIConfig(apiKey, modelId));
}
return models.Select(modelId => new OpenAIConfig(apiKey, modelId));
}
public static IEnumerable<ILLMConfig> GetAzureOpenAIConfigList(
string endpoint,
string apiKey,
IEnumerable<string> deploymentNames)
{
return deploymentNames.Select(deploymentName => new AzureOpenAIConfig(endpoint, deploymentName, apiKey));
}
public static IEnumerable<ILLMConfig> GetAzureOpenAIConfigList(
string endpoint,
string apiKey,
IEnumerable<string> deploymentNames)
{
return deploymentNames.Select(deploymentName => new AzureOpenAIConfig(endpoint, deploymentName, apiKey));
}
/// <summary>
/// Get a list of LLMConfig objects from a JSON file.
/// </summary>
internal static IEnumerable<ILLMConfig> ConfigListFromJson(
string filePath,
IEnumerable<string>? filterModels = null)
{
// Disable this API from documentation for now.
throw new NotImplementedException();
}
/// <summary>
/// Get a list of LLMConfig objects from a JSON file.
/// </summary>
internal static IEnumerable<ILLMConfig> ConfigListFromJson(
string filePath,
IEnumerable<string>? filterModels = null)
{
// Disable this API from documentation for now.
throw new NotImplementedException();
}
}

View File

@ -1,4 +1,5 @@
using System.Diagnostics;
using System.Reflection;
using System.Text;
using System.Text.Json;
using System.Threading.Channels;
@ -17,6 +18,8 @@ public abstract class AgentBase : IAgentBase
private readonly Channel<object> _mailbox = Channel.CreateUnbounded<object>();
private readonly IAgentContext _context;
public string Route { get; set; } = "base";
protected internal ILogger Logger => _context.Logger;
public IAgentContext Context => _context;
protected readonly EventTypes EventTypes;
@ -212,14 +215,39 @@ public abstract class AgentBase : IAgentBase
public Task CallHandler(CloudEvent item)
{
// Only send the event to the handler if the agent type is handling that type
if (EventTypes.EventsMap[GetType()].Contains(item.Type))
// foreach of the keys in the EventTypes.EventsMap[] if it contains the item.type
foreach (var key in EventTypes.EventsMap.Keys)
{
var payload = item.ProtoData.Unpack(EventTypes.TypeRegistry);
var convertedPayload = Convert.ChangeType(payload, EventTypes.Types[item.Type]);
var genericInterfaceType = typeof(IHandle<>).MakeGenericType(EventTypes.Types[item.Type]);
var methodInfo = genericInterfaceType.GetMethod(nameof(IHandle<object>.Handle)) ?? throw new InvalidOperationException($"Method not found on type {genericInterfaceType.FullName}");
return methodInfo.Invoke(this, [payload]) as Task ?? Task.CompletedTask;
if (EventTypes.EventsMap[key].Contains(item.Type))
{
var payload = item.ProtoData.Unpack(EventTypes.TypeRegistry);
var convertedPayload = Convert.ChangeType(payload, EventTypes.Types[item.Type]);
var genericInterfaceType = typeof(IHandle<>).MakeGenericType(EventTypes.Types[item.Type]);
MethodInfo methodInfo;
try
{
// check that our target actually implements this interface, otherwise call the default static
if (genericInterfaceType.IsAssignableFrom(this.GetType()))
{
methodInfo = genericInterfaceType.GetMethod(nameof(IHandle<object>.Handle), BindingFlags.Public | BindingFlags.Instance)
?? throw new InvalidOperationException($"Method not found on type {genericInterfaceType.FullName}");
return methodInfo.Invoke(this, [payload]) as Task ?? Task.CompletedTask;
}
else
{
// The error here is we have registered for an event that we do not have code to listen to
throw new InvalidOperationException($"No handler found for event '{item.Type}'; expecting IHandle<{item.Type}> implementation.");
}
}
catch (Exception ex)
{
Logger.LogError(ex, $"Error invoking method {nameof(IHandle<object>.Handle)}");
throw; // TODO: ?
}
}
}
return Task.CompletedTask;
}

View File

@ -0,0 +1,46 @@
using Microsoft.AutoGen.Abstractions;
namespace Microsoft.AutoGen.Agents;
public interface IHandleConsole : IHandle<Output>, IHandle<Input>
{
string Route { get; }
AgentId AgentId { get; }
ValueTask PublishEvent(CloudEvent item);
async Task IHandle<Output>.Handle(Output item)
{
// Assuming item has a property `Message` that we want to write to the console
Console.WriteLine(item.Message);
await ProcessOutput(item.Message);
var evt = new OutputWritten
{
Route = "console"
}.ToCloudEvent(AgentId.Key);
await PublishEvent(evt);
}
async Task IHandle<Input>.Handle(Input item)
{
Console.WriteLine("Please enter input:");
string content = Console.ReadLine() ?? string.Empty;
await ProcessInput(content);
var evt = new InputProcessed
{
Route = "console"
}.ToCloudEvent(AgentId.Key);
await PublishEvent(evt);
}
static Task ProcessOutput(string message)
{
// Implement your output processing logic here
return Task.CompletedTask;
}
static Task<string> ProcessInput(string message)
{
// Implement your input processing logic here
return Task.FromResult(message);
}
}

View File

@ -83,6 +83,7 @@ public sealed class GrpcAgentWorkerRuntime : IHostedService, IDisposable, IAgent
message.Response.RequestId = request.OriginalRequestId;
request.Agent.ReceiveMessage(message);
break;
case Message.MessageOneofCase.RegisterAgentTypeResponse:
if (!message.RegisterAgentTypeResponse.Success)
{

View File

@ -71,7 +71,52 @@ public static class HostBuilderExtensions
.Where(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IHandle<>))
.Select(i => (GetMessageDescriptor(i.GetGenericArguments().First())?.FullName ?? "")).ToHashSet()))
.ToDictionary(item => item.t, item => item.Item2);
// if the assembly contains any interfaces of type IHandler, then add all the methods of the interface to the eventsMap
var handlersMap = AppDomain.CurrentDomain.GetAssemblies()
.SelectMany(assembly => assembly.GetTypes())
.Where(type => ReflectionHelper.IsSubclassOfGeneric(type, typeof(AgentBase)) && !type.IsAbstract)
.Select(t => (t, t.GetMethods()
.Where(m => m.Name == "Handle")
.Select(m => (GetMessageDescriptor(m.GetParameters().First().ParameterType)?.FullName ?? "")).ToHashSet()))
.ToDictionary(item => item.t, item => item.Item2);
// get interfaces implemented by the agent and get the methods of the interface if they are named Handle
var ifaceHandlersMap = AppDomain.CurrentDomain.GetAssemblies()
.SelectMany(assembly => assembly.GetTypes())
.Where(type => ReflectionHelper.IsSubclassOfGeneric(type, typeof(AgentBase)) && !type.IsAbstract)
.Select(t => t.GetInterfaces()
.Select(i => (t, i, i.GetMethods()
.Where(m => m.Name == "Handle")
.Select(m => (GetMessageDescriptor(m.GetParameters().First().ParameterType)?.FullName ?? ""))
//to dictionary of type t and paramter type of the method
.ToDictionary(m => m, m => m).Keys.ToHashSet())).ToList());
// for each item in ifaceHandlersMap, add the handlers to eventsMap with item as the key
foreach (var item in ifaceHandlersMap)
{
foreach (var iface in item)
{
if (eventsMap.TryGetValue(iface.Item2, out var events))
{
events.UnionWith(iface.Item3);
}
else
{
eventsMap[iface.Item2] = iface.Item3;
}
}
}
// merge the handlersMap into the eventsMap
foreach (var item in handlersMap)
{
if (eventsMap.TryGetValue(item.Key, out var events))
{
events.UnionWith(item.Value);
}
else
{
eventsMap[item.Key] = item.Value;
}
}
return new EventTypes(typeRegistry, types, eventsMap);
});
return new AgentApplicationBuilder(builder);

View File

@ -1,20 +1,19 @@
using Google.Protobuf;
using Microsoft.AutoGen.Abstractions;
namespace Microsoft.AutoGen.Agents
{
public interface IAgentBase
{
// Properties
AgentId AgentId { get; }
IAgentContext Context { get; }
namespace Microsoft.AutoGen.Agents;
// Methods
Task CallHandler(CloudEvent item);
Task<RpcResponse> HandleRequest(RpcRequest request);
void ReceiveMessage(Message message);
Task Store(AgentState state);
Task<T> Read<T>(AgentId agentId) where T : IMessage, new();
ValueTask PublishEvent(CloudEvent item);
}
public interface IAgentBase
{
// Properties
AgentId AgentId { get; }
IAgentContext Context { get; }
// Methods
Task CallHandler(CloudEvent item);
Task<RpcResponse> HandleRequest(RpcRequest request);
void ReceiveMessage(Message message);
Task Store(AgentState state);
Task<T> Read<T>(AgentId agentId) where T : IMessage, new();
ValueTask PublishEvent(CloudEvent item);
}

View File

@ -1,33 +1,32 @@
using Microsoft.Extensions.AI;
namespace Microsoft.Extensions.Hosting
{
public static class AIModelClient
{
public static IHostApplicationBuilder AddChatCompletionService(this IHostApplicationBuilder builder, string serviceName)
{
var pipeline = (ChatClientBuilder pipeline) => pipeline
.UseLogging()
.UseFunctionInvocation()
.UseOpenTelemetry(configure: c => c.EnableSensitiveData = true);
namespace Microsoft.Extensions.Hosting;
if (builder.Configuration[$"{serviceName}:ModelType"] == "ollama")
{
builder.AddOllamaChatClient(serviceName, pipeline);
}
else if (builder.Configuration[$"{serviceName}:ModelType"] == "openai" || builder.Configuration[$"{serviceName}:ModelType"] == "azureopenai")
{
builder.AddOpenAIChatClient(serviceName, pipeline);
}
else if (builder.Configuration[$"{serviceName}:ModelType"] == "azureaiinference")
{
builder.AddAzureChatClient(serviceName, pipeline);
}
else
{
throw new InvalidOperationException("Did not find a valid model implementation for the given service name ${serviceName}, valid supported implemenation types are ollama, openai, azureopenai, azureaiinference");
}
return builder;
public static class AIModelClient
{
public static IHostApplicationBuilder AddChatCompletionService(this IHostApplicationBuilder builder, string serviceName)
{
var pipeline = (ChatClientBuilder pipeline) => pipeline
.UseLogging()
.UseFunctionInvocation()
.UseOpenTelemetry(configure: c => c.EnableSensitiveData = true);
if (builder.Configuration[$"{serviceName}:ModelType"] == "ollama")
{
builder.AddOllamaChatClient(serviceName, pipeline);
}
else if (builder.Configuration[$"{serviceName}:ModelType"] == "openai" || builder.Configuration[$"{serviceName}:ModelType"] == "azureopenai")
{
builder.AddOpenAIChatClient(serviceName, pipeline);
}
else if (builder.Configuration[$"{serviceName}:ModelType"] == "azureaiinference")
{
builder.AddAzureChatClient(serviceName, pipeline);
}
else
{
throw new InvalidOperationException("Did not find a valid model implementation for the given service name ${serviceName}, valid supported implemenation types are ollama, openai, azureopenai, azureaiinference");
}
return builder;
}
}

View File

@ -14,206 +14,205 @@ using FluentAssertions;
using OpenAI;
using Xunit.Abstractions;
namespace AutoGen.OpenAI.Tests
namespace AutoGen.OpenAI.Tests;
public partial class MathClassTest
{
public partial class MathClassTest
private readonly ITestOutputHelper _output;
// as of 2024-05-20, aoai return 500 error when round > 1
// I'm pretty sure that round > 5 was supported before
// So this is probably some wield regression on aoai side
// I'll keep this test case here for now, plus setting round to 1
// so the test can still pass.
// In the future, we should rewind this test case to round > 1 (previously was 5)
private int round = 1;
public MathClassTest(ITestOutputHelper output)
{
private readonly ITestOutputHelper _output;
_output = output;
}
// as of 2024-05-20, aoai return 500 error when round > 1
// I'm pretty sure that round > 5 was supported before
// So this is probably some wield regression on aoai side
// I'll keep this test case here for now, plus setting round to 1
// so the test can still pass.
// In the future, we should rewind this test case to round > 1 (previously was 5)
private int round = 1;
public MathClassTest(ITestOutputHelper output)
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
{
try
{
_output = output;
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
_output.WriteLine(reply.FormatMessage());
return Task.FromResult(reply);
}
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
catch (Exception)
{
try
_output.WriteLine("Request failed");
_output.WriteLine($"agent name: {agent.Name}");
foreach (var message in messages)
{
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
_output.WriteLine(reply.FormatMessage());
return Task.FromResult(reply);
}
catch (Exception)
{
_output.WriteLine("Request failed");
_output.WriteLine($"agent name: {agent.Name}");
foreach (var message in messages)
{
_output.WriteLine(message.FormatMessage());
}
throw;
_output.WriteLine(message.FormatMessage());
}
throw;
}
[FunctionAttribute]
public async Task<string> CreateMathQuestion(string question, int question_index)
{
return $@"[MATH_QUESTION]
}
[FunctionAttribute]
public async Task<string> CreateMathQuestion(string question, int question_index)
{
return $@"[MATH_QUESTION]
Question {question_index}:
{question}
Student, please answer";
}
}
[FunctionAttribute]
public async Task<string> AnswerQuestion(string answer)
{
return $@"[MATH_ANSWER]
[FunctionAttribute]
public async Task<string> AnswerQuestion(string answer)
{
return $@"[MATH_ANSWER]
The answer is {answer}
teacher please check answer";
}
}
[FunctionAttribute]
public async Task<string> AnswerIsCorrect(string message)
{
return $@"[ANSWER_IS_CORRECT]
[FunctionAttribute]
public async Task<string> AnswerIsCorrect(string message)
{
return $@"[ANSWER_IS_CORRECT]
{message}
please update progress";
}
}
[FunctionAttribute]
public async Task<string> UpdateProgress(int correctAnswerCount)
[FunctionAttribute]
public async Task<string> UpdateProgress(int correctAnswerCount)
{
if (correctAnswerCount >= this.round)
{
if (correctAnswerCount >= this.round)
{
return $@"[UPDATE_PROGRESS]
return $@"[UPDATE_PROGRESS]
{GroupChatExtension.TERMINATE}";
}
else
{
return $@"[UPDATE_PROGRESS]
}
else
{
return $@"[UPDATE_PROGRESS]
the number of resolved question is {correctAnswerCount}
teacher, please create the next math question";
}
}
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task OpenAIAgentMathChatTestAsync()
{
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
var openaiClient = new AzureOpenAIClient(new Uri(endPoint), new ApiKeyCredential(key));
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task OpenAIAgentMathChatTestAsync()
{
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
var openaiClient = new AzureOpenAIClient(new Uri(endPoint), new ApiKeyCredential(key));
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
var adminFunctionMiddleware = new FunctionCallMiddleware(
functions: [this.UpdateProgressFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
});
var admin = new OpenAIChatAgent(
chatClient: openaiClient.GetChatClient(deployName),
name: "Admin",
systemMessage: $@"You are admin. You update progress after each question is answered.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(adminFunctionMiddleware)
.RegisterMiddleware(Print);
var adminFunctionMiddleware = new FunctionCallMiddleware(
functions: [this.UpdateProgressFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
});
var admin = new OpenAIChatAgent(
chatClient: openaiClient.GetChatClient(deployName),
name: "Admin",
systemMessage: $@"You are admin. You update progress after each question is answered.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(adminFunctionMiddleware)
.RegisterMiddleware(Print);
var groupAdmin = new OpenAIChatAgent(
chatClient: openaiClient.GetChatClient(deployName),
name: "GroupAdmin",
systemMessage: "You are group admin. You manage the group chat.")
.RegisterMessageConnector()
.RegisterMiddleware(Print);
await RunMathChatAsync(teacher, student, admin, groupAdmin);
}
var groupAdmin = new OpenAIChatAgent(
chatClient: openaiClient.GetChatClient(deployName),
name: "GroupAdmin",
systemMessage: "You are group admin. You manage the group chat.")
.RegisterMessageConnector()
.RegisterMiddleware(Print);
await RunMathChatAsync(teacher, student, admin, groupAdmin);
}
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
});
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
});
var teacher = new OpenAIChatAgent(
chatClient: client.GetChatClient(model),
name: "Teacher",
systemMessage: @"You are a preschool math teacher.
var teacher = new OpenAIChatAgent(
chatClient: client.GetChatClient(model),
name: "Teacher",
systemMessage: @"You are a preschool math teacher.
You create math question and ask student to answer it.
Then you check if the answer is correct.
If the answer is wrong, you ask student to fix it")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
return teacher;
}
return teacher;
}
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.AnswerQuestionFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
});
var student = new OpenAIChatAgent(
chatClient: client.GetChatClient(model),
name: "Student",
systemMessage: @"You are a student. You answer math question from teacher.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.AnswerQuestionFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
});
var student = new OpenAIChatAgent(
chatClient: client.GetChatClient(model),
name: "Student",
systemMessage: @"You are a student. You answer math question from teacher.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
return student;
}
return student;
}
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
{
var teacher2Student = Transition.Create(teacher, student);
var student2Teacher = Transition.Create(student, teacher);
var teacher2Admin = Transition.Create(teacher, admin);
var admin2Teacher = Transition.Create(admin, teacher);
var workflow = new Graph(
[
teacher2Student,
student2Teacher,
teacher2Admin,
admin2Teacher,
]);
var group = new GroupChat(
workflow: workflow,
members: [
admin,
teacher,
student,
],
admin: groupAdmin);
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
{
var teacher2Student = Transition.Create(teacher, student);
var student2Teacher = Transition.Create(student, teacher);
var teacher2Admin = Transition.Create(teacher, admin);
var admin2Teacher = Transition.Create(admin, teacher);
var workflow = new Graph(
[
teacher2Student,
student2Teacher,
teacher2Admin,
admin2Teacher,
]);
var group = new GroupChat(
workflow: workflow,
members: [
admin,
teacher,
student,
],
admin: groupAdmin);
var groupChatManager = new GroupChatManager(group);
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
var groupChatManager = new GroupChatManager(group);
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
// check if there's terminate chat message from admin
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
.Count()
.Should().Be(1);
}
// check if there's terminate chat message from admin
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
.Count()
.Should().Be(1);
}
}

View File

@ -13,214 +13,213 @@ using Azure.AI.OpenAI;
using FluentAssertions;
using Xunit.Abstractions;
namespace AutoGen.OpenAI.V1.Tests
namespace AutoGen.OpenAI.V1.Tests;
public partial class MathClassTest
{
public partial class MathClassTest
private readonly ITestOutputHelper _output;
// as of 2024-05-20, aoai return 500 error when round > 1
// I'm pretty sure that round > 5 was supported before
// So this is probably some wield regression on aoai side
// I'll keep this test case here for now, plus setting round to 1
// so the test can still pass.
// In the future, we should rewind this test case to round > 1 (previously was 5)
private int round = 1;
public MathClassTest(ITestOutputHelper output)
{
private readonly ITestOutputHelper _output;
_output = output;
}
// as of 2024-05-20, aoai return 500 error when round > 1
// I'm pretty sure that round > 5 was supported before
// So this is probably some wield regression on aoai side
// I'll keep this test case here for now, plus setting round to 1
// so the test can still pass.
// In the future, we should rewind this test case to round > 1 (previously was 5)
private int round = 1;
public MathClassTest(ITestOutputHelper output)
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
{
try
{
_output = output;
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
_output.WriteLine(reply.FormatMessage());
return Task.FromResult(reply);
}
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
catch (Exception)
{
try
_output.WriteLine("Request failed");
_output.WriteLine($"agent name: {agent.Name}");
foreach (var message in messages)
{
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
_output.WriteLine(reply.FormatMessage());
return Task.FromResult(reply);
}
catch (Exception)
{
_output.WriteLine("Request failed");
_output.WriteLine($"agent name: {agent.Name}");
foreach (var message in messages)
if (message is IMessage<object> envelope)
{
if (message is IMessage<object> envelope)
{
var json = JsonSerializer.Serialize(envelope.Content, new JsonSerializerOptions { WriteIndented = true });
_output.WriteLine(json);
}
var json = JsonSerializer.Serialize(envelope.Content, new JsonSerializerOptions { WriteIndented = true });
_output.WriteLine(json);
}
throw;
}
throw;
}
[FunctionAttribute]
public async Task<string> CreateMathQuestion(string question, int question_index)
{
return $@"[MATH_QUESTION]
}
[FunctionAttribute]
public async Task<string> CreateMathQuestion(string question, int question_index)
{
return $@"[MATH_QUESTION]
Question {question_index}:
{question}
Student, please answer";
}
}
[FunctionAttribute]
public async Task<string> AnswerQuestion(string answer)
{
return $@"[MATH_ANSWER]
[FunctionAttribute]
public async Task<string> AnswerQuestion(string answer)
{
return $@"[MATH_ANSWER]
The answer is {answer}
teacher please check answer";
}
}
[FunctionAttribute]
public async Task<string> AnswerIsCorrect(string message)
{
return $@"[ANSWER_IS_CORRECT]
[FunctionAttribute]
public async Task<string> AnswerIsCorrect(string message)
{
return $@"[ANSWER_IS_CORRECT]
{message}
please update progress";
}
}
[FunctionAttribute]
public async Task<string> UpdateProgress(int correctAnswerCount)
[FunctionAttribute]
public async Task<string> UpdateProgress(int correctAnswerCount)
{
if (correctAnswerCount >= this.round)
{
if (correctAnswerCount >= this.round)
{
return $@"[UPDATE_PROGRESS]
return $@"[UPDATE_PROGRESS]
{GroupChatExtension.TERMINATE}";
}
else
{
return $@"[UPDATE_PROGRESS]
}
else
{
return $@"[UPDATE_PROGRESS]
the number of resolved question is {correctAnswerCount}
teacher, please create the next math question";
}
}
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task OpenAIAgentMathChatTestAsync()
{
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
var openaiClient = new OpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(key));
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task OpenAIAgentMathChatTestAsync()
{
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
var openaiClient = new OpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(key));
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
var adminFunctionMiddleware = new FunctionCallMiddleware(
functions: [this.UpdateProgressFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
});
var admin = new OpenAIChatAgent(
openAIClient: openaiClient,
modelName: deployName,
name: "Admin",
systemMessage: $@"You are admin. You update progress after each question is answered.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(adminFunctionMiddleware)
.RegisterMiddleware(Print);
var adminFunctionMiddleware = new FunctionCallMiddleware(
functions: [this.UpdateProgressFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
});
var admin = new OpenAIChatAgent(
openAIClient: openaiClient,
modelName: deployName,
name: "Admin",
systemMessage: $@"You are admin. You update progress after each question is answered.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(adminFunctionMiddleware)
.RegisterMiddleware(Print);
var groupAdmin = new OpenAIChatAgent(
openAIClient: openaiClient,
modelName: deployName,
name: "GroupAdmin",
systemMessage: "You are group admin. You manage the group chat.")
.RegisterMessageConnector()
.RegisterMiddleware(Print);
await RunMathChatAsync(teacher, student, admin, groupAdmin);
}
var groupAdmin = new OpenAIChatAgent(
openAIClient: openaiClient,
modelName: deployName,
name: "GroupAdmin",
systemMessage: "You are group admin. You manage the group chat.")
.RegisterMessageConnector()
.RegisterMiddleware(Print);
await RunMathChatAsync(teacher, student, admin, groupAdmin);
}
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
});
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
});
var teacher = new OpenAIChatAgent(
openAIClient: client,
name: "Teacher",
systemMessage: @"You are a preschool math teacher.
var teacher = new OpenAIChatAgent(
openAIClient: client,
name: "Teacher",
systemMessage: @"You are a preschool math teacher.
You create math question and ask student to answer it.
Then you check if the answer is correct.
If the answer is wrong, you ask student to fix it",
modelName: model)
.RegisterMiddleware(Print)
.RegisterMiddleware(new OpenAIChatRequestMessageConnector())
.RegisterMiddleware(functionCallMiddleware);
modelName: model)
.RegisterMiddleware(Print)
.RegisterMiddleware(new OpenAIChatRequestMessageConnector())
.RegisterMiddleware(functionCallMiddleware);
return teacher;
}
return teacher;
}
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.AnswerQuestionFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
});
var student = new OpenAIChatAgent(
openAIClient: client,
name: "Student",
modelName: model,
systemMessage: @"You are a student. You answer math question from teacher.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.AnswerQuestionFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
});
var student = new OpenAIChatAgent(
openAIClient: client,
name: "Student",
modelName: model,
systemMessage: @"You are a student. You answer math question from teacher.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
return student;
}
return student;
}
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
{
var teacher2Student = Transition.Create(teacher, student);
var student2Teacher = Transition.Create(student, teacher);
var teacher2Admin = Transition.Create(teacher, admin);
var admin2Teacher = Transition.Create(admin, teacher);
var workflow = new Graph(
[
teacher2Student,
student2Teacher,
teacher2Admin,
admin2Teacher,
]);
var group = new GroupChat(
workflow: workflow,
members: [
admin,
teacher,
student,
],
admin: groupAdmin);
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
{
var teacher2Student = Transition.Create(teacher, student);
var student2Teacher = Transition.Create(student, teacher);
var teacher2Admin = Transition.Create(teacher, admin);
var admin2Teacher = Transition.Create(admin, teacher);
var workflow = new Graph(
[
teacher2Student,
student2Teacher,
teacher2Admin,
admin2Teacher,
]);
var group = new GroupChat(
workflow: workflow,
members: [
admin,
teacher,
student,
],
admin: groupAdmin);
var groupChatManager = new GroupChatManager(group);
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
var groupChatManager = new GroupChatManager(group);
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
// check if there's terminate chat message from admin
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
.Count()
.Should().Be(1);
}
// check if there's terminate chat message from admin
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
.Count()
.Should().Be(1);
}
}

View File

@ -3,85 +3,84 @@
using AutoGen.SourceGenerator.Template; // Needed for FunctionCallTemplate
using Xunit; // Needed for Fact and Assert
namespace AutoGen.SourceGenerator.Tests
namespace AutoGen.SourceGenerator.Tests;
public class FunctionCallTemplateEncodingTests
{
public class FunctionCallTemplateEncodingTests
[Fact]
public void FunctionDescription_Should_Encode_DoubleQuotes()
{
[Fact]
public void FunctionDescription_Should_Encode_DoubleQuotes()
// Arrange
var functionContracts = new List<SourceGeneratorFunctionContract>
{
// Arrange
var functionContracts = new List<SourceGeneratorFunctionContract>
new SourceGeneratorFunctionContract
{
new SourceGeneratorFunctionContract
Name = "TestFunction",
Description = "This is a \"test\" function",
Parameters = new SourceGeneratorParameterContract[]
{
Name = "TestFunction",
Description = "This is a \"test\" function",
Parameters = new SourceGeneratorParameterContract[]
new SourceGeneratorParameterContract
{
new SourceGeneratorParameterContract
{
Name = "param1",
Description = "This is a \"parameter\" description",
Type = "string",
IsOptional = false
}
},
ReturnType = "void"
}
};
Name = "param1",
Description = "This is a \"parameter\" description",
Type = "string",
IsOptional = false
}
},
ReturnType = "void"
}
};
var template = new FunctionCallTemplate
{
NameSpace = "TestNamespace",
ClassName = "TestClass",
FunctionContracts = functionContracts
};
// Act
var result = template.TransformText();
// Assert
Assert.Contains("Description = @\"This is a \"\"test\"\" function\"", result);
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
}
[Fact]
public void ParameterDescription_Should_Encode_DoubleQuotes()
var template = new FunctionCallTemplate
{
// Arrange
var functionContracts = new List<SourceGeneratorFunctionContract>
NameSpace = "TestNamespace",
ClassName = "TestClass",
FunctionContracts = functionContracts
};
// Act
var result = template.TransformText();
// Assert
Assert.Contains("Description = @\"This is a \"\"test\"\" function\"", result);
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
}
[Fact]
public void ParameterDescription_Should_Encode_DoubleQuotes()
{
// Arrange
var functionContracts = new List<SourceGeneratorFunctionContract>
{
new SourceGeneratorFunctionContract
{
new SourceGeneratorFunctionContract
Name = "TestFunction",
Description = "This is a test function",
Parameters = new SourceGeneratorParameterContract[]
{
Name = "TestFunction",
Description = "This is a test function",
Parameters = new SourceGeneratorParameterContract[]
new SourceGeneratorParameterContract
{
new SourceGeneratorParameterContract
{
Name = "param1",
Description = "This is a \"parameter\" description",
Type = "string",
IsOptional = false
}
},
ReturnType = "void"
}
};
Name = "param1",
Description = "This is a \"parameter\" description",
Type = "string",
IsOptional = false
}
},
ReturnType = "void"
}
};
var template = new FunctionCallTemplate
{
NameSpace = "TestNamespace",
ClassName = "TestClass",
FunctionContracts = functionContracts
};
var template = new FunctionCallTemplate
{
NameSpace = "TestNamespace",
ClassName = "TestClass",
FunctionContracts = functionContracts
};
// Act
var result = template.TransformText();
// Act
var result = template.TransformText();
// Assert
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
}
// Assert
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
}
}

View File

@ -10,122 +10,121 @@ using FluentAssertions;
using OpenAI.Chat;
using Xunit;
namespace AutoGen.SourceGenerator.Tests
namespace AutoGen.SourceGenerator.Tests;
public class FunctionExample
{
public class FunctionExample
private readonly FunctionExamples functionExamples = new FunctionExamples();
private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions
{
private readonly FunctionExamples functionExamples = new FunctionExamples();
private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions
WriteIndented = true,
};
[Fact]
public void Add_Test()
{
var args = new
{
WriteIndented = true,
a = 1,
b = 2,
};
[Fact]
public void Add_Test()
this.VerifyFunction(functionExamples.AddWrapper, args, 3);
this.VerifyFunctionDefinition(functionExamples.AddFunctionContract.ToChatTool());
}
[Fact]
public void Sum_Test()
{
var args = new
{
var args = new
args = new double[] { 1, 2, 3 },
};
this.VerifyFunction(functionExamples.SumWrapper, args, 6.0);
this.VerifyFunctionDefinition(functionExamples.SumFunctionContract.ToChatTool());
}
[Fact]
public async Task DictionaryToString_Test()
{
var args = new
{
xargs = new Dictionary<string, string>
{
a = 1,
b = 2,
};
{ "a", "1" },
{ "b", "2" },
},
};
this.VerifyFunction(functionExamples.AddWrapper, args, 3);
this.VerifyFunctionDefinition(functionExamples.AddFunctionContract.ToChatTool());
}
await this.VerifyAsyncFunction(functionExamples.DictionaryToStringAsyncWrapper, args, JsonSerializer.Serialize(args.xargs, jsonSerializerOptions));
this.VerifyFunctionDefinition(functionExamples.DictionaryToStringAsyncFunctionContract.ToChatTool());
}
[Fact]
public void Sum_Test()
[Fact]
public async Task TopLevelFunctionExampleAddTestAsync()
{
var example = new TopLevelStatementFunctionExample();
var args = new
{
var args = new
{
args = new double[] { 1, 2, 3 },
};
a = 1,
b = 2,
};
this.VerifyFunction(functionExamples.SumWrapper, args, 6.0);
this.VerifyFunctionDefinition(functionExamples.SumFunctionContract.ToChatTool());
}
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
}
[Fact]
public async Task DictionaryToString_Test()
[Fact]
public async Task FilescopeFunctionExampleAddTestAsync()
{
var example = new FilescopeNamespaceFunctionExample();
var args = new
{
var args = new
{
xargs = new Dictionary<string, string>
{
{ "a", "1" },
{ "b", "2" },
},
};
a = 1,
b = 2,
};
await this.VerifyAsyncFunction(functionExamples.DictionaryToStringAsyncWrapper, args, JsonSerializer.Serialize(args.xargs, jsonSerializerOptions));
this.VerifyFunctionDefinition(functionExamples.DictionaryToStringAsyncFunctionContract.ToChatTool());
}
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
}
[Fact]
public async Task TopLevelFunctionExampleAddTestAsync()
[Fact]
public void Query_Test()
{
var args = new
{
var example = new TopLevelStatementFunctionExample();
var args = new
{
a = 1,
b = 2,
};
query = "hello",
k = 3,
};
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
}
this.VerifyFunction(functionExamples.QueryWrapper, args, new[] { "hello", "hello", "hello" });
this.VerifyFunctionDefinition(functionExamples.QueryFunctionContract.ToChatTool());
}
[Fact]
public async Task FilescopeFunctionExampleAddTestAsync()
[UseReporter(typeof(DiffReporter))]
[UseApprovalSubdirectory("ApprovalTests")]
private void VerifyFunctionDefinition(ChatTool function)
{
var func = new
{
var example = new FilescopeNamespaceFunctionExample();
var args = new
{
a = 1,
b = 2,
};
name = function.FunctionName,
description = function.FunctionDescription.Replace(Environment.NewLine, ","),
parameters = function.FunctionParameters.ToObjectFromJson<object>(options: jsonSerializerOptions),
};
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
}
Approvals.Verify(JsonSerializer.Serialize(func, jsonSerializerOptions));
}
[Fact]
public void Query_Test()
{
var args = new
{
query = "hello",
k = 3,
};
private void VerifyFunction<T, U>(Func<string, T> func, U args, T expected)
{
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
var res = func(str);
res.Should().BeEquivalentTo(expected);
}
this.VerifyFunction(functionExamples.QueryWrapper, args, new[] { "hello", "hello", "hello" });
this.VerifyFunctionDefinition(functionExamples.QueryFunctionContract.ToChatTool());
}
[UseReporter(typeof(DiffReporter))]
[UseApprovalSubdirectory("ApprovalTests")]
private void VerifyFunctionDefinition(ChatTool function)
{
var func = new
{
name = function.FunctionName,
description = function.FunctionDescription.Replace(Environment.NewLine, ","),
parameters = function.FunctionParameters.ToObjectFromJson<object>(options: jsonSerializerOptions),
};
Approvals.Verify(JsonSerializer.Serialize(func, jsonSerializerOptions));
}
private void VerifyFunction<T, U>(Func<string, T> func, U args, T expected)
{
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
var res = func(str);
res.Should().BeEquivalentTo(expected);
}
private async Task VerifyAsyncFunction<T, U>(Func<string, Task<T>> func, U args, T expected)
{
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
var res = await func(str);
res.Should().BeEquivalentTo(expected);
}
private async Task VerifyAsyncFunction<T, U>(Func<string, Task<T>> func, U args, T expected)
{
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
var res = await func(str);
res.Should().BeEquivalentTo(expected);
}
}

View File

@ -4,67 +4,66 @@
using System.Text.Json;
using AutoGen.Core;
namespace AutoGen.SourceGenerator.Tests
namespace AutoGen.SourceGenerator.Tests;
public partial class FunctionExamples
{
public partial class FunctionExamples
/// <summary>
/// Add function
/// </summary>
/// <param name="a">a</param>
/// <param name="b">b</param>
[FunctionAttribute]
public int Add(int a, int b)
{
/// <summary>
/// Add function
/// </summary>
/// <param name="a">a</param>
/// <param name="b">b</param>
[FunctionAttribute]
public int Add(int a, int b)
{
return a + b;
}
return a + b;
}
/// <summary>
/// Add two numbers.
/// </summary>
/// <param name="a">The first number.</param>
/// <param name="b">The second number.</param>
[Function]
public Task<string> AddAsync(int a, int b)
{
return Task.FromResult($"{a} + {b} = {a + b}");
}
/// <summary>
/// Add two numbers.
/// </summary>
/// <param name="a">The first number.</param>
/// <param name="b">The second number.</param>
[Function]
public Task<string> AddAsync(int a, int b)
{
return Task.FromResult($"{a} + {b} = {a + b}");
}
/// <summary>
/// Sum function
/// </summary>
/// <param name="args">an array of double values</param>
[FunctionAttribute]
public double Sum(double[] args)
{
return args.Sum();
}
/// <summary>
/// Sum function
/// </summary>
/// <param name="args">an array of double values</param>
[FunctionAttribute]
public double Sum(double[] args)
{
return args.Sum();
}
/// <summary>
/// DictionaryToString function
/// </summary>
/// <param name="xargs">an object of key-value pairs. key is string, value is string</param>
[FunctionAttribute]
public Task<string> DictionaryToStringAsync(Dictionary<string, string> xargs)
/// <summary>
/// DictionaryToString function
/// </summary>
/// <param name="xargs">an object of key-value pairs. key is string, value is string</param>
[FunctionAttribute]
public Task<string> DictionaryToStringAsync(Dictionary<string, string> xargs)
{
var res = JsonSerializer.Serialize(xargs, new JsonSerializerOptions
{
var res = JsonSerializer.Serialize(xargs, new JsonSerializerOptions
{
WriteIndented = true,
});
WriteIndented = true,
});
return Task.FromResult(res);
}
return Task.FromResult(res);
}
/// <summary>
/// query function
/// </summary>
/// <param name="query">query, required</param>
/// <param name="k">top k, optional, default value is 3</param>
/// <param name="thresold">thresold, optional, default value is 0.5</param>
[FunctionAttribute]
public string[] Query(string query, int k = 3, float thresold = 0.5f)
{
return Enumerable.Repeat(query, k).ToArray();
}
/// <summary>
/// query function
/// </summary>
/// <param name="query">query, required</param>
/// <param name="k">top k, optional, default value is 3</param>
/// <param name="thresold">thresold, optional, default value is 0.5</param>
[FunctionAttribute]
public string[] Query(string query, int k = 3, float thresold = 0.5f)
{
return Enumerable.Repeat(query, k).ToArray();
}
}

View File

@ -7,73 +7,72 @@ using System.Threading.Tasks;
using AutoGen.BasicSample;
using Xunit.Abstractions;
namespace AutoGen.Tests
namespace AutoGen.Tests;
public class BasicSampleTest
{
public class BasicSampleTest
private readonly ITestOutputHelper _output;
public BasicSampleTest(ITestOutputHelper output)
{
private readonly ITestOutputHelper _output;
_output = output;
Console.SetOut(new ConsoleWriter(_output));
}
public BasicSampleTest(ITestOutputHelper output)
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task AssistantAgentTestAsync()
{
await Example01_AssistantAgent.RunAsync();
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task TwoAgentMathClassTestAsync()
{
await Example02_TwoAgent_MathChat.RunAsync();
}
[ApiKeyFact("OPENAI_API_KEY")]
public async Task AgentFunctionCallTestAsync()
{
await Example03_Agent_FunctionCall.RunAsync();
}
[ApiKeyFact("MISTRAL_API_KEY")]
public async Task MistralClientAgent_TokenCount()
{
await Example14_MistralClientAgent_TokenCount.RunAsync();
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task DynamicGroupChatCalculateFibonacciAsync()
{
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunAsync();
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunWorkflowAsync();
}
[ApiKeyFact("OPENAI_API_KEY")]
public async Task DalleAndGPT4VTestAsync()
{
await Example05_Dalle_And_GPT4V.RunAsync();
}
[ApiKeyFact("OPENAI_API_KEY")]
public async Task GPT4ImageMessage()
{
await Example15_GPT4V_BinaryDataImageMessage.RunAsync();
}
public class ConsoleWriter : StringWriter
{
private ITestOutputHelper output;
public ConsoleWriter(ITestOutputHelper output)
{
_output = output;
Console.SetOut(new ConsoleWriter(_output));
this.output = output;
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task AssistantAgentTestAsync()
public override void WriteLine(string? m)
{
await Example01_AssistantAgent.RunAsync();
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task TwoAgentMathClassTestAsync()
{
await Example02_TwoAgent_MathChat.RunAsync();
}
[ApiKeyFact("OPENAI_API_KEY")]
public async Task AgentFunctionCallTestAsync()
{
await Example03_Agent_FunctionCall.RunAsync();
}
[ApiKeyFact("MISTRAL_API_KEY")]
public async Task MistralClientAgent_TokenCount()
{
await Example14_MistralClientAgent_TokenCount.RunAsync();
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task DynamicGroupChatCalculateFibonacciAsync()
{
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunAsync();
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunWorkflowAsync();
}
[ApiKeyFact("OPENAI_API_KEY")]
public async Task DalleAndGPT4VTestAsync()
{
await Example05_Dalle_And_GPT4V.RunAsync();
}
[ApiKeyFact("OPENAI_API_KEY")]
public async Task GPT4ImageMessage()
{
await Example15_GPT4V_BinaryDataImageMessage.RunAsync();
}
public class ConsoleWriter : StringWriter
{
private ITestOutputHelper output;
public ConsoleWriter(ITestOutputHelper output)
{
this.output = output;
}
public override void WriteLine(string? m)
{
output.WriteLine(m);
}
output.WriteLine(m);
}
}
}

View File

@ -3,18 +3,17 @@
using Xunit;
namespace AutoGen.Tests
{
public class GraphTests
{
[Fact]
public void GraphTest()
{
var graph1 = new Graph();
Assert.NotNull(graph1);
namespace AutoGen.Tests;
var graph2 = new Graph(null);
Assert.NotNull(graph2);
}
public class GraphTests
{
[Fact]
public void GraphTest()
{
var graph1 = new Graph();
Assert.NotNull(graph1);
var graph2 = new Graph(null);
Assert.NotNull(graph2);
}
}

View File

@ -9,219 +9,218 @@ using FluentAssertions;
using Xunit;
using Xunit.Abstractions;
namespace AutoGen.Tests
namespace AutoGen.Tests;
public partial class SingleAgentTest
{
public partial class SingleAgentTest
private ITestOutputHelper _output;
public SingleAgentTest(ITestOutputHelper output)
{
private ITestOutputHelper _output;
public SingleAgentTest(ITestOutputHelper output)
{
_output = output;
}
_output = output;
}
private ILLMConfig CreateAzureOpenAIGPT35TurboConfig()
{
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
return new AzureOpenAIConfig(endpoint, deployName, key);
}
private ILLMConfig CreateAzureOpenAIGPT35TurboConfig()
{
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
return new AzureOpenAIConfig(endpoint, deployName, key);
}
private ILLMConfig CreateOpenAIGPT4VisionConfig()
{
var key = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new ArgumentException("OPENAI_API_KEY is not set");
return new OpenAIConfig(key, "gpt-4-vision-preview");
}
private ILLMConfig CreateOpenAIGPT4VisionConfig()
{
var key = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new ArgumentException("OPENAI_API_KEY is not set");
return new OpenAIConfig(key, "gpt-4-vision-preview");
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task AssistantAgentFunctionCallTestAsync()
{
var config = this.CreateAzureOpenAIGPT35TurboConfig();
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task AssistantAgentFunctionCallTestAsync()
{
var config = this.CreateAzureOpenAIGPT35TurboConfig();
var llmConfig = new ConversableAgentConfig
var llmConfig = new ConversableAgentConfig
{
Temperature = 0,
FunctionContracts = new[]
{
Temperature = 0,
FunctionContracts = new[]
{
this.EchoAsyncFunctionContract,
},
ConfigList = new[]
{
config,
},
};
var assistantAgent = new AssistantAgent(
name: "assistant",
llmConfig: llmConfig);
await EchoFunctionCallTestAsync(assistantAgent);
}
[Fact]
public async Task AssistantAgentDefaultReplyTestAsync()
{
var assistantAgent = new AssistantAgent(
llmConfig: null,
name: "assistant",
defaultReply: "hello world");
var reply = await assistantAgent.SendAsync("hi");
reply.GetContent().Should().Be("hello world");
reply.GetRole().Should().Be(Role.Assistant);
reply.From.Should().Be(assistantAgent.Name);
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task AssistantAgentFunctionCallSelfExecutionTestAsync()
{
var config = this.CreateAzureOpenAIGPT35TurboConfig();
var llmConfig = new ConversableAgentConfig
this.EchoAsyncFunctionContract,
},
ConfigList = new[]
{
FunctionContracts = new[]
{
this.EchoAsyncFunctionContract,
},
ConfigList = new[]
{
config,
},
};
var assistantAgent = new AssistantAgent(
name: "assistant",
llmConfig: llmConfig,
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ nameof(EchoAsync), this.EchoAsyncWrapper },
});
config,
},
};
await EchoFunctionCallExecutionTestAsync(assistantAgent);
}
var assistantAgent = new AssistantAgent(
name: "assistant",
llmConfig: llmConfig);
/// <summary>
/// echo when asked.
/// </summary>
/// <param name="message">message to echo</param>
[FunctionAttribute]
public async Task<string> EchoAsync(string message)
await EchoFunctionCallTestAsync(assistantAgent);
}
[Fact]
public async Task AssistantAgentDefaultReplyTestAsync()
{
var assistantAgent = new AssistantAgent(
llmConfig: null,
name: "assistant",
defaultReply: "hello world");
var reply = await assistantAgent.SendAsync("hi");
reply.GetContent().Should().Be("hello world");
reply.GetRole().Should().Be(Role.Assistant);
reply.From.Should().Be(assistantAgent.Name);
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task AssistantAgentFunctionCallSelfExecutionTestAsync()
{
var config = this.CreateAzureOpenAIGPT35TurboConfig();
var llmConfig = new ConversableAgentConfig
{
return $"[ECHO] {message}";
}
FunctionContracts = new[]
{
this.EchoAsyncFunctionContract,
},
ConfigList = new[]
{
config,
},
};
var assistantAgent = new AssistantAgent(
name: "assistant",
llmConfig: llmConfig,
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ nameof(EchoAsync), this.EchoAsyncWrapper },
});
/// <summary>
/// return the label name with hightest inference cost
/// </summary>
/// <param name="labelName"></param>
/// <returns></returns>
[FunctionAttribute]
public async Task<string> GetHighestLabel(string labelName, string color)
await EchoFunctionCallExecutionTestAsync(assistantAgent);
}
/// <summary>
/// echo when asked.
/// </summary>
/// <param name="message">message to echo</param>
[FunctionAttribute]
public async Task<string> EchoAsync(string message)
{
return $"[ECHO] {message}";
}
/// <summary>
/// return the label name with hightest inference cost
/// </summary>
/// <param name="labelName"></param>
/// <returns></returns>
[FunctionAttribute]
public async Task<string> GetHighestLabel(string labelName, string color)
{
return $"[HIGHEST_LABEL] {labelName} {color}";
}
public async Task EchoFunctionCallTestAsync(IAgent agent)
{
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that call echo function");
var helloWorld = new TextMessage(Role.User, "echo Hello world");
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
reply.From.Should().Be(agent.Name);
reply.GetToolCalls()!.First().FunctionName.Should().Be(nameof(EchoAsync));
}
public async Task EchoFunctionCallExecutionTestAsync(IAgent agent)
{
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
var helloWorld = new TextMessage(Role.User, "echo Hello world");
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
reply.GetContent().Should().Be("[ECHO] Hello world");
reply.From.Should().Be(agent.Name);
reply.Should().BeOfType<ToolCallAggregateMessage>();
}
public async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent agent)
{
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
var helloWorld = new TextMessage(Role.User, "echo Hello world");
var option = new GenerateReplyOptions
{
return $"[HIGHEST_LABEL] {labelName} {color}";
}
public async Task EchoFunctionCallTestAsync(IAgent agent)
Temperature = 0,
};
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { helloWorld }, option);
var answer = "[ECHO] Hello world";
IMessage? finalReply = default;
await foreach (var reply in replyStream)
{
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that call echo function");
var helloWorld = new TextMessage(Role.User, "echo Hello world");
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
reply.From.Should().Be(agent.Name);
reply.GetToolCalls()!.First().FunctionName.Should().Be(nameof(EchoAsync));
finalReply = reply;
}
public async Task EchoFunctionCallExecutionTestAsync(IAgent agent)
if (finalReply is ToolCallAggregateMessage aggregateMessage)
{
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
var helloWorld = new TextMessage(Role.User, "echo Hello world");
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
reply.GetContent().Should().Be("[ECHO] Hello world");
reply.From.Should().Be(agent.Name);
reply.Should().BeOfType<ToolCallAggregateMessage>();
var toolCallResultMessage = aggregateMessage.Message2;
toolCallResultMessage.ToolCalls.First().Result.Should().Be(answer);
toolCallResultMessage.From.Should().Be(agent.Name);
toolCallResultMessage.ToolCalls.First().FunctionName.Should().Be(nameof(EchoAsync));
}
public async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent agent)
else
{
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
var helloWorld = new TextMessage(Role.User, "echo Hello world");
var option = new GenerateReplyOptions
{
Temperature = 0,
};
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { helloWorld }, option);
var answer = "[ECHO] Hello world";
IMessage? finalReply = default;
await foreach (var reply in replyStream)
{
reply.From.Should().Be(agent.Name);
finalReply = reply;
}
if (finalReply is ToolCallAggregateMessage aggregateMessage)
{
var toolCallResultMessage = aggregateMessage.Message2;
toolCallResultMessage.ToolCalls.First().Result.Should().Be(answer);
toolCallResultMessage.From.Should().Be(agent.Name);
toolCallResultMessage.ToolCalls.First().FunctionName.Should().Be(nameof(EchoAsync));
}
else
{
throw new Exception("unexpected message type");
}
}
public async Task UpperCaseTestAsync(IAgent agent)
{
var message = new TextMessage(Role.User, "Please convert abcde to upper case.");
var reply = await agent.SendAsync(chatHistory: new[] { message });
reply.GetContent().Should().Contain("ABCDE");
reply.From.Should().Be(agent.Name);
}
public async Task UpperCaseStreamingTestAsync(IStreamingAgent agent)
{
var message = new TextMessage(Role.User, "Please convert 'hello world' to upper case");
var option = new GenerateReplyOptions
{
Temperature = 0,
};
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message }, option);
var answer = "HELLO WORLD";
TextMessage? finalReply = default;
await foreach (var reply in replyStream)
{
if (reply is TextMessageUpdate update)
{
update.From.Should().Be(agent.Name);
if (finalReply is null)
{
finalReply = new TextMessage(update);
}
else
{
finalReply.Update(update);
}
continue;
}
else if (reply is TextMessage textMessage)
{
finalReply = textMessage;
continue;
}
throw new Exception("unexpected message type");
}
finalReply!.Content.Should().Contain(answer);
finalReply!.Role.Should().Be(Role.Assistant);
finalReply!.From.Should().Be(agent.Name);
throw new Exception("unexpected message type");
}
}
public async Task UpperCaseTestAsync(IAgent agent)
{
var message = new TextMessage(Role.User, "Please convert abcde to upper case.");
var reply = await agent.SendAsync(chatHistory: new[] { message });
reply.GetContent().Should().Contain("ABCDE");
reply.From.Should().Be(agent.Name);
}
public async Task UpperCaseStreamingTestAsync(IStreamingAgent agent)
{
var message = new TextMessage(Role.User, "Please convert 'hello world' to upper case");
var option = new GenerateReplyOptions
{
Temperature = 0,
};
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message }, option);
var answer = "HELLO WORLD";
TextMessage? finalReply = default;
await foreach (var reply in replyStream)
{
if (reply is TextMessageUpdate update)
{
update.From.Should().Be(agent.Name);
if (finalReply is null)
{
finalReply = new TextMessage(update);
}
else
{
finalReply.Update(update);
}
continue;
}
else if (reply is TextMessage textMessage)
{
finalReply = textMessage;
continue;
}
throw new Exception("unexpected message type");
}
finalReply!.Content.Should().Contain(answer);
finalReply!.Role.Should().Be(Role.Assistant);
finalReply!.From.Should().Be(agent.Name);
}
}