mirror of https://github.com/microsoft/autogen.git
Compare commits
4 Commits
e11af571b3
...
456e0525e5
Author | SHA1 | Date |
---|---|---|
Eric Zhu | 456e0525e5 | |
Xiaoyun Zhang | e63fd17ed5 | |
Eric Zhu | 3ca80a8d82 | |
Ryan Sweet | 51cd5b8d1f |
|
@ -221,13 +221,14 @@ dotnet_diagnostic.IDE0161.severity = warning # Use file-scoped namespace
|
|||
|
||||
csharp_style_var_elsewhere = true:suggestion # Prefer 'var' everywhere
|
||||
csharp_prefer_simple_using_statement = true:suggestion
|
||||
csharp_style_namespace_declarations = block_scoped:silent
|
||||
csharp_style_namespace_declarations = file_scoped:warning
|
||||
csharp_style_prefer_method_group_conversion = true:silent
|
||||
csharp_style_prefer_top_level_statements = true:silent
|
||||
csharp_style_prefer_primary_constructors = true:suggestion
|
||||
csharp_style_expression_bodied_lambdas = true:silent
|
||||
csharp_style_prefer_local_over_anonymous_function = true:suggestion
|
||||
dotnet_diagnostic.CA2016.severity = suggestion
|
||||
csharp_prefer_static_anonymous_function = true:suggestion
|
||||
|
||||
# disable check for generated code
|
||||
[*.generated.cs]
|
||||
|
@ -697,6 +698,7 @@ dotnet_style_prefer_compound_assignment = true:suggestion
|
|||
dotnet_style_prefer_simplified_interpolation = true:suggestion
|
||||
dotnet_style_prefer_collection_expression = when_types_loosely_match:suggestion
|
||||
dotnet_style_namespace_match_folder = true:suggestion
|
||||
dotnet_style_qualification_for_method = false:silent
|
||||
|
||||
[**/*.g.cs]
|
||||
generated_code = true
|
||||
|
|
|
@ -125,7 +125,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AIModelClientHostingExtensi
|
|||
EndProject
|
||||
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "sample", "sample", "{686480D7-8FEC-4ED3-9C5D-CEBE1057A7ED}"
|
||||
EndProject
|
||||
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "HelloAgentState", "samples\Hello\HelloAgentState\HelloAgentState.csproj", "{64EF61E7-00A6-4E5E-9808-62E10993A0E5}"
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HelloAgentState", "samples\Hello\HelloAgentState\HelloAgentState.csproj", "{64EF61E7-00A6-4E5E-9808-62E10993A0E5}"
|
||||
EndProject
|
||||
Global
|
||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution
|
||||
|
|
|
@ -18,10 +18,11 @@ namespace Hello
|
|||
[TopicSubscription("HelloAgents")]
|
||||
public class HelloAgent(
|
||||
IAgentContext context,
|
||||
[FromKeyedServices("EventTypes")] EventTypes typeRegistry) : ConsoleAgent(
|
||||
[FromKeyedServices("EventTypes")] EventTypes typeRegistry) : AgentBase(
|
||||
context,
|
||||
typeRegistry),
|
||||
ISayHello,
|
||||
IHandleConsole,
|
||||
IHandle<NewMessageReceived>,
|
||||
IHandle<ConversationClosed>
|
||||
{
|
||||
|
|
|
@ -5,45 +5,44 @@ using System;
|
|||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
|
||||
namespace AutoGen
|
||||
namespace AutoGen;
|
||||
|
||||
public static class LLMConfigAPI
|
||||
{
|
||||
public static class LLMConfigAPI
|
||||
public static IEnumerable<ILLMConfig> GetOpenAIConfigList(
|
||||
string apiKey,
|
||||
IEnumerable<string>? modelIDs = null)
|
||||
{
|
||||
public static IEnumerable<ILLMConfig> GetOpenAIConfigList(
|
||||
string apiKey,
|
||||
IEnumerable<string>? modelIDs = null)
|
||||
var models = modelIDs ?? new[]
|
||||
{
|
||||
var models = modelIDs ?? new[]
|
||||
{
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-4-1106-preview",
|
||||
};
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-4-1106-preview",
|
||||
};
|
||||
|
||||
return models.Select(modelId => new OpenAIConfig(apiKey, modelId));
|
||||
}
|
||||
return models.Select(modelId => new OpenAIConfig(apiKey, modelId));
|
||||
}
|
||||
|
||||
public static IEnumerable<ILLMConfig> GetAzureOpenAIConfigList(
|
||||
string endpoint,
|
||||
string apiKey,
|
||||
IEnumerable<string> deploymentNames)
|
||||
{
|
||||
return deploymentNames.Select(deploymentName => new AzureOpenAIConfig(endpoint, deploymentName, apiKey));
|
||||
}
|
||||
public static IEnumerable<ILLMConfig> GetAzureOpenAIConfigList(
|
||||
string endpoint,
|
||||
string apiKey,
|
||||
IEnumerable<string> deploymentNames)
|
||||
{
|
||||
return deploymentNames.Select(deploymentName => new AzureOpenAIConfig(endpoint, deploymentName, apiKey));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get a list of LLMConfig objects from a JSON file.
|
||||
/// </summary>
|
||||
internal static IEnumerable<ILLMConfig> ConfigListFromJson(
|
||||
string filePath,
|
||||
IEnumerable<string>? filterModels = null)
|
||||
{
|
||||
// Disable this API from documentation for now.
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
/// <summary>
|
||||
/// Get a list of LLMConfig objects from a JSON file.
|
||||
/// </summary>
|
||||
internal static IEnumerable<ILLMConfig> ConfigListFromJson(
|
||||
string filePath,
|
||||
IEnumerable<string>? filterModels = null)
|
||||
{
|
||||
// Disable this API from documentation for now.
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
using System.Diagnostics;
|
||||
using System.Reflection;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Threading.Channels;
|
||||
|
@ -17,6 +18,8 @@ public abstract class AgentBase : IAgentBase
|
|||
|
||||
private readonly Channel<object> _mailbox = Channel.CreateUnbounded<object>();
|
||||
private readonly IAgentContext _context;
|
||||
public string Route { get; set; } = "base";
|
||||
|
||||
protected internal ILogger Logger => _context.Logger;
|
||||
public IAgentContext Context => _context;
|
||||
protected readonly EventTypes EventTypes;
|
||||
|
@ -212,14 +215,39 @@ public abstract class AgentBase : IAgentBase
|
|||
public Task CallHandler(CloudEvent item)
|
||||
{
|
||||
// Only send the event to the handler if the agent type is handling that type
|
||||
if (EventTypes.EventsMap[GetType()].Contains(item.Type))
|
||||
// foreach of the keys in the EventTypes.EventsMap[] if it contains the item.type
|
||||
foreach (var key in EventTypes.EventsMap.Keys)
|
||||
{
|
||||
var payload = item.ProtoData.Unpack(EventTypes.TypeRegistry);
|
||||
var convertedPayload = Convert.ChangeType(payload, EventTypes.Types[item.Type]);
|
||||
var genericInterfaceType = typeof(IHandle<>).MakeGenericType(EventTypes.Types[item.Type]);
|
||||
var methodInfo = genericInterfaceType.GetMethod(nameof(IHandle<object>.Handle)) ?? throw new InvalidOperationException($"Method not found on type {genericInterfaceType.FullName}");
|
||||
return methodInfo.Invoke(this, [payload]) as Task ?? Task.CompletedTask;
|
||||
if (EventTypes.EventsMap[key].Contains(item.Type))
|
||||
{
|
||||
var payload = item.ProtoData.Unpack(EventTypes.TypeRegistry);
|
||||
var convertedPayload = Convert.ChangeType(payload, EventTypes.Types[item.Type]);
|
||||
var genericInterfaceType = typeof(IHandle<>).MakeGenericType(EventTypes.Types[item.Type]);
|
||||
|
||||
MethodInfo methodInfo;
|
||||
try
|
||||
{
|
||||
// check that our target actually implements this interface, otherwise call the default static
|
||||
if (genericInterfaceType.IsAssignableFrom(this.GetType()))
|
||||
{
|
||||
methodInfo = genericInterfaceType.GetMethod(nameof(IHandle<object>.Handle), BindingFlags.Public | BindingFlags.Instance)
|
||||
?? throw new InvalidOperationException($"Method not found on type {genericInterfaceType.FullName}");
|
||||
return methodInfo.Invoke(this, [payload]) as Task ?? Task.CompletedTask;
|
||||
}
|
||||
else
|
||||
{
|
||||
// The error here is we have registered for an event that we do not have code to listen to
|
||||
throw new InvalidOperationException($"No handler found for event '{item.Type}'; expecting IHandle<{item.Type}> implementation.");
|
||||
}
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
Logger.LogError(ex, $"Error invoking method {nameof(IHandle<object>.Handle)}");
|
||||
throw; // TODO: ?
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
using Microsoft.AutoGen.Abstractions;
|
||||
|
||||
namespace Microsoft.AutoGen.Agents;
|
||||
|
||||
public interface IHandleConsole : IHandle<Output>, IHandle<Input>
|
||||
{
|
||||
string Route { get; }
|
||||
AgentId AgentId { get; }
|
||||
ValueTask PublishEvent(CloudEvent item);
|
||||
|
||||
async Task IHandle<Output>.Handle(Output item)
|
||||
{
|
||||
// Assuming item has a property `Message` that we want to write to the console
|
||||
Console.WriteLine(item.Message);
|
||||
await ProcessOutput(item.Message);
|
||||
|
||||
var evt = new OutputWritten
|
||||
{
|
||||
Route = "console"
|
||||
}.ToCloudEvent(AgentId.Key);
|
||||
await PublishEvent(evt);
|
||||
}
|
||||
async Task IHandle<Input>.Handle(Input item)
|
||||
{
|
||||
Console.WriteLine("Please enter input:");
|
||||
string content = Console.ReadLine() ?? string.Empty;
|
||||
|
||||
await ProcessInput(content);
|
||||
|
||||
var evt = new InputProcessed
|
||||
{
|
||||
Route = "console"
|
||||
}.ToCloudEvent(AgentId.Key);
|
||||
await PublishEvent(evt);
|
||||
}
|
||||
static Task ProcessOutput(string message)
|
||||
{
|
||||
// Implement your output processing logic here
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
static Task<string> ProcessInput(string message)
|
||||
{
|
||||
// Implement your input processing logic here
|
||||
return Task.FromResult(message);
|
||||
}
|
||||
}
|
|
@ -83,6 +83,7 @@ public sealed class GrpcAgentWorkerRuntime : IHostedService, IDisposable, IAgent
|
|||
message.Response.RequestId = request.OriginalRequestId;
|
||||
request.Agent.ReceiveMessage(message);
|
||||
break;
|
||||
|
||||
case Message.MessageOneofCase.RegisterAgentTypeResponse:
|
||||
if (!message.RegisterAgentTypeResponse.Success)
|
||||
{
|
||||
|
|
|
@ -71,7 +71,52 @@ public static class HostBuilderExtensions
|
|||
.Where(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IHandle<>))
|
||||
.Select(i => (GetMessageDescriptor(i.GetGenericArguments().First())?.FullName ?? "")).ToHashSet()))
|
||||
.ToDictionary(item => item.t, item => item.Item2);
|
||||
// if the assembly contains any interfaces of type IHandler, then add all the methods of the interface to the eventsMap
|
||||
var handlersMap = AppDomain.CurrentDomain.GetAssemblies()
|
||||
.SelectMany(assembly => assembly.GetTypes())
|
||||
.Where(type => ReflectionHelper.IsSubclassOfGeneric(type, typeof(AgentBase)) && !type.IsAbstract)
|
||||
.Select(t => (t, t.GetMethods()
|
||||
.Where(m => m.Name == "Handle")
|
||||
.Select(m => (GetMessageDescriptor(m.GetParameters().First().ParameterType)?.FullName ?? "")).ToHashSet()))
|
||||
.ToDictionary(item => item.t, item => item.Item2);
|
||||
// get interfaces implemented by the agent and get the methods of the interface if they are named Handle
|
||||
var ifaceHandlersMap = AppDomain.CurrentDomain.GetAssemblies()
|
||||
.SelectMany(assembly => assembly.GetTypes())
|
||||
.Where(type => ReflectionHelper.IsSubclassOfGeneric(type, typeof(AgentBase)) && !type.IsAbstract)
|
||||
.Select(t => t.GetInterfaces()
|
||||
.Select(i => (t, i, i.GetMethods()
|
||||
.Where(m => m.Name == "Handle")
|
||||
.Select(m => (GetMessageDescriptor(m.GetParameters().First().ParameterType)?.FullName ?? ""))
|
||||
//to dictionary of type t and paramter type of the method
|
||||
.ToDictionary(m => m, m => m).Keys.ToHashSet())).ToList());
|
||||
// for each item in ifaceHandlersMap, add the handlers to eventsMap with item as the key
|
||||
foreach (var item in ifaceHandlersMap)
|
||||
{
|
||||
foreach (var iface in item)
|
||||
{
|
||||
if (eventsMap.TryGetValue(iface.Item2, out var events))
|
||||
{
|
||||
events.UnionWith(iface.Item3);
|
||||
}
|
||||
else
|
||||
{
|
||||
eventsMap[iface.Item2] = iface.Item3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// merge the handlersMap into the eventsMap
|
||||
foreach (var item in handlersMap)
|
||||
{
|
||||
if (eventsMap.TryGetValue(item.Key, out var events))
|
||||
{
|
||||
events.UnionWith(item.Value);
|
||||
}
|
||||
else
|
||||
{
|
||||
eventsMap[item.Key] = item.Value;
|
||||
}
|
||||
}
|
||||
return new EventTypes(typeRegistry, types, eventsMap);
|
||||
});
|
||||
return new AgentApplicationBuilder(builder);
|
||||
|
|
|
@ -1,20 +1,19 @@
|
|||
using Google.Protobuf;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
|
||||
namespace Microsoft.AutoGen.Agents
|
||||
{
|
||||
public interface IAgentBase
|
||||
{
|
||||
// Properties
|
||||
AgentId AgentId { get; }
|
||||
IAgentContext Context { get; }
|
||||
namespace Microsoft.AutoGen.Agents;
|
||||
|
||||
// Methods
|
||||
Task CallHandler(CloudEvent item);
|
||||
Task<RpcResponse> HandleRequest(RpcRequest request);
|
||||
void ReceiveMessage(Message message);
|
||||
Task Store(AgentState state);
|
||||
Task<T> Read<T>(AgentId agentId) where T : IMessage, new();
|
||||
ValueTask PublishEvent(CloudEvent item);
|
||||
}
|
||||
public interface IAgentBase
|
||||
{
|
||||
// Properties
|
||||
AgentId AgentId { get; }
|
||||
IAgentContext Context { get; }
|
||||
|
||||
// Methods
|
||||
Task CallHandler(CloudEvent item);
|
||||
Task<RpcResponse> HandleRequest(RpcRequest request);
|
||||
void ReceiveMessage(Message message);
|
||||
Task Store(AgentState state);
|
||||
Task<T> Read<T>(AgentId agentId) where T : IMessage, new();
|
||||
ValueTask PublishEvent(CloudEvent item);
|
||||
}
|
||||
|
|
|
@ -1,33 +1,32 @@
|
|||
using Microsoft.Extensions.AI;
|
||||
|
||||
namespace Microsoft.Extensions.Hosting
|
||||
{
|
||||
public static class AIModelClient
|
||||
{
|
||||
public static IHostApplicationBuilder AddChatCompletionService(this IHostApplicationBuilder builder, string serviceName)
|
||||
{
|
||||
var pipeline = (ChatClientBuilder pipeline) => pipeline
|
||||
.UseLogging()
|
||||
.UseFunctionInvocation()
|
||||
.UseOpenTelemetry(configure: c => c.EnableSensitiveData = true);
|
||||
namespace Microsoft.Extensions.Hosting;
|
||||
|
||||
if (builder.Configuration[$"{serviceName}:ModelType"] == "ollama")
|
||||
{
|
||||
builder.AddOllamaChatClient(serviceName, pipeline);
|
||||
}
|
||||
else if (builder.Configuration[$"{serviceName}:ModelType"] == "openai" || builder.Configuration[$"{serviceName}:ModelType"] == "azureopenai")
|
||||
{
|
||||
builder.AddOpenAIChatClient(serviceName, pipeline);
|
||||
}
|
||||
else if (builder.Configuration[$"{serviceName}:ModelType"] == "azureaiinference")
|
||||
{
|
||||
builder.AddAzureChatClient(serviceName, pipeline);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("Did not find a valid model implementation for the given service name ${serviceName}, valid supported implemenation types are ollama, openai, azureopenai, azureaiinference");
|
||||
}
|
||||
return builder;
|
||||
public static class AIModelClient
|
||||
{
|
||||
public static IHostApplicationBuilder AddChatCompletionService(this IHostApplicationBuilder builder, string serviceName)
|
||||
{
|
||||
var pipeline = (ChatClientBuilder pipeline) => pipeline
|
||||
.UseLogging()
|
||||
.UseFunctionInvocation()
|
||||
.UseOpenTelemetry(configure: c => c.EnableSensitiveData = true);
|
||||
|
||||
if (builder.Configuration[$"{serviceName}:ModelType"] == "ollama")
|
||||
{
|
||||
builder.AddOllamaChatClient(serviceName, pipeline);
|
||||
}
|
||||
else if (builder.Configuration[$"{serviceName}:ModelType"] == "openai" || builder.Configuration[$"{serviceName}:ModelType"] == "azureopenai")
|
||||
{
|
||||
builder.AddOpenAIChatClient(serviceName, pipeline);
|
||||
}
|
||||
else if (builder.Configuration[$"{serviceName}:ModelType"] == "azureaiinference")
|
||||
{
|
||||
builder.AddAzureChatClient(serviceName, pipeline);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("Did not find a valid model implementation for the given service name ${serviceName}, valid supported implemenation types are ollama, openai, azureopenai, azureaiinference");
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,206 +14,205 @@ using FluentAssertions;
|
|||
using OpenAI;
|
||||
using Xunit.Abstractions;
|
||||
|
||||
namespace AutoGen.OpenAI.Tests
|
||||
namespace AutoGen.OpenAI.Tests;
|
||||
|
||||
public partial class MathClassTest
|
||||
{
|
||||
public partial class MathClassTest
|
||||
private readonly ITestOutputHelper _output;
|
||||
|
||||
// as of 2024-05-20, aoai return 500 error when round > 1
|
||||
// I'm pretty sure that round > 5 was supported before
|
||||
// So this is probably some wield regression on aoai side
|
||||
// I'll keep this test case here for now, plus setting round to 1
|
||||
// so the test can still pass.
|
||||
// In the future, we should rewind this test case to round > 1 (previously was 5)
|
||||
private int round = 1;
|
||||
public MathClassTest(ITestOutputHelper output)
|
||||
{
|
||||
private readonly ITestOutputHelper _output;
|
||||
_output = output;
|
||||
}
|
||||
|
||||
// as of 2024-05-20, aoai return 500 error when round > 1
|
||||
// I'm pretty sure that round > 5 was supported before
|
||||
// So this is probably some wield regression on aoai side
|
||||
// I'll keep this test case here for now, plus setting round to 1
|
||||
// so the test can still pass.
|
||||
// In the future, we should rewind this test case to round > 1 (previously was 5)
|
||||
private int round = 1;
|
||||
public MathClassTest(ITestOutputHelper output)
|
||||
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
|
||||
{
|
||||
try
|
||||
{
|
||||
_output = output;
|
||||
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
|
||||
|
||||
_output.WriteLine(reply.FormatMessage());
|
||||
return Task.FromResult(reply);
|
||||
}
|
||||
|
||||
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
|
||||
catch (Exception)
|
||||
{
|
||||
try
|
||||
_output.WriteLine("Request failed");
|
||||
_output.WriteLine($"agent name: {agent.Name}");
|
||||
foreach (var message in messages)
|
||||
{
|
||||
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
|
||||
|
||||
_output.WriteLine(reply.FormatMessage());
|
||||
return Task.FromResult(reply);
|
||||
}
|
||||
catch (Exception)
|
||||
{
|
||||
_output.WriteLine("Request failed");
|
||||
_output.WriteLine($"agent name: {agent.Name}");
|
||||
foreach (var message in messages)
|
||||
{
|
||||
_output.WriteLine(message.FormatMessage());
|
||||
}
|
||||
|
||||
throw;
|
||||
_output.WriteLine(message.FormatMessage());
|
||||
}
|
||||
|
||||
throw;
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> CreateMathQuestion(string question, int question_index)
|
||||
{
|
||||
return $@"[MATH_QUESTION]
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> CreateMathQuestion(string question, int question_index)
|
||||
{
|
||||
return $@"[MATH_QUESTION]
|
||||
Question {question_index}:
|
||||
{question}
|
||||
|
||||
Student, please answer";
|
||||
}
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerQuestion(string answer)
|
||||
{
|
||||
return $@"[MATH_ANSWER]
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerQuestion(string answer)
|
||||
{
|
||||
return $@"[MATH_ANSWER]
|
||||
The answer is {answer}
|
||||
teacher please check answer";
|
||||
}
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerIsCorrect(string message)
|
||||
{
|
||||
return $@"[ANSWER_IS_CORRECT]
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerIsCorrect(string message)
|
||||
{
|
||||
return $@"[ANSWER_IS_CORRECT]
|
||||
{message}
|
||||
please update progress";
|
||||
}
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> UpdateProgress(int correctAnswerCount)
|
||||
[FunctionAttribute]
|
||||
public async Task<string> UpdateProgress(int correctAnswerCount)
|
||||
{
|
||||
if (correctAnswerCount >= this.round)
|
||||
{
|
||||
if (correctAnswerCount >= this.round)
|
||||
{
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
{GroupChatExtension.TERMINATE}";
|
||||
}
|
||||
else
|
||||
{
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
}
|
||||
else
|
||||
{
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
the number of resolved question is {correctAnswerCount}
|
||||
teacher, please create the next math question";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task OpenAIAgentMathChatTestAsync()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
|
||||
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
|
||||
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
|
||||
var openaiClient = new AzureOpenAIClient(new Uri(endPoint), new ApiKeyCredential(key));
|
||||
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
|
||||
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task OpenAIAgentMathChatTestAsync()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
|
||||
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
|
||||
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
|
||||
var openaiClient = new AzureOpenAIClient(new Uri(endPoint), new ApiKeyCredential(key));
|
||||
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
|
||||
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
|
||||
|
||||
var adminFunctionMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.UpdateProgressFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
|
||||
});
|
||||
var admin = new OpenAIChatAgent(
|
||||
chatClient: openaiClient.GetChatClient(deployName),
|
||||
name: "Admin",
|
||||
systemMessage: $@"You are admin. You update progress after each question is answered.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(adminFunctionMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
var adminFunctionMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.UpdateProgressFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
|
||||
});
|
||||
var admin = new OpenAIChatAgent(
|
||||
chatClient: openaiClient.GetChatClient(deployName),
|
||||
name: "Admin",
|
||||
systemMessage: $@"You are admin. You update progress after each question is answered.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(adminFunctionMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
|
||||
var groupAdmin = new OpenAIChatAgent(
|
||||
chatClient: openaiClient.GetChatClient(deployName),
|
||||
name: "GroupAdmin",
|
||||
systemMessage: "You are group admin. You manage the group chat.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterMiddleware(Print);
|
||||
await RunMathChatAsync(teacher, student, admin, groupAdmin);
|
||||
}
|
||||
var groupAdmin = new OpenAIChatAgent(
|
||||
chatClient: openaiClient.GetChatClient(deployName),
|
||||
name: "GroupAdmin",
|
||||
systemMessage: "You are group admin. You manage the group chat.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterMiddleware(Print);
|
||||
await RunMathChatAsync(teacher, student, admin, groupAdmin);
|
||||
}
|
||||
|
||||
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
|
||||
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
|
||||
});
|
||||
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
|
||||
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
|
||||
});
|
||||
|
||||
var teacher = new OpenAIChatAgent(
|
||||
chatClient: client.GetChatClient(model),
|
||||
name: "Teacher",
|
||||
systemMessage: @"You are a preschool math teacher.
|
||||
var teacher = new OpenAIChatAgent(
|
||||
chatClient: client.GetChatClient(model),
|
||||
name: "Teacher",
|
||||
systemMessage: @"You are a preschool math teacher.
|
||||
You create math question and ask student to answer it.
|
||||
Then you check if the answer is correct.
|
||||
If the answer is wrong, you ask student to fix it")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
|
||||
return teacher;
|
||||
}
|
||||
return teacher;
|
||||
}
|
||||
|
||||
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.AnswerQuestionFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
|
||||
});
|
||||
var student = new OpenAIChatAgent(
|
||||
chatClient: client.GetChatClient(model),
|
||||
name: "Student",
|
||||
systemMessage: @"You are a student. You answer math question from teacher.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.AnswerQuestionFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
|
||||
});
|
||||
var student = new OpenAIChatAgent(
|
||||
chatClient: client.GetChatClient(model),
|
||||
name: "Student",
|
||||
systemMessage: @"You are a student. You answer math question from teacher.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
|
||||
return student;
|
||||
}
|
||||
return student;
|
||||
}
|
||||
|
||||
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
|
||||
{
|
||||
var teacher2Student = Transition.Create(teacher, student);
|
||||
var student2Teacher = Transition.Create(student, teacher);
|
||||
var teacher2Admin = Transition.Create(teacher, admin);
|
||||
var admin2Teacher = Transition.Create(admin, teacher);
|
||||
var workflow = new Graph(
|
||||
[
|
||||
teacher2Student,
|
||||
student2Teacher,
|
||||
teacher2Admin,
|
||||
admin2Teacher,
|
||||
]);
|
||||
var group = new GroupChat(
|
||||
workflow: workflow,
|
||||
members: [
|
||||
admin,
|
||||
teacher,
|
||||
student,
|
||||
],
|
||||
admin: groupAdmin);
|
||||
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
|
||||
{
|
||||
var teacher2Student = Transition.Create(teacher, student);
|
||||
var student2Teacher = Transition.Create(student, teacher);
|
||||
var teacher2Admin = Transition.Create(teacher, admin);
|
||||
var admin2Teacher = Transition.Create(admin, teacher);
|
||||
var workflow = new Graph(
|
||||
[
|
||||
teacher2Student,
|
||||
student2Teacher,
|
||||
teacher2Admin,
|
||||
admin2Teacher,
|
||||
]);
|
||||
var group = new GroupChat(
|
||||
workflow: workflow,
|
||||
members: [
|
||||
admin,
|
||||
teacher,
|
||||
student,
|
||||
],
|
||||
admin: groupAdmin);
|
||||
|
||||
var groupChatManager = new GroupChatManager(group);
|
||||
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
|
||||
var groupChatManager = new GroupChatManager(group);
|
||||
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
|
||||
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
|
||||
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
|
||||
// check if there's terminate chat message from admin
|
||||
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
|
||||
.Count()
|
||||
.Should().Be(1);
|
||||
}
|
||||
// check if there's terminate chat message from admin
|
||||
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
|
||||
.Count()
|
||||
.Should().Be(1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,214 +13,213 @@ using Azure.AI.OpenAI;
|
|||
using FluentAssertions;
|
||||
using Xunit.Abstractions;
|
||||
|
||||
namespace AutoGen.OpenAI.V1.Tests
|
||||
namespace AutoGen.OpenAI.V1.Tests;
|
||||
|
||||
public partial class MathClassTest
|
||||
{
|
||||
public partial class MathClassTest
|
||||
private readonly ITestOutputHelper _output;
|
||||
|
||||
// as of 2024-05-20, aoai return 500 error when round > 1
|
||||
// I'm pretty sure that round > 5 was supported before
|
||||
// So this is probably some wield regression on aoai side
|
||||
// I'll keep this test case here for now, plus setting round to 1
|
||||
// so the test can still pass.
|
||||
// In the future, we should rewind this test case to round > 1 (previously was 5)
|
||||
private int round = 1;
|
||||
public MathClassTest(ITestOutputHelper output)
|
||||
{
|
||||
private readonly ITestOutputHelper _output;
|
||||
_output = output;
|
||||
}
|
||||
|
||||
// as of 2024-05-20, aoai return 500 error when round > 1
|
||||
// I'm pretty sure that round > 5 was supported before
|
||||
// So this is probably some wield regression on aoai side
|
||||
// I'll keep this test case here for now, plus setting round to 1
|
||||
// so the test can still pass.
|
||||
// In the future, we should rewind this test case to round > 1 (previously was 5)
|
||||
private int round = 1;
|
||||
public MathClassTest(ITestOutputHelper output)
|
||||
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
|
||||
{
|
||||
try
|
||||
{
|
||||
_output = output;
|
||||
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
|
||||
|
||||
_output.WriteLine(reply.FormatMessage());
|
||||
return Task.FromResult(reply);
|
||||
}
|
||||
|
||||
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
|
||||
catch (Exception)
|
||||
{
|
||||
try
|
||||
_output.WriteLine("Request failed");
|
||||
_output.WriteLine($"agent name: {agent.Name}");
|
||||
foreach (var message in messages)
|
||||
{
|
||||
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
|
||||
|
||||
_output.WriteLine(reply.FormatMessage());
|
||||
return Task.FromResult(reply);
|
||||
}
|
||||
catch (Exception)
|
||||
{
|
||||
_output.WriteLine("Request failed");
|
||||
_output.WriteLine($"agent name: {agent.Name}");
|
||||
foreach (var message in messages)
|
||||
if (message is IMessage<object> envelope)
|
||||
{
|
||||
if (message is IMessage<object> envelope)
|
||||
{
|
||||
var json = JsonSerializer.Serialize(envelope.Content, new JsonSerializerOptions { WriteIndented = true });
|
||||
_output.WriteLine(json);
|
||||
}
|
||||
var json = JsonSerializer.Serialize(envelope.Content, new JsonSerializerOptions { WriteIndented = true });
|
||||
_output.WriteLine(json);
|
||||
}
|
||||
|
||||
throw;
|
||||
}
|
||||
|
||||
throw;
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> CreateMathQuestion(string question, int question_index)
|
||||
{
|
||||
return $@"[MATH_QUESTION]
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> CreateMathQuestion(string question, int question_index)
|
||||
{
|
||||
return $@"[MATH_QUESTION]
|
||||
Question {question_index}:
|
||||
{question}
|
||||
|
||||
Student, please answer";
|
||||
}
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerQuestion(string answer)
|
||||
{
|
||||
return $@"[MATH_ANSWER]
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerQuestion(string answer)
|
||||
{
|
||||
return $@"[MATH_ANSWER]
|
||||
The answer is {answer}
|
||||
teacher please check answer";
|
||||
}
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerIsCorrect(string message)
|
||||
{
|
||||
return $@"[ANSWER_IS_CORRECT]
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerIsCorrect(string message)
|
||||
{
|
||||
return $@"[ANSWER_IS_CORRECT]
|
||||
{message}
|
||||
please update progress";
|
||||
}
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> UpdateProgress(int correctAnswerCount)
|
||||
[FunctionAttribute]
|
||||
public async Task<string> UpdateProgress(int correctAnswerCount)
|
||||
{
|
||||
if (correctAnswerCount >= this.round)
|
||||
{
|
||||
if (correctAnswerCount >= this.round)
|
||||
{
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
{GroupChatExtension.TERMINATE}";
|
||||
}
|
||||
else
|
||||
{
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
}
|
||||
else
|
||||
{
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
the number of resolved question is {correctAnswerCount}
|
||||
teacher, please create the next math question";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task OpenAIAgentMathChatTestAsync()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
|
||||
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
|
||||
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
|
||||
var openaiClient = new OpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(key));
|
||||
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
|
||||
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task OpenAIAgentMathChatTestAsync()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
|
||||
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
|
||||
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
|
||||
var openaiClient = new OpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(key));
|
||||
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
|
||||
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
|
||||
|
||||
var adminFunctionMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.UpdateProgressFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
|
||||
});
|
||||
var admin = new OpenAIChatAgent(
|
||||
openAIClient: openaiClient,
|
||||
modelName: deployName,
|
||||
name: "Admin",
|
||||
systemMessage: $@"You are admin. You update progress after each question is answered.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(adminFunctionMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
var adminFunctionMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.UpdateProgressFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
|
||||
});
|
||||
var admin = new OpenAIChatAgent(
|
||||
openAIClient: openaiClient,
|
||||
modelName: deployName,
|
||||
name: "Admin",
|
||||
systemMessage: $@"You are admin. You update progress after each question is answered.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(adminFunctionMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
|
||||
var groupAdmin = new OpenAIChatAgent(
|
||||
openAIClient: openaiClient,
|
||||
modelName: deployName,
|
||||
name: "GroupAdmin",
|
||||
systemMessage: "You are group admin. You manage the group chat.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterMiddleware(Print);
|
||||
await RunMathChatAsync(teacher, student, admin, groupAdmin);
|
||||
}
|
||||
var groupAdmin = new OpenAIChatAgent(
|
||||
openAIClient: openaiClient,
|
||||
modelName: deployName,
|
||||
name: "GroupAdmin",
|
||||
systemMessage: "You are group admin. You manage the group chat.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterMiddleware(Print);
|
||||
await RunMathChatAsync(teacher, student, admin, groupAdmin);
|
||||
}
|
||||
|
||||
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
|
||||
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
|
||||
});
|
||||
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
|
||||
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
|
||||
});
|
||||
|
||||
var teacher = new OpenAIChatAgent(
|
||||
openAIClient: client,
|
||||
name: "Teacher",
|
||||
systemMessage: @"You are a preschool math teacher.
|
||||
var teacher = new OpenAIChatAgent(
|
||||
openAIClient: client,
|
||||
name: "Teacher",
|
||||
systemMessage: @"You are a preschool math teacher.
|
||||
You create math question and ask student to answer it.
|
||||
Then you check if the answer is correct.
|
||||
If the answer is wrong, you ask student to fix it",
|
||||
modelName: model)
|
||||
.RegisterMiddleware(Print)
|
||||
.RegisterMiddleware(new OpenAIChatRequestMessageConnector())
|
||||
.RegisterMiddleware(functionCallMiddleware);
|
||||
modelName: model)
|
||||
.RegisterMiddleware(Print)
|
||||
.RegisterMiddleware(new OpenAIChatRequestMessageConnector())
|
||||
.RegisterMiddleware(functionCallMiddleware);
|
||||
|
||||
return teacher;
|
||||
}
|
||||
return teacher;
|
||||
}
|
||||
|
||||
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.AnswerQuestionFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
|
||||
});
|
||||
var student = new OpenAIChatAgent(
|
||||
openAIClient: client,
|
||||
name: "Student",
|
||||
modelName: model,
|
||||
systemMessage: @"You are a student. You answer math question from teacher.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.AnswerQuestionFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
|
||||
});
|
||||
var student = new OpenAIChatAgent(
|
||||
openAIClient: client,
|
||||
name: "Student",
|
||||
modelName: model,
|
||||
systemMessage: @"You are a student. You answer math question from teacher.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
|
||||
return student;
|
||||
}
|
||||
return student;
|
||||
}
|
||||
|
||||
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
|
||||
{
|
||||
var teacher2Student = Transition.Create(teacher, student);
|
||||
var student2Teacher = Transition.Create(student, teacher);
|
||||
var teacher2Admin = Transition.Create(teacher, admin);
|
||||
var admin2Teacher = Transition.Create(admin, teacher);
|
||||
var workflow = new Graph(
|
||||
[
|
||||
teacher2Student,
|
||||
student2Teacher,
|
||||
teacher2Admin,
|
||||
admin2Teacher,
|
||||
]);
|
||||
var group = new GroupChat(
|
||||
workflow: workflow,
|
||||
members: [
|
||||
admin,
|
||||
teacher,
|
||||
student,
|
||||
],
|
||||
admin: groupAdmin);
|
||||
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
|
||||
{
|
||||
var teacher2Student = Transition.Create(teacher, student);
|
||||
var student2Teacher = Transition.Create(student, teacher);
|
||||
var teacher2Admin = Transition.Create(teacher, admin);
|
||||
var admin2Teacher = Transition.Create(admin, teacher);
|
||||
var workflow = new Graph(
|
||||
[
|
||||
teacher2Student,
|
||||
student2Teacher,
|
||||
teacher2Admin,
|
||||
admin2Teacher,
|
||||
]);
|
||||
var group = new GroupChat(
|
||||
workflow: workflow,
|
||||
members: [
|
||||
admin,
|
||||
teacher,
|
||||
student,
|
||||
],
|
||||
admin: groupAdmin);
|
||||
|
||||
var groupChatManager = new GroupChatManager(group);
|
||||
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
|
||||
var groupChatManager = new GroupChatManager(group);
|
||||
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
|
||||
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
|
||||
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
|
||||
// check if there's terminate chat message from admin
|
||||
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
|
||||
.Count()
|
||||
.Should().Be(1);
|
||||
}
|
||||
// check if there's terminate chat message from admin
|
||||
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
|
||||
.Count()
|
||||
.Should().Be(1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,85 +3,84 @@
|
|||
using AutoGen.SourceGenerator.Template; // Needed for FunctionCallTemplate
|
||||
using Xunit; // Needed for Fact and Assert
|
||||
|
||||
namespace AutoGen.SourceGenerator.Tests
|
||||
namespace AutoGen.SourceGenerator.Tests;
|
||||
|
||||
public class FunctionCallTemplateEncodingTests
|
||||
{
|
||||
public class FunctionCallTemplateEncodingTests
|
||||
[Fact]
|
||||
public void FunctionDescription_Should_Encode_DoubleQuotes()
|
||||
{
|
||||
[Fact]
|
||||
public void FunctionDescription_Should_Encode_DoubleQuotes()
|
||||
// Arrange
|
||||
var functionContracts = new List<SourceGeneratorFunctionContract>
|
||||
{
|
||||
// Arrange
|
||||
var functionContracts = new List<SourceGeneratorFunctionContract>
|
||||
new SourceGeneratorFunctionContract
|
||||
{
|
||||
new SourceGeneratorFunctionContract
|
||||
Name = "TestFunction",
|
||||
Description = "This is a \"test\" function",
|
||||
Parameters = new SourceGeneratorParameterContract[]
|
||||
{
|
||||
Name = "TestFunction",
|
||||
Description = "This is a \"test\" function",
|
||||
Parameters = new SourceGeneratorParameterContract[]
|
||||
new SourceGeneratorParameterContract
|
||||
{
|
||||
new SourceGeneratorParameterContract
|
||||
{
|
||||
Name = "param1",
|
||||
Description = "This is a \"parameter\" description",
|
||||
Type = "string",
|
||||
IsOptional = false
|
||||
}
|
||||
},
|
||||
ReturnType = "void"
|
||||
}
|
||||
};
|
||||
Name = "param1",
|
||||
Description = "This is a \"parameter\" description",
|
||||
Type = "string",
|
||||
IsOptional = false
|
||||
}
|
||||
},
|
||||
ReturnType = "void"
|
||||
}
|
||||
};
|
||||
|
||||
var template = new FunctionCallTemplate
|
||||
{
|
||||
NameSpace = "TestNamespace",
|
||||
ClassName = "TestClass",
|
||||
FunctionContracts = functionContracts
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = template.TransformText();
|
||||
|
||||
// Assert
|
||||
Assert.Contains("Description = @\"This is a \"\"test\"\" function\"", result);
|
||||
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ParameterDescription_Should_Encode_DoubleQuotes()
|
||||
var template = new FunctionCallTemplate
|
||||
{
|
||||
// Arrange
|
||||
var functionContracts = new List<SourceGeneratorFunctionContract>
|
||||
NameSpace = "TestNamespace",
|
||||
ClassName = "TestClass",
|
||||
FunctionContracts = functionContracts
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = template.TransformText();
|
||||
|
||||
// Assert
|
||||
Assert.Contains("Description = @\"This is a \"\"test\"\" function\"", result);
|
||||
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ParameterDescription_Should_Encode_DoubleQuotes()
|
||||
{
|
||||
// Arrange
|
||||
var functionContracts = new List<SourceGeneratorFunctionContract>
|
||||
{
|
||||
new SourceGeneratorFunctionContract
|
||||
{
|
||||
new SourceGeneratorFunctionContract
|
||||
Name = "TestFunction",
|
||||
Description = "This is a test function",
|
||||
Parameters = new SourceGeneratorParameterContract[]
|
||||
{
|
||||
Name = "TestFunction",
|
||||
Description = "This is a test function",
|
||||
Parameters = new SourceGeneratorParameterContract[]
|
||||
new SourceGeneratorParameterContract
|
||||
{
|
||||
new SourceGeneratorParameterContract
|
||||
{
|
||||
Name = "param1",
|
||||
Description = "This is a \"parameter\" description",
|
||||
Type = "string",
|
||||
IsOptional = false
|
||||
}
|
||||
},
|
||||
ReturnType = "void"
|
||||
}
|
||||
};
|
||||
Name = "param1",
|
||||
Description = "This is a \"parameter\" description",
|
||||
Type = "string",
|
||||
IsOptional = false
|
||||
}
|
||||
},
|
||||
ReturnType = "void"
|
||||
}
|
||||
};
|
||||
|
||||
var template = new FunctionCallTemplate
|
||||
{
|
||||
NameSpace = "TestNamespace",
|
||||
ClassName = "TestClass",
|
||||
FunctionContracts = functionContracts
|
||||
};
|
||||
var template = new FunctionCallTemplate
|
||||
{
|
||||
NameSpace = "TestNamespace",
|
||||
ClassName = "TestClass",
|
||||
FunctionContracts = functionContracts
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = template.TransformText();
|
||||
// Act
|
||||
var result = template.TransformText();
|
||||
|
||||
// Assert
|
||||
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
|
||||
}
|
||||
// Assert
|
||||
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,122 +10,121 @@ using FluentAssertions;
|
|||
using OpenAI.Chat;
|
||||
using Xunit;
|
||||
|
||||
namespace AutoGen.SourceGenerator.Tests
|
||||
namespace AutoGen.SourceGenerator.Tests;
|
||||
|
||||
public class FunctionExample
|
||||
{
|
||||
public class FunctionExample
|
||||
private readonly FunctionExamples functionExamples = new FunctionExamples();
|
||||
private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions
|
||||
{
|
||||
private readonly FunctionExamples functionExamples = new FunctionExamples();
|
||||
private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions
|
||||
WriteIndented = true,
|
||||
};
|
||||
|
||||
[Fact]
|
||||
public void Add_Test()
|
||||
{
|
||||
var args = new
|
||||
{
|
||||
WriteIndented = true,
|
||||
a = 1,
|
||||
b = 2,
|
||||
};
|
||||
|
||||
[Fact]
|
||||
public void Add_Test()
|
||||
this.VerifyFunction(functionExamples.AddWrapper, args, 3);
|
||||
this.VerifyFunctionDefinition(functionExamples.AddFunctionContract.ToChatTool());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Sum_Test()
|
||||
{
|
||||
var args = new
|
||||
{
|
||||
var args = new
|
||||
args = new double[] { 1, 2, 3 },
|
||||
};
|
||||
|
||||
this.VerifyFunction(functionExamples.SumWrapper, args, 6.0);
|
||||
this.VerifyFunctionDefinition(functionExamples.SumFunctionContract.ToChatTool());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task DictionaryToString_Test()
|
||||
{
|
||||
var args = new
|
||||
{
|
||||
xargs = new Dictionary<string, string>
|
||||
{
|
||||
a = 1,
|
||||
b = 2,
|
||||
};
|
||||
{ "a", "1" },
|
||||
{ "b", "2" },
|
||||
},
|
||||
};
|
||||
|
||||
this.VerifyFunction(functionExamples.AddWrapper, args, 3);
|
||||
this.VerifyFunctionDefinition(functionExamples.AddFunctionContract.ToChatTool());
|
||||
}
|
||||
await this.VerifyAsyncFunction(functionExamples.DictionaryToStringAsyncWrapper, args, JsonSerializer.Serialize(args.xargs, jsonSerializerOptions));
|
||||
this.VerifyFunctionDefinition(functionExamples.DictionaryToStringAsyncFunctionContract.ToChatTool());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Sum_Test()
|
||||
[Fact]
|
||||
public async Task TopLevelFunctionExampleAddTestAsync()
|
||||
{
|
||||
var example = new TopLevelStatementFunctionExample();
|
||||
var args = new
|
||||
{
|
||||
var args = new
|
||||
{
|
||||
args = new double[] { 1, 2, 3 },
|
||||
};
|
||||
a = 1,
|
||||
b = 2,
|
||||
};
|
||||
|
||||
this.VerifyFunction(functionExamples.SumWrapper, args, 6.0);
|
||||
this.VerifyFunctionDefinition(functionExamples.SumFunctionContract.ToChatTool());
|
||||
}
|
||||
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task DictionaryToString_Test()
|
||||
[Fact]
|
||||
public async Task FilescopeFunctionExampleAddTestAsync()
|
||||
{
|
||||
var example = new FilescopeNamespaceFunctionExample();
|
||||
var args = new
|
||||
{
|
||||
var args = new
|
||||
{
|
||||
xargs = new Dictionary<string, string>
|
||||
{
|
||||
{ "a", "1" },
|
||||
{ "b", "2" },
|
||||
},
|
||||
};
|
||||
a = 1,
|
||||
b = 2,
|
||||
};
|
||||
|
||||
await this.VerifyAsyncFunction(functionExamples.DictionaryToStringAsyncWrapper, args, JsonSerializer.Serialize(args.xargs, jsonSerializerOptions));
|
||||
this.VerifyFunctionDefinition(functionExamples.DictionaryToStringAsyncFunctionContract.ToChatTool());
|
||||
}
|
||||
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TopLevelFunctionExampleAddTestAsync()
|
||||
[Fact]
|
||||
public void Query_Test()
|
||||
{
|
||||
var args = new
|
||||
{
|
||||
var example = new TopLevelStatementFunctionExample();
|
||||
var args = new
|
||||
{
|
||||
a = 1,
|
||||
b = 2,
|
||||
};
|
||||
query = "hello",
|
||||
k = 3,
|
||||
};
|
||||
|
||||
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
|
||||
}
|
||||
this.VerifyFunction(functionExamples.QueryWrapper, args, new[] { "hello", "hello", "hello" });
|
||||
this.VerifyFunctionDefinition(functionExamples.QueryFunctionContract.ToChatTool());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task FilescopeFunctionExampleAddTestAsync()
|
||||
[UseReporter(typeof(DiffReporter))]
|
||||
[UseApprovalSubdirectory("ApprovalTests")]
|
||||
private void VerifyFunctionDefinition(ChatTool function)
|
||||
{
|
||||
var func = new
|
||||
{
|
||||
var example = new FilescopeNamespaceFunctionExample();
|
||||
var args = new
|
||||
{
|
||||
a = 1,
|
||||
b = 2,
|
||||
};
|
||||
name = function.FunctionName,
|
||||
description = function.FunctionDescription.Replace(Environment.NewLine, ","),
|
||||
parameters = function.FunctionParameters.ToObjectFromJson<object>(options: jsonSerializerOptions),
|
||||
};
|
||||
|
||||
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
|
||||
}
|
||||
Approvals.Verify(JsonSerializer.Serialize(func, jsonSerializerOptions));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Query_Test()
|
||||
{
|
||||
var args = new
|
||||
{
|
||||
query = "hello",
|
||||
k = 3,
|
||||
};
|
||||
private void VerifyFunction<T, U>(Func<string, T> func, U args, T expected)
|
||||
{
|
||||
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
|
||||
var res = func(str);
|
||||
res.Should().BeEquivalentTo(expected);
|
||||
}
|
||||
|
||||
this.VerifyFunction(functionExamples.QueryWrapper, args, new[] { "hello", "hello", "hello" });
|
||||
this.VerifyFunctionDefinition(functionExamples.QueryFunctionContract.ToChatTool());
|
||||
}
|
||||
|
||||
[UseReporter(typeof(DiffReporter))]
|
||||
[UseApprovalSubdirectory("ApprovalTests")]
|
||||
private void VerifyFunctionDefinition(ChatTool function)
|
||||
{
|
||||
var func = new
|
||||
{
|
||||
name = function.FunctionName,
|
||||
description = function.FunctionDescription.Replace(Environment.NewLine, ","),
|
||||
parameters = function.FunctionParameters.ToObjectFromJson<object>(options: jsonSerializerOptions),
|
||||
};
|
||||
|
||||
Approvals.Verify(JsonSerializer.Serialize(func, jsonSerializerOptions));
|
||||
}
|
||||
|
||||
private void VerifyFunction<T, U>(Func<string, T> func, U args, T expected)
|
||||
{
|
||||
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
|
||||
var res = func(str);
|
||||
res.Should().BeEquivalentTo(expected);
|
||||
}
|
||||
|
||||
private async Task VerifyAsyncFunction<T, U>(Func<string, Task<T>> func, U args, T expected)
|
||||
{
|
||||
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
|
||||
var res = await func(str);
|
||||
res.Should().BeEquivalentTo(expected);
|
||||
}
|
||||
private async Task VerifyAsyncFunction<T, U>(Func<string, Task<T>> func, U args, T expected)
|
||||
{
|
||||
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
|
||||
var res = await func(str);
|
||||
res.Should().BeEquivalentTo(expected);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,67 +4,66 @@
|
|||
using System.Text.Json;
|
||||
using AutoGen.Core;
|
||||
|
||||
namespace AutoGen.SourceGenerator.Tests
|
||||
namespace AutoGen.SourceGenerator.Tests;
|
||||
|
||||
public partial class FunctionExamples
|
||||
{
|
||||
public partial class FunctionExamples
|
||||
/// <summary>
|
||||
/// Add function
|
||||
/// </summary>
|
||||
/// <param name="a">a</param>
|
||||
/// <param name="b">b</param>
|
||||
[FunctionAttribute]
|
||||
public int Add(int a, int b)
|
||||
{
|
||||
/// <summary>
|
||||
/// Add function
|
||||
/// </summary>
|
||||
/// <param name="a">a</param>
|
||||
/// <param name="b">b</param>
|
||||
[FunctionAttribute]
|
||||
public int Add(int a, int b)
|
||||
{
|
||||
return a + b;
|
||||
}
|
||||
return a + b;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Add two numbers.
|
||||
/// </summary>
|
||||
/// <param name="a">The first number.</param>
|
||||
/// <param name="b">The second number.</param>
|
||||
[Function]
|
||||
public Task<string> AddAsync(int a, int b)
|
||||
{
|
||||
return Task.FromResult($"{a} + {b} = {a + b}");
|
||||
}
|
||||
/// <summary>
|
||||
/// Add two numbers.
|
||||
/// </summary>
|
||||
/// <param name="a">The first number.</param>
|
||||
/// <param name="b">The second number.</param>
|
||||
[Function]
|
||||
public Task<string> AddAsync(int a, int b)
|
||||
{
|
||||
return Task.FromResult($"{a} + {b} = {a + b}");
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sum function
|
||||
/// </summary>
|
||||
/// <param name="args">an array of double values</param>
|
||||
[FunctionAttribute]
|
||||
public double Sum(double[] args)
|
||||
{
|
||||
return args.Sum();
|
||||
}
|
||||
/// <summary>
|
||||
/// Sum function
|
||||
/// </summary>
|
||||
/// <param name="args">an array of double values</param>
|
||||
[FunctionAttribute]
|
||||
public double Sum(double[] args)
|
||||
{
|
||||
return args.Sum();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// DictionaryToString function
|
||||
/// </summary>
|
||||
/// <param name="xargs">an object of key-value pairs. key is string, value is string</param>
|
||||
[FunctionAttribute]
|
||||
public Task<string> DictionaryToStringAsync(Dictionary<string, string> xargs)
|
||||
/// <summary>
|
||||
/// DictionaryToString function
|
||||
/// </summary>
|
||||
/// <param name="xargs">an object of key-value pairs. key is string, value is string</param>
|
||||
[FunctionAttribute]
|
||||
public Task<string> DictionaryToStringAsync(Dictionary<string, string> xargs)
|
||||
{
|
||||
var res = JsonSerializer.Serialize(xargs, new JsonSerializerOptions
|
||||
{
|
||||
var res = JsonSerializer.Serialize(xargs, new JsonSerializerOptions
|
||||
{
|
||||
WriteIndented = true,
|
||||
});
|
||||
WriteIndented = true,
|
||||
});
|
||||
|
||||
return Task.FromResult(res);
|
||||
}
|
||||
return Task.FromResult(res);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// query function
|
||||
/// </summary>
|
||||
/// <param name="query">query, required</param>
|
||||
/// <param name="k">top k, optional, default value is 3</param>
|
||||
/// <param name="thresold">thresold, optional, default value is 0.5</param>
|
||||
[FunctionAttribute]
|
||||
public string[] Query(string query, int k = 3, float thresold = 0.5f)
|
||||
{
|
||||
return Enumerable.Repeat(query, k).ToArray();
|
||||
}
|
||||
/// <summary>
|
||||
/// query function
|
||||
/// </summary>
|
||||
/// <param name="query">query, required</param>
|
||||
/// <param name="k">top k, optional, default value is 3</param>
|
||||
/// <param name="thresold">thresold, optional, default value is 0.5</param>
|
||||
[FunctionAttribute]
|
||||
public string[] Query(string query, int k = 3, float thresold = 0.5f)
|
||||
{
|
||||
return Enumerable.Repeat(query, k).ToArray();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,73 +7,72 @@ using System.Threading.Tasks;
|
|||
using AutoGen.BasicSample;
|
||||
using Xunit.Abstractions;
|
||||
|
||||
namespace AutoGen.Tests
|
||||
namespace AutoGen.Tests;
|
||||
|
||||
public class BasicSampleTest
|
||||
{
|
||||
public class BasicSampleTest
|
||||
private readonly ITestOutputHelper _output;
|
||||
|
||||
public BasicSampleTest(ITestOutputHelper output)
|
||||
{
|
||||
private readonly ITestOutputHelper _output;
|
||||
_output = output;
|
||||
Console.SetOut(new ConsoleWriter(_output));
|
||||
}
|
||||
|
||||
public BasicSampleTest(ITestOutputHelper output)
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task AssistantAgentTestAsync()
|
||||
{
|
||||
await Example01_AssistantAgent.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task TwoAgentMathClassTestAsync()
|
||||
{
|
||||
await Example02_TwoAgent_MathChat.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("OPENAI_API_KEY")]
|
||||
public async Task AgentFunctionCallTestAsync()
|
||||
{
|
||||
await Example03_Agent_FunctionCall.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("MISTRAL_API_KEY")]
|
||||
public async Task MistralClientAgent_TokenCount()
|
||||
{
|
||||
await Example14_MistralClientAgent_TokenCount.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task DynamicGroupChatCalculateFibonacciAsync()
|
||||
{
|
||||
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunAsync();
|
||||
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunWorkflowAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("OPENAI_API_KEY")]
|
||||
public async Task DalleAndGPT4VTestAsync()
|
||||
{
|
||||
await Example05_Dalle_And_GPT4V.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("OPENAI_API_KEY")]
|
||||
public async Task GPT4ImageMessage()
|
||||
{
|
||||
await Example15_GPT4V_BinaryDataImageMessage.RunAsync();
|
||||
}
|
||||
|
||||
public class ConsoleWriter : StringWriter
|
||||
{
|
||||
private ITestOutputHelper output;
|
||||
public ConsoleWriter(ITestOutputHelper output)
|
||||
{
|
||||
_output = output;
|
||||
Console.SetOut(new ConsoleWriter(_output));
|
||||
this.output = output;
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task AssistantAgentTestAsync()
|
||||
public override void WriteLine(string? m)
|
||||
{
|
||||
await Example01_AssistantAgent.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task TwoAgentMathClassTestAsync()
|
||||
{
|
||||
await Example02_TwoAgent_MathChat.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("OPENAI_API_KEY")]
|
||||
public async Task AgentFunctionCallTestAsync()
|
||||
{
|
||||
await Example03_Agent_FunctionCall.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("MISTRAL_API_KEY")]
|
||||
public async Task MistralClientAgent_TokenCount()
|
||||
{
|
||||
await Example14_MistralClientAgent_TokenCount.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task DynamicGroupChatCalculateFibonacciAsync()
|
||||
{
|
||||
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunAsync();
|
||||
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunWorkflowAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("OPENAI_API_KEY")]
|
||||
public async Task DalleAndGPT4VTestAsync()
|
||||
{
|
||||
await Example05_Dalle_And_GPT4V.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("OPENAI_API_KEY")]
|
||||
public async Task GPT4ImageMessage()
|
||||
{
|
||||
await Example15_GPT4V_BinaryDataImageMessage.RunAsync();
|
||||
}
|
||||
|
||||
public class ConsoleWriter : StringWriter
|
||||
{
|
||||
private ITestOutputHelper output;
|
||||
public ConsoleWriter(ITestOutputHelper output)
|
||||
{
|
||||
this.output = output;
|
||||
}
|
||||
|
||||
public override void WriteLine(string? m)
|
||||
{
|
||||
output.WriteLine(m);
|
||||
}
|
||||
output.WriteLine(m);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,18 +3,17 @@
|
|||
|
||||
using Xunit;
|
||||
|
||||
namespace AutoGen.Tests
|
||||
{
|
||||
public class GraphTests
|
||||
{
|
||||
[Fact]
|
||||
public void GraphTest()
|
||||
{
|
||||
var graph1 = new Graph();
|
||||
Assert.NotNull(graph1);
|
||||
namespace AutoGen.Tests;
|
||||
|
||||
var graph2 = new Graph(null);
|
||||
Assert.NotNull(graph2);
|
||||
}
|
||||
public class GraphTests
|
||||
{
|
||||
[Fact]
|
||||
public void GraphTest()
|
||||
{
|
||||
var graph1 = new Graph();
|
||||
Assert.NotNull(graph1);
|
||||
|
||||
var graph2 = new Graph(null);
|
||||
Assert.NotNull(graph2);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,219 +9,218 @@ using FluentAssertions;
|
|||
using Xunit;
|
||||
using Xunit.Abstractions;
|
||||
|
||||
namespace AutoGen.Tests
|
||||
namespace AutoGen.Tests;
|
||||
|
||||
public partial class SingleAgentTest
|
||||
{
|
||||
public partial class SingleAgentTest
|
||||
private ITestOutputHelper _output;
|
||||
public SingleAgentTest(ITestOutputHelper output)
|
||||
{
|
||||
private ITestOutputHelper _output;
|
||||
public SingleAgentTest(ITestOutputHelper output)
|
||||
{
|
||||
_output = output;
|
||||
}
|
||||
_output = output;
|
||||
}
|
||||
|
||||
private ILLMConfig CreateAzureOpenAIGPT35TurboConfig()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
|
||||
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
|
||||
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
|
||||
return new AzureOpenAIConfig(endpoint, deployName, key);
|
||||
}
|
||||
private ILLMConfig CreateAzureOpenAIGPT35TurboConfig()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
|
||||
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
|
||||
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
|
||||
return new AzureOpenAIConfig(endpoint, deployName, key);
|
||||
}
|
||||
|
||||
private ILLMConfig CreateOpenAIGPT4VisionConfig()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new ArgumentException("OPENAI_API_KEY is not set");
|
||||
return new OpenAIConfig(key, "gpt-4-vision-preview");
|
||||
}
|
||||
private ILLMConfig CreateOpenAIGPT4VisionConfig()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new ArgumentException("OPENAI_API_KEY is not set");
|
||||
return new OpenAIConfig(key, "gpt-4-vision-preview");
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task AssistantAgentFunctionCallTestAsync()
|
||||
{
|
||||
var config = this.CreateAzureOpenAIGPT35TurboConfig();
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task AssistantAgentFunctionCallTestAsync()
|
||||
{
|
||||
var config = this.CreateAzureOpenAIGPT35TurboConfig();
|
||||
|
||||
var llmConfig = new ConversableAgentConfig
|
||||
var llmConfig = new ConversableAgentConfig
|
||||
{
|
||||
Temperature = 0,
|
||||
FunctionContracts = new[]
|
||||
{
|
||||
Temperature = 0,
|
||||
FunctionContracts = new[]
|
||||
{
|
||||
this.EchoAsyncFunctionContract,
|
||||
},
|
||||
ConfigList = new[]
|
||||
{
|
||||
config,
|
||||
},
|
||||
};
|
||||
|
||||
var assistantAgent = new AssistantAgent(
|
||||
name: "assistant",
|
||||
llmConfig: llmConfig);
|
||||
|
||||
await EchoFunctionCallTestAsync(assistantAgent);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task AssistantAgentDefaultReplyTestAsync()
|
||||
{
|
||||
var assistantAgent = new AssistantAgent(
|
||||
llmConfig: null,
|
||||
name: "assistant",
|
||||
defaultReply: "hello world");
|
||||
|
||||
var reply = await assistantAgent.SendAsync("hi");
|
||||
|
||||
reply.GetContent().Should().Be("hello world");
|
||||
reply.GetRole().Should().Be(Role.Assistant);
|
||||
reply.From.Should().Be(assistantAgent.Name);
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task AssistantAgentFunctionCallSelfExecutionTestAsync()
|
||||
{
|
||||
var config = this.CreateAzureOpenAIGPT35TurboConfig();
|
||||
var llmConfig = new ConversableAgentConfig
|
||||
this.EchoAsyncFunctionContract,
|
||||
},
|
||||
ConfigList = new[]
|
||||
{
|
||||
FunctionContracts = new[]
|
||||
{
|
||||
this.EchoAsyncFunctionContract,
|
||||
},
|
||||
ConfigList = new[]
|
||||
{
|
||||
config,
|
||||
},
|
||||
};
|
||||
var assistantAgent = new AssistantAgent(
|
||||
name: "assistant",
|
||||
llmConfig: llmConfig,
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ nameof(EchoAsync), this.EchoAsyncWrapper },
|
||||
});
|
||||
config,
|
||||
},
|
||||
};
|
||||
|
||||
await EchoFunctionCallExecutionTestAsync(assistantAgent);
|
||||
}
|
||||
var assistantAgent = new AssistantAgent(
|
||||
name: "assistant",
|
||||
llmConfig: llmConfig);
|
||||
|
||||
/// <summary>
|
||||
/// echo when asked.
|
||||
/// </summary>
|
||||
/// <param name="message">message to echo</param>
|
||||
[FunctionAttribute]
|
||||
public async Task<string> EchoAsync(string message)
|
||||
await EchoFunctionCallTestAsync(assistantAgent);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task AssistantAgentDefaultReplyTestAsync()
|
||||
{
|
||||
var assistantAgent = new AssistantAgent(
|
||||
llmConfig: null,
|
||||
name: "assistant",
|
||||
defaultReply: "hello world");
|
||||
|
||||
var reply = await assistantAgent.SendAsync("hi");
|
||||
|
||||
reply.GetContent().Should().Be("hello world");
|
||||
reply.GetRole().Should().Be(Role.Assistant);
|
||||
reply.From.Should().Be(assistantAgent.Name);
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task AssistantAgentFunctionCallSelfExecutionTestAsync()
|
||||
{
|
||||
var config = this.CreateAzureOpenAIGPT35TurboConfig();
|
||||
var llmConfig = new ConversableAgentConfig
|
||||
{
|
||||
return $"[ECHO] {message}";
|
||||
}
|
||||
FunctionContracts = new[]
|
||||
{
|
||||
this.EchoAsyncFunctionContract,
|
||||
},
|
||||
ConfigList = new[]
|
||||
{
|
||||
config,
|
||||
},
|
||||
};
|
||||
var assistantAgent = new AssistantAgent(
|
||||
name: "assistant",
|
||||
llmConfig: llmConfig,
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ nameof(EchoAsync), this.EchoAsyncWrapper },
|
||||
});
|
||||
|
||||
/// <summary>
|
||||
/// return the label name with hightest inference cost
|
||||
/// </summary>
|
||||
/// <param name="labelName"></param>
|
||||
/// <returns></returns>
|
||||
[FunctionAttribute]
|
||||
public async Task<string> GetHighestLabel(string labelName, string color)
|
||||
await EchoFunctionCallExecutionTestAsync(assistantAgent);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// echo when asked.
|
||||
/// </summary>
|
||||
/// <param name="message">message to echo</param>
|
||||
[FunctionAttribute]
|
||||
public async Task<string> EchoAsync(string message)
|
||||
{
|
||||
return $"[ECHO] {message}";
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// return the label name with hightest inference cost
|
||||
/// </summary>
|
||||
/// <param name="labelName"></param>
|
||||
/// <returns></returns>
|
||||
[FunctionAttribute]
|
||||
public async Task<string> GetHighestLabel(string labelName, string color)
|
||||
{
|
||||
return $"[HIGHEST_LABEL] {labelName} {color}";
|
||||
}
|
||||
|
||||
public async Task EchoFunctionCallTestAsync(IAgent agent)
|
||||
{
|
||||
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that call echo function");
|
||||
var helloWorld = new TextMessage(Role.User, "echo Hello world");
|
||||
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
|
||||
|
||||
reply.From.Should().Be(agent.Name);
|
||||
reply.GetToolCalls()!.First().FunctionName.Should().Be(nameof(EchoAsync));
|
||||
}
|
||||
|
||||
public async Task EchoFunctionCallExecutionTestAsync(IAgent agent)
|
||||
{
|
||||
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
|
||||
var helloWorld = new TextMessage(Role.User, "echo Hello world");
|
||||
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
|
||||
|
||||
reply.GetContent().Should().Be("[ECHO] Hello world");
|
||||
reply.From.Should().Be(agent.Name);
|
||||
reply.Should().BeOfType<ToolCallAggregateMessage>();
|
||||
}
|
||||
|
||||
public async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent agent)
|
||||
{
|
||||
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
|
||||
var helloWorld = new TextMessage(Role.User, "echo Hello world");
|
||||
var option = new GenerateReplyOptions
|
||||
{
|
||||
return $"[HIGHEST_LABEL] {labelName} {color}";
|
||||
}
|
||||
|
||||
public async Task EchoFunctionCallTestAsync(IAgent agent)
|
||||
Temperature = 0,
|
||||
};
|
||||
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { helloWorld }, option);
|
||||
var answer = "[ECHO] Hello world";
|
||||
IMessage? finalReply = default;
|
||||
await foreach (var reply in replyStream)
|
||||
{
|
||||
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that call echo function");
|
||||
var helloWorld = new TextMessage(Role.User, "echo Hello world");
|
||||
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
|
||||
|
||||
reply.From.Should().Be(agent.Name);
|
||||
reply.GetToolCalls()!.First().FunctionName.Should().Be(nameof(EchoAsync));
|
||||
finalReply = reply;
|
||||
}
|
||||
|
||||
public async Task EchoFunctionCallExecutionTestAsync(IAgent agent)
|
||||
if (finalReply is ToolCallAggregateMessage aggregateMessage)
|
||||
{
|
||||
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
|
||||
var helloWorld = new TextMessage(Role.User, "echo Hello world");
|
||||
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
|
||||
|
||||
reply.GetContent().Should().Be("[ECHO] Hello world");
|
||||
reply.From.Should().Be(agent.Name);
|
||||
reply.Should().BeOfType<ToolCallAggregateMessage>();
|
||||
var toolCallResultMessage = aggregateMessage.Message2;
|
||||
toolCallResultMessage.ToolCalls.First().Result.Should().Be(answer);
|
||||
toolCallResultMessage.From.Should().Be(agent.Name);
|
||||
toolCallResultMessage.ToolCalls.First().FunctionName.Should().Be(nameof(EchoAsync));
|
||||
}
|
||||
|
||||
public async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent agent)
|
||||
else
|
||||
{
|
||||
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
|
||||
var helloWorld = new TextMessage(Role.User, "echo Hello world");
|
||||
var option = new GenerateReplyOptions
|
||||
{
|
||||
Temperature = 0,
|
||||
};
|
||||
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { helloWorld }, option);
|
||||
var answer = "[ECHO] Hello world";
|
||||
IMessage? finalReply = default;
|
||||
await foreach (var reply in replyStream)
|
||||
{
|
||||
reply.From.Should().Be(agent.Name);
|
||||
finalReply = reply;
|
||||
}
|
||||
|
||||
if (finalReply is ToolCallAggregateMessage aggregateMessage)
|
||||
{
|
||||
var toolCallResultMessage = aggregateMessage.Message2;
|
||||
toolCallResultMessage.ToolCalls.First().Result.Should().Be(answer);
|
||||
toolCallResultMessage.From.Should().Be(agent.Name);
|
||||
toolCallResultMessage.ToolCalls.First().FunctionName.Should().Be(nameof(EchoAsync));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new Exception("unexpected message type");
|
||||
}
|
||||
}
|
||||
|
||||
public async Task UpperCaseTestAsync(IAgent agent)
|
||||
{
|
||||
var message = new TextMessage(Role.User, "Please convert abcde to upper case.");
|
||||
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { message });
|
||||
|
||||
reply.GetContent().Should().Contain("ABCDE");
|
||||
reply.From.Should().Be(agent.Name);
|
||||
}
|
||||
|
||||
public async Task UpperCaseStreamingTestAsync(IStreamingAgent agent)
|
||||
{
|
||||
var message = new TextMessage(Role.User, "Please convert 'hello world' to upper case");
|
||||
var option = new GenerateReplyOptions
|
||||
{
|
||||
Temperature = 0,
|
||||
};
|
||||
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message }, option);
|
||||
var answer = "HELLO WORLD";
|
||||
TextMessage? finalReply = default;
|
||||
await foreach (var reply in replyStream)
|
||||
{
|
||||
if (reply is TextMessageUpdate update)
|
||||
{
|
||||
update.From.Should().Be(agent.Name);
|
||||
|
||||
if (finalReply is null)
|
||||
{
|
||||
finalReply = new TextMessage(update);
|
||||
}
|
||||
else
|
||||
{
|
||||
finalReply.Update(update);
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
else if (reply is TextMessage textMessage)
|
||||
{
|
||||
finalReply = textMessage;
|
||||
continue;
|
||||
}
|
||||
|
||||
throw new Exception("unexpected message type");
|
||||
}
|
||||
|
||||
finalReply!.Content.Should().Contain(answer);
|
||||
finalReply!.Role.Should().Be(Role.Assistant);
|
||||
finalReply!.From.Should().Be(agent.Name);
|
||||
throw new Exception("unexpected message type");
|
||||
}
|
||||
}
|
||||
|
||||
public async Task UpperCaseTestAsync(IAgent agent)
|
||||
{
|
||||
var message = new TextMessage(Role.User, "Please convert abcde to upper case.");
|
||||
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { message });
|
||||
|
||||
reply.GetContent().Should().Contain("ABCDE");
|
||||
reply.From.Should().Be(agent.Name);
|
||||
}
|
||||
|
||||
public async Task UpperCaseStreamingTestAsync(IStreamingAgent agent)
|
||||
{
|
||||
var message = new TextMessage(Role.User, "Please convert 'hello world' to upper case");
|
||||
var option = new GenerateReplyOptions
|
||||
{
|
||||
Temperature = 0,
|
||||
};
|
||||
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message }, option);
|
||||
var answer = "HELLO WORLD";
|
||||
TextMessage? finalReply = default;
|
||||
await foreach (var reply in replyStream)
|
||||
{
|
||||
if (reply is TextMessageUpdate update)
|
||||
{
|
||||
update.From.Should().Be(agent.Name);
|
||||
|
||||
if (finalReply is null)
|
||||
{
|
||||
finalReply = new TextMessage(update);
|
||||
}
|
||||
else
|
||||
{
|
||||
finalReply.Update(update);
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
else if (reply is TextMessage textMessage)
|
||||
{
|
||||
finalReply = textMessage;
|
||||
continue;
|
||||
}
|
||||
|
||||
throw new Exception("unexpected message type");
|
||||
}
|
||||
|
||||
finalReply!.Content.Should().Contain(answer);
|
||||
finalReply!.Role.Should().Be(Role.Assistant);
|
||||
finalReply!.From.Should().Be(agent.Name);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue