Compare commits

...

16 Commits

Author SHA1 Message Date
Mohammad Mazraeh 9028bb6d18
remove extra files
Signed-off-by: Mohammad Mazraeh <mazraeh.mohammad@gmail.com>
2024-10-31 03:12:22 -07:00
Mohammad Mazraeh a6f0c7cc8a python/packages/autogen-core/samples/distributed-group-chat/public/avatars/group_chat_manager.png: convert to Git LFS 2024-10-31 03:09:14 -07:00
Mohammad Mazraeh 2b20a74ef0 python/packages/autogen-core/samples/distributed-group-chat/public/avatars/editor.png: convert to Git LFS 2024-10-31 03:09:06 -07:00
Mohammad Mazraeh ba0be1f78f python/packages/autogen-core/samples/distributed-group-chat/public/avatars/writer.png: convert to Git LFS 2024-10-31 03:09:00 -07:00
Mohammad Mazraeh ccf94bfdb5 python/packages/autogen-core/samples/distributed-group-chat/public/logo.png: convert to Git LFS 2024-10-31 03:08:52 -07:00
Mohammad Mazraeh 07603e1054 python/packages/autogen-core/samples/distributed-group-chat/public/favicon.png: convert to Git LFS 2024-10-31 03:08:33 -07:00
Mohammad Mazraeh a6bf6eff49 python/packages/autogen-core/samples/distributed-group-chat/public/avatars/user.png: convert to Git LFS 2024-10-31 03:07:57 -07:00
Mohammad Mazraeh d98e8b2fdf
fix pyright issue
Signed-off-by: Mohammad Mazraeh <mazraeh.mohammad@gmail.com>
2024-10-31 02:49:57 -07:00
Mohammad Mazraeh e83217ca1a
add video and some cleanup
Signed-off-by: Mohammad Mazraeh <mazraeh.mohammad@gmail.com>
2024-10-31 02:31:47 -07:00
Mohammad Mazraeh 9f0e747e93
Merge branch 'main' into add-chainlit-to-distributed-group-chat 2024-10-31 08:08:01 +00:00
Rohan Thacker 3c63f6f3ef
Corrected typo in get_capabilities in _model_info.py (#4002) 2024-10-30 13:39:45 -07:00
Xiaoyun Zhang 6bea055b26
[.Net] Add a generic `IHandle` interface so AgentRuntime doesn't need to deal with typed handler (#3985)
* add IHandle for object type

* rename handle -> handleObject

* remove duplicate file header setting

* update

* remove AgentId

* fix format
2024-10-30 11:53:37 -07:00
Eric Zhu 3d51ab76ae
Formalize `ChatAgent` response as a dataclass with inner messages (#3990) 2024-10-30 10:27:57 -07:00
Xiaoyun Zhang e63fd17ed5
[.Net] use file-scope (#3997)
* use file-scope

* reformat
2024-10-30 10:05:58 -07:00
Ryan Sweet 51cd5b8d1f
interface inheritance examples (#3989)
changes to AgentBase and HostBuilderExtensions to enable leveraging handlers from composition (interfaces) vs inheritance... see HelloAgents sample for usage

closes #3928
is related to #3925
2024-10-30 09:51:01 -07:00
Eric Zhu 4a49844996
`ChatAgent` declares the types of messages it produces (#3991)
* `ChatAgent` declares the types of messages it produces
2024-10-30 05:32:11 -07:00
130 changed files with 1596 additions and 1084 deletions

View File

@ -193,10 +193,6 @@ csharp_using_directive_placement = outside_namespace:error
csharp_prefer_static_local_function = true:warning
csharp_preferred_modifier_order = public,private,protected,internal,static,extern,new,virtual,abstract,sealed,override,readonly,unsafe,volatile,async:warning
# Header template
file_header_template = Copyright (c) Microsoft Corporation. All rights reserved.\n{fileName}
dotnet_diagnostic.IDE0073.severity = error
# enable format error
dotnet_diagnostic.IDE0055.severity = error
@ -221,13 +217,14 @@ dotnet_diagnostic.IDE0161.severity = warning # Use file-scoped namespace
csharp_style_var_elsewhere = true:suggestion # Prefer 'var' everywhere
csharp_prefer_simple_using_statement = true:suggestion
csharp_style_namespace_declarations = block_scoped:silent
csharp_style_namespace_declarations = file_scoped:warning
csharp_style_prefer_method_group_conversion = true:silent
csharp_style_prefer_top_level_statements = true:silent
csharp_style_prefer_primary_constructors = true:suggestion
csharp_style_expression_bodied_lambdas = true:silent
csharp_style_prefer_local_over_anonymous_function = true:suggestion
dotnet_diagnostic.CA2016.severity = suggestion
csharp_prefer_static_anonymous_function = true:suggestion
# disable check for generated code
[*.generated.cs]
@ -556,8 +553,8 @@ dotnet_diagnostic.IDE0060.severity = warning
dotnet_diagnostic.IDE0062.severity = warning
# IDE0073: File header
dotnet_diagnostic.IDE0073.severity = suggestion
file_header_template = Copyright (c) Microsoft. All rights reserved.
dotnet_diagnostic.IDE0073.severity = warning
file_header_template = Copyright (c) Microsoft Corporation. All rights reserved.\n{fileName}
# IDE1006: Required naming style
dotnet_diagnostic.IDE1006.severity = warning
@ -697,6 +694,7 @@ dotnet_style_prefer_compound_assignment = true:suggestion
dotnet_style_prefer_simplified_interpolation = true:suggestion
dotnet_style_prefer_collection_expression = when_types_loosely_match:suggestion
dotnet_style_namespace_match_folder = true:suggestion
dotnet_style_qualification_for_method = false:silent
[**/*.g.cs]
generated_code = true

View File

@ -123,9 +123,11 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HelloAgent", "samples\Hello
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AIModelClientHostingExtensions", "src\Microsoft.AutoGen\Extensions\AIModelClientHostingExtensions\AIModelClientHostingExtensions.csproj", "{97550E87-48C6-4EBF-85E1-413ABAE9DBFD}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.Agents.Tests", "Microsoft.AutoGen.Agents.Tests\Microsoft.AutoGen.Agents.Tests.csproj", "{CF4C92BD-28AE-4B8F-B173-601004AEC9BF}"
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "sample", "sample", "{686480D7-8FEC-4ED3-9C5D-CEBE1057A7ED}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "HelloAgentState", "samples\Hello\HelloAgentState\HelloAgentState.csproj", "{64EF61E7-00A6-4E5E-9808-62E10993A0E5}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HelloAgentState", "samples\Hello\HelloAgentState\HelloAgentState.csproj", "{64EF61E7-00A6-4E5E-9808-62E10993A0E5}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
@ -337,6 +339,10 @@ Global
{97550E87-48C6-4EBF-85E1-413ABAE9DBFD}.Debug|Any CPU.Build.0 = Debug|Any CPU
{97550E87-48C6-4EBF-85E1-413ABAE9DBFD}.Release|Any CPU.ActiveCfg = Release|Any CPU
{97550E87-48C6-4EBF-85E1-413ABAE9DBFD}.Release|Any CPU.Build.0 = Release|Any CPU
{CF4C92BD-28AE-4B8F-B173-601004AEC9BF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{CF4C92BD-28AE-4B8F-B173-601004AEC9BF}.Debug|Any CPU.Build.0 = Debug|Any CPU
{CF4C92BD-28AE-4B8F-B173-601004AEC9BF}.Release|Any CPU.ActiveCfg = Release|Any CPU
{CF4C92BD-28AE-4B8F-B173-601004AEC9BF}.Release|Any CPU.Build.0 = Release|Any CPU
{64EF61E7-00A6-4E5E-9808-62E10993A0E5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{64EF61E7-00A6-4E5E-9808-62E10993A0E5}.Debug|Any CPU.Build.0 = Debug|Any CPU
{64EF61E7-00A6-4E5E-9808-62E10993A0E5}.Release|Any CPU.ActiveCfg = Release|Any CPU
@ -401,6 +407,7 @@ Global
{A20B9894-F352-4338-872A-F215A241D43D} = {7EB336C2-7C0A-4BC8-80C6-A3173AB8DC45}
{8F7560CF-EEBB-4333-A69F-838CA40FD85D} = {7EB336C2-7C0A-4BC8-80C6-A3173AB8DC45}
{97550E87-48C6-4EBF-85E1-413ABAE9DBFD} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
{CF4C92BD-28AE-4B8F-B173-601004AEC9BF} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{64EF61E7-00A6-4E5E-9808-62E10993A0E5} = {7EB336C2-7C0A-4BC8-80C6-A3173AB8DC45}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution

View File

@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentBaseTests.cs
using FluentAssertions;
using Google.Protobuf.Reflection;
using Microsoft.AutoGen.Abstractions;
using Moq;
using Xunit;
namespace Microsoft.AutoGen.Agents.Tests;
public class AgentBaseTests
{
[Fact]
public async Task ItInvokeRightHandlerTestAsync()
{
var mockContext = new Mock<IAgentContext>();
var agent = new TestAgent(mockContext.Object, new EventTypes(TypeRegistry.Empty, [], []));
await agent.HandleObject("hello world");
await agent.HandleObject(42);
agent.ReceivedItems.Should().HaveCount(2);
agent.ReceivedItems[0].Should().Be("hello world");
agent.ReceivedItems[1].Should().Be(42);
}
/// <summary>
/// The test agent is a simple agent that is used for testing purposes.
/// </summary>
public class TestAgent : AgentBase, IHandle<string>, IHandle<int>
{
public TestAgent(IAgentContext context, EventTypes eventTypes) : base(context, eventTypes)
{
}
public Task Handle(string item)
{
ReceivedItems.Add(item);
return Task.CompletedTask;
}
public Task Handle(int item)
{
ReceivedItems.Add(item);
return Task.CompletedTask;
}
public List<object> ReceivedItems { get; private set; } = [];
}
}

View File

@ -0,0 +1,14 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFrameworks>$(TestTargetFrameworks)</TargetFrameworks>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<IsTestProject>True</IsTestProject>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\src\Microsoft.AutoGen\Agents\Microsoft.AutoGen.Agents.csproj" />
</ItemGroup>
</Project>

View File

@ -1,5 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentCodeSnippet.cs
using AutoGen.Core;
namespace AutoGen.BasicSample.CodeSnippet;

View File

@ -1,5 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// UserProxyAgentCodeSnippet.cs
using AutoGen.Core;
namespace AutoGen.BasicSample.CodeSnippet;

View File

@ -1,5 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Example06_UserProxyAgent.cs
using AutoGen.Core;
using AutoGen.OpenAI;
using AutoGen.OpenAI.Extension;

View File

@ -1,4 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Program.cs
//await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunAsync();

View File

@ -1,4 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Connect_To_Azure_OpenAI.cs
#region using_statement
using System.ClientModel;

View File

@ -1,4 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Program.cs
using Microsoft.Extensions.Hosting;
var app = await Microsoft.AutoGen.Runtime.Host.StartAsync(local: true);

View File

@ -1,4 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Program.cs
var builder = DistributedApplication.CreateBuilder(args);
var backend = builder.AddProject<Projects.Backend>("backend");

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// HelloAIAgent.cs
using Microsoft.AutoGen.Abstractions;
using Microsoft.AutoGen.Agents;
using Microsoft.Extensions.AI;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Program.cs
using Hello;
using Microsoft.AspNetCore.Builder;
using Microsoft.AutoGen.Abstractions;

View File

@ -1,11 +1,20 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Program.cs
using Microsoft.AutoGen.Abstractions;
using Microsoft.AutoGen.Agents;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
// send a message to the agent
// step 1: create in-memory agent runtime
// step 2: register HelloAgent to that agent runtime
// step 3: start the agent runtime
// step 4: send a message to the agent
// step 5: wait for the agent runtime to shutdown
var app = await AgentsApp.PublishMessageAsync("HelloAgents", new NewMessageReceived
{
Message = "World"
@ -18,10 +27,11 @@ namespace Hello
[TopicSubscription("HelloAgents")]
public class HelloAgent(
IAgentContext context,
[FromKeyedServices("EventTypes")] EventTypes typeRegistry) : ConsoleAgent(
[FromKeyedServices("EventTypes")] EventTypes typeRegistry) : AgentBase(
context,
typeRegistry),
ISayHello,
IHandleConsole,
IHandle<NewMessageReceived>,
IHandle<ConversationClosed>
{

View File

@ -1,4 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Program.cs
using Microsoft.AutoGen.Abstractions;
using Microsoft.AutoGen.Agents;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Program.cs
using Microsoft.AutoGen.Runtime;
var builder = WebApplication.CreateBuilder(args);

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Developer.cs
using DevTeam.Shared;
using Microsoft.AutoGen.Abstractions;
using Microsoft.AutoGen.Agents;

View File

@ -1,3 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// DeveloperPrompts.cs
namespace DevTeam.Agents;
public static class DeveloperSkills

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// DeveloperLead.cs
using DevTeam.Shared;
using Microsoft.AutoGen.Abstractions;
using Microsoft.AutoGen.Agents;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// DeveloperLeadPrompts.cs
namespace DevTeam.Agents;
public static class DevLeadSkills
{

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// PMPrompts.cs
namespace DevTeam.Agents;
public static class PMSkills
{

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ProductManager.cs
using DevTeam.Shared;
using Microsoft.AutoGen.Abstractions;
using Microsoft.AutoGen.Agents;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Program.cs
using DevTeam.Agents;
using Microsoft.AutoGen.Agents;
using Microsoft.AutoGen.Extensions.SemanticKernel;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Program.cs
var builder = DistributedApplication.CreateBuilder(args);
builder.AddAzureProvisioning();

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AzureGenie.cs
using DevTeam.Backend;
using DevTeam.Shared;
using Microsoft.AutoGen.Abstractions;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Hubber.cs
using System.Text.Json;
using DevTeam;
using DevTeam.Backend;

View File

@ -1,7 +1,5 @@
// TODO: Reimplement using ACA Sessions
// using DevTeam.Events;
// using Microsoft.AutoGen.Abstractions;
// using Microsoft.AutoGen.Agents;
// Copyright (c) Microsoft Corporation. All rights reserved.
// Sandbox.cs
// namespace DevTeam.Backend;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Program.cs
using Azure.Identity;
using DevTeam.Backend;
using DevTeam.Options;

View File

@ -1,3 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AzureService.cs
using System.Text;
using Azure;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GithubAuthService.cs
using System.IdentityModel.Tokens.Jwt;
using System.Security.Claims;
using System.Security.Cryptography;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GithubService.cs
using System.Text;
using Azure.Storage.Files.Shares;
using DevTeam.Options;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GithubWebHookProcessor.cs
using System.Globalization;
using DevTeam.Shared;
using Microsoft.AutoGen.Abstractions;

View File

@ -1,4 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// EventExtensions.cs
using System.Globalization;
using Microsoft.AutoGen.Abstractions;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// DevPlan.cs
namespace DevTeam;
public class DevLeadPlan
{

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AzureOptions.cs
using System.ComponentModel.DataAnnotations;
namespace DevTeam.Options;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GithubOptions.cs
using System.ComponentModel.DataAnnotations;
namespace DevTeam.Options;

View File

@ -1,4 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// ParseExtensions.cs
namespace DevTeam;

View File

@ -1,5 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ChatCompletionRequest.cs
using System.Collections.Generic;
using System.Text.Json.Serialization;

View File

@ -1,4 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentExtension.cs
using System;
using System.Collections.Generic;

View File

@ -1,4 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// GroupChat.cs
using System;
using System.Collections.Generic;

View File

@ -5,45 +5,44 @@ using System;
using System.Collections.Generic;
using System.Linq;
namespace AutoGen
namespace AutoGen;
public static class LLMConfigAPI
{
public static class LLMConfigAPI
public static IEnumerable<ILLMConfig> GetOpenAIConfigList(
string apiKey,
IEnumerable<string>? modelIDs = null)
{
public static IEnumerable<ILLMConfig> GetOpenAIConfigList(
string apiKey,
IEnumerable<string>? modelIDs = null)
var models = modelIDs ?? new[]
{
var models = modelIDs ?? new[]
{
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-4-32k",
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-4-1106-preview",
};
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-4-32k",
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-4-1106-preview",
};
return models.Select(modelId => new OpenAIConfig(apiKey, modelId));
}
return models.Select(modelId => new OpenAIConfig(apiKey, modelId));
}
public static IEnumerable<ILLMConfig> GetAzureOpenAIConfigList(
string endpoint,
string apiKey,
IEnumerable<string> deploymentNames)
{
return deploymentNames.Select(deploymentName => new AzureOpenAIConfig(endpoint, deploymentName, apiKey));
}
public static IEnumerable<ILLMConfig> GetAzureOpenAIConfigList(
string endpoint,
string apiKey,
IEnumerable<string> deploymentNames)
{
return deploymentNames.Select(deploymentName => new AzureOpenAIConfig(endpoint, deploymentName, apiKey));
}
/// <summary>
/// Get a list of LLMConfig objects from a JSON file.
/// </summary>
internal static IEnumerable<ILLMConfig> ConfigListFromJson(
string filePath,
IEnumerable<string>? filterModels = null)
{
// Disable this API from documentation for now.
throw new NotImplementedException();
}
/// <summary>
/// Get a list of LLMConfig objects from a JSON file.
/// </summary>
internal static IEnumerable<ILLMConfig> ConfigListFromJson(
string filePath,
IEnumerable<string>? filterModels = null)
{
// Disable this API from documentation for now.
throw new NotImplementedException();
}
}

View File

@ -1,5 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// LMStudioConfig.cs
using System;
using System.ClientModel;
using OpenAI;

View File

@ -0,0 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentId.cs
namespace Microsoft.AutoGen.Abstractions;
public partial class AgentId
{
public AgentId(string type, string key)
{
Type = type;
Key = key;
}
}

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ChatHistoryItem.cs
namespace Microsoft.AutoGen.Abstractions;
[Serializable]

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ChatState.cs
using Google.Protobuf;
namespace Microsoft.AutoGen.Abstractions;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ChatUserType.cs
namespace Microsoft.AutoGen.Abstractions;
public enum ChatUserType

View File

@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IAgentBase.cs
using Google.Protobuf;
namespace Microsoft.AutoGen.Abstractions;
public interface IAgentBase
{
// Properties
AgentId AgentId { get; }
IAgentContext Context { get; }
// Methods
Task CallHandler(CloudEvent item);
Task<RpcResponse> HandleRequest(RpcRequest request);
void ReceiveMessage(Message message);
Task Store(AgentState state);
Task<T> Read<T>(AgentId agentId) where T : IMessage, new();
ValueTask PublishEvent(CloudEvent item);
}

View File

@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IAgentContext.cs
using System.Diagnostics;
using Microsoft.Extensions.Logging;
namespace Microsoft.AutoGen.Abstractions;
public interface IAgentContext
{
AgentId AgentId { get; }
IAgentBase? AgentInstance { get; set; }
DistributedContextPropagator DistributedContextPropagator { get; } // TODO: Remove this. An abstraction should not have a dependency on DistributedContextPropagator.
ILogger Logger { get; } // TODO: Remove this. An abstraction should not have a dependency on ILogger.
ValueTask Store(AgentState value);
ValueTask<AgentState> Read(AgentId agentId);
ValueTask SendResponseAsync(RpcRequest request, RpcResponse response);
ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request);
ValueTask PublishEventAsync(CloudEvent @event);
}

View File

@ -1,8 +1,8 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// IAgentWorkerRuntime.cs
using Microsoft.AutoGen.Abstractions;
namespace Microsoft.AutoGen.Abstractions;
namespace Microsoft.AutoGen.Agents;
public interface IAgentWorkerRuntime
{
ValueTask PublishEvent(CloudEvent evt);

View File

@ -1,6 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IHandle.cs
namespace Microsoft.AutoGen.Abstractions;
public interface IHandle<T>
public interface IHandle
{
Task HandleObject(object item);
}
public interface IHandle<T> : IHandle
{
Task Handle(T item);
}

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// MessageExtensions.cs
using Google.Protobuf;
using Google.Protobuf.WellKnownTypes;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// TopicSubscriptionAttribute.cs
namespace Microsoft.AutoGen.Abstractions;
[AttributeUsage(AttributeTargets.All)]

View File

@ -1,4 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentBase.cs
using System.Diagnostics;
using System.Reflection;
using System.Text;
using System.Text.Json;
using System.Threading.Channels;
@ -8,7 +12,7 @@ using Microsoft.Extensions.Logging;
namespace Microsoft.AutoGen.Agents;
public abstract class AgentBase : IAgentBase
public abstract class AgentBase : IAgentBase, IHandle
{
public static readonly ActivitySource s_source = new("AutoGen.Agent");
public AgentId AgentId => _context.AgentId;
@ -17,6 +21,8 @@ public abstract class AgentBase : IAgentBase
private readonly Channel<object> _mailbox = Channel.CreateUnbounded<object>();
private readonly IAgentContext _context;
public string Route { get; set; } = "base";
protected internal ILogger Logger => _context.Logger;
public IAgentContext Context => _context;
protected readonly EventTypes EventTypes;
@ -212,16 +218,59 @@ public abstract class AgentBase : IAgentBase
public Task CallHandler(CloudEvent item)
{
// Only send the event to the handler if the agent type is handling that type
if (EventTypes.EventsMap[GetType()].Contains(item.Type))
// foreach of the keys in the EventTypes.EventsMap[] if it contains the item.type
foreach (var key in EventTypes.EventsMap.Keys)
{
var payload = item.ProtoData.Unpack(EventTypes.TypeRegistry);
var convertedPayload = Convert.ChangeType(payload, EventTypes.Types[item.Type]);
var genericInterfaceType = typeof(IHandle<>).MakeGenericType(EventTypes.Types[item.Type]);
var methodInfo = genericInterfaceType.GetMethod(nameof(IHandle<object>.Handle)) ?? throw new InvalidOperationException($"Method not found on type {genericInterfaceType.FullName}");
return methodInfo.Invoke(this, [payload]) as Task ?? Task.CompletedTask;
if (EventTypes.EventsMap[key].Contains(item.Type))
{
var payload = item.ProtoData.Unpack(EventTypes.TypeRegistry);
var convertedPayload = Convert.ChangeType(payload, EventTypes.Types[item.Type]);
var genericInterfaceType = typeof(IHandle<>).MakeGenericType(EventTypes.Types[item.Type]);
MethodInfo methodInfo;
try
{
// check that our target actually implements this interface, otherwise call the default static
if (genericInterfaceType.IsAssignableFrom(this.GetType()))
{
methodInfo = genericInterfaceType.GetMethod(nameof(IHandle<object>.Handle), BindingFlags.Public | BindingFlags.Instance)
?? throw new InvalidOperationException($"Method not found on type {genericInterfaceType.FullName}");
return methodInfo.Invoke(this, [payload]) as Task ?? Task.CompletedTask;
}
else
{
// The error here is we have registered for an event that we do not have code to listen to
throw new InvalidOperationException($"No handler found for event '{item.Type}'; expecting IHandle<{item.Type}> implementation.");
}
}
catch (Exception ex)
{
Logger.LogError(ex, $"Error invoking method {nameof(IHandle<object>.Handle)}");
throw; // TODO: ?
}
}
}
return Task.CompletedTask;
}
public virtual Task<RpcResponse> HandleRequest(RpcRequest request) => Task.FromResult(new RpcResponse { Error = "Not implemented" });
public Task<RpcResponse> HandleRequest(RpcRequest request) => Task.FromResult(new RpcResponse { Error = "Not implemented" });
public virtual Task HandleObject(object item)
{
// get all Handle<T> methods
var handleTMethods = this.GetType().GetMethods().Where(m => m.Name == "Handle" && m.GetParameters().Length == 1).ToList();
// get the one that matches the type of the item
var handleTMethod = handleTMethods.FirstOrDefault(m => m.GetParameters()[0].ParameterType == item.GetType());
// if we found one, invoke it
if (handleTMethod != null)
{
return (Task)handleTMethod.Invoke(this, [item])!;
}
// otherwise, complain
throw new InvalidOperationException($"No handler found for type {item.GetType().FullName}");
}
}

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentBaseExtensions.cs
using System.Diagnostics;
namespace Microsoft.AutoGen.Agents;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentContext.cs
using System.Diagnostics;
using Microsoft.AutoGen.Abstractions;
using Microsoft.Extensions.Logging;
@ -10,14 +13,14 @@ internal sealed class AgentContext(AgentId agentId, IAgentWorkerRuntime runtime,
public AgentId AgentId { get; } = agentId;
public ILogger Logger { get; } = logger;
public AgentBase? AgentInstance { get; set; }
public IAgentBase? AgentInstance { get; set; }
public DistributedContextPropagator DistributedContextPropagator { get; } = distributedContextPropagator;
public async ValueTask SendResponseAsync(RpcRequest request, RpcResponse response)
{
response.RequestId = request.RequestId;
await _runtime.SendResponse(response);
}
public async ValueTask SendRequestAsync(AgentBase agent, RpcRequest request)
public async ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request)
{
await _runtime.SendRequest(agent, request).ConfigureAwait(false);
}

View File

@ -1,15 +0,0 @@
using RpcAgentId = Microsoft.AutoGen.Abstractions.AgentId;
namespace Microsoft.AutoGen.Agents;
public sealed record class AgentId(string Type, string Key)
{
public static implicit operator RpcAgentId(AgentId agentId) => new()
{
Type = agentId.Type,
Key = agentId.Key
};
public static implicit operator AgentId(RpcAgentId agentId) => new(agentId.Type, agentId.Key);
public override string ToString() => $"{Type}/{Key}";
}

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentWorker.cs
using System.Diagnostics;
using Google.Protobuf;
using Microsoft.AutoGen.Abstractions;

View File

@ -1,4 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// InferenceAgent.cs
using Google.Protobuf;
using Microsoft.AutoGen.Abstractions;
using Microsoft.Extensions.AI;
namespace Microsoft.AutoGen.Agents.Client;
public abstract class InferenceAgent<T> : AgentBase where T : IMessage, new()

View File

@ -1,7 +1,9 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// SKAiAgent.cs
using System.Globalization;
using System.Text;
using Microsoft.AutoGen.Abstractions;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using Microsoft.SemanticKernel.Memory;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ConsoleAgent.cs
using Microsoft.AutoGen.Abstractions;
using Microsoft.Extensions.DependencyInjection;

View File

@ -0,0 +1,49 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IHandleConsole.cs
using Microsoft.AutoGen.Abstractions;
namespace Microsoft.AutoGen.Agents;
public interface IHandleConsole : IHandle<Output>, IHandle<Input>
{
string Route { get; }
AgentId AgentId { get; }
ValueTask PublishEvent(CloudEvent item);
async Task IHandle<Output>.Handle(Output item)
{
// Assuming item has a property `Message` that we want to write to the console
Console.WriteLine(item.Message);
await ProcessOutput(item.Message);
var evt = new OutputWritten
{
Route = "console"
}.ToCloudEvent(AgentId.Key);
await PublishEvent(evt);
}
async Task IHandle<Input>.Handle(Input item)
{
Console.WriteLine("Please enter input:");
string content = Console.ReadLine() ?? string.Empty;
await ProcessInput(content);
var evt = new InputProcessed
{
Route = "console"
}.ToCloudEvent(AgentId.Key);
await PublishEvent(evt);
}
static Task ProcessOutput(string message)
{
// Implement your output processing logic here
return Task.CompletedTask;
}
static Task<string> ProcessInput(string message)
{
// Implement your input processing logic here
return Task.FromResult(message);
}
}

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// FileAgent.cs
using Microsoft.AutoGen.Abstractions;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IOAgent.cs
using Microsoft.AutoGen.Abstractions;
namespace Microsoft.AutoGen.Agents;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// WebAPIAgent.cs
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.AutoGen.Abstractions;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// App.cs
using System.Diagnostics.CodeAnalysis;
using Google.Protobuf;
using Microsoft.AspNetCore.Builder;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GrpcAgentWorkerRuntime.cs
using System.Collections.Concurrent;
using System.Diagnostics;
using System.Reflection;
@ -83,6 +86,7 @@ public sealed class GrpcAgentWorkerRuntime : IHostedService, IDisposable, IAgent
message.Response.RequestId = request.OriginalRequestId;
request.Agent.ReceiveMessage(message);
break;
case Message.MessageOneofCase.RegisterAgentTypeResponse:
if (!message.RegisterAgentTypeResponse.Success)
{

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// HostBuilderExtensions.cs
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
@ -71,7 +74,52 @@ public static class HostBuilderExtensions
.Where(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IHandle<>))
.Select(i => (GetMessageDescriptor(i.GetGenericArguments().First())?.FullName ?? "")).ToHashSet()))
.ToDictionary(item => item.t, item => item.Item2);
// if the assembly contains any interfaces of type IHandler, then add all the methods of the interface to the eventsMap
var handlersMap = AppDomain.CurrentDomain.GetAssemblies()
.SelectMany(assembly => assembly.GetTypes())
.Where(type => ReflectionHelper.IsSubclassOfGeneric(type, typeof(AgentBase)) && !type.IsAbstract)
.Select(t => (t, t.GetMethods()
.Where(m => m.Name == "Handle")
.Select(m => (GetMessageDescriptor(m.GetParameters().First().ParameterType)?.FullName ?? "")).ToHashSet()))
.ToDictionary(item => item.t, item => item.Item2);
// get interfaces implemented by the agent and get the methods of the interface if they are named Handle
var ifaceHandlersMap = AppDomain.CurrentDomain.GetAssemblies()
.SelectMany(assembly => assembly.GetTypes())
.Where(type => ReflectionHelper.IsSubclassOfGeneric(type, typeof(AgentBase)) && !type.IsAbstract)
.Select(t => t.GetInterfaces()
.Select(i => (t, i, i.GetMethods()
.Where(m => m.Name == "Handle")
.Select(m => (GetMessageDescriptor(m.GetParameters().First().ParameterType)?.FullName ?? ""))
//to dictionary of type t and paramter type of the method
.ToDictionary(m => m, m => m).Keys.ToHashSet())).ToList());
// for each item in ifaceHandlersMap, add the handlers to eventsMap with item as the key
foreach (var item in ifaceHandlersMap)
{
foreach (var iface in item)
{
if (eventsMap.TryGetValue(iface.Item2, out var events))
{
events.UnionWith(iface.Item3);
}
else
{
eventsMap[iface.Item2] = iface.Item3;
}
}
}
// merge the handlersMap into the eventsMap
foreach (var item in handlersMap)
{
if (eventsMap.TryGetValue(item.Key, out var events))
{
events.UnionWith(item.Value);
}
else
{
eventsMap[item.Key] = item.Value;
}
}
return new EventTypes(typeRegistry, types, eventsMap);
});
return new AgentApplicationBuilder(builder);

View File

@ -1,20 +0,0 @@
using Google.Protobuf;
using Microsoft.AutoGen.Abstractions;
namespace Microsoft.AutoGen.Agents
{
public interface IAgentBase
{
// Properties
AgentId AgentId { get; }
IAgentContext Context { get; }
// Methods
Task CallHandler(CloudEvent item);
Task<RpcResponse> HandleRequest(RpcRequest request);
void ReceiveMessage(Message message);
Task Store(AgentState state);
Task<T> Read<T>(AgentId agentId) where T : IMessage, new();
ValueTask PublishEvent(CloudEvent item);
}
}

View File

@ -1,18 +0,0 @@
using System.Diagnostics;
using Microsoft.AutoGen.Abstractions;
using Microsoft.Extensions.Logging;
namespace Microsoft.AutoGen.Agents;
public interface IAgentContext
{
AgentId AgentId { get; }
AgentBase? AgentInstance { get; set; }
DistributedContextPropagator DistributedContextPropagator { get; }
ILogger Logger { get; }
ValueTask Store(AgentState value);
ValueTask<AgentState> Read(AgentId agentId);
ValueTask SendResponseAsync(RpcRequest request, RpcResponse response);
ValueTask SendRequestAsync(AgentBase agent, RpcRequest request);
ValueTask PublishEventAsync(CloudEvent @event);
}

View File

@ -12,7 +12,7 @@
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="../Abstractions/Microsoft.AutoGen.Abstractions.csproj" />
<ProjectReference Include="..\Abstractions\Microsoft.AutoGen.Abstractions.csproj" />
<ProjectReference Include="..\Runtime\Microsoft.AutoGen.Runtime.csproj" />
</ItemGroup>

View File

@ -1,33 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AIModelClientHostingExtensions.cs
using Microsoft.Extensions.AI;
namespace Microsoft.Extensions.Hosting
{
public static class AIModelClient
{
public static IHostApplicationBuilder AddChatCompletionService(this IHostApplicationBuilder builder, string serviceName)
{
var pipeline = (ChatClientBuilder pipeline) => pipeline
.UseLogging()
.UseFunctionInvocation()
.UseOpenTelemetry(configure: c => c.EnableSensitiveData = true);
namespace Microsoft.Extensions.Hosting;
if (builder.Configuration[$"{serviceName}:ModelType"] == "ollama")
{
builder.AddOllamaChatClient(serviceName, pipeline);
}
else if (builder.Configuration[$"{serviceName}:ModelType"] == "openai" || builder.Configuration[$"{serviceName}:ModelType"] == "azureopenai")
{
builder.AddOpenAIChatClient(serviceName, pipeline);
}
else if (builder.Configuration[$"{serviceName}:ModelType"] == "azureaiinference")
{
builder.AddAzureChatClient(serviceName, pipeline);
}
else
{
throw new InvalidOperationException("Did not find a valid model implementation for the given service name ${serviceName}, valid supported implemenation types are ollama, openai, azureopenai, azureaiinference");
}
return builder;
public static class AIModelClient
{
public static IHostApplicationBuilder AddChatCompletionService(this IHostApplicationBuilder builder, string serviceName)
{
var pipeline = (ChatClientBuilder pipeline) => pipeline
.UseLogging()
.UseFunctionInvocation()
.UseOpenTelemetry(configure: c => c.EnableSensitiveData = true);
if (builder.Configuration[$"{serviceName}:ModelType"] == "ollama")
{
builder.AddOllamaChatClient(serviceName, pipeline);
}
else if (builder.Configuration[$"{serviceName}:ModelType"] == "openai" || builder.Configuration[$"{serviceName}:ModelType"] == "azureopenai")
{
builder.AddOpenAIChatClient(serviceName, pipeline);
}
else if (builder.Configuration[$"{serviceName}:ModelType"] == "azureaiinference")
{
builder.AddAzureChatClient(serviceName, pipeline);
}
else
{
throw new InvalidOperationException("Did not find a valid model implementation for the given service name ${serviceName}, valid supported implemenation types are ollama, openai, azureopenai, azureaiinference");
}
return builder;
}
}

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AIClientOptions.cs
using System.ComponentModel.DataAnnotations;
namespace Microsoft.Extensions.Hosting;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ServiceCollectionChatCompletionExtensions.cs
using System.ClientModel;
using System.Data.Common;
using Azure;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// CloudEventExtensions.cs
using Google.Protobuf;
using Google.Protobuf.WellKnownTypes;
using Microsoft.AutoGen.Abstractions;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// QdrantOptions.cs
using System.ComponentModel.DataAnnotations;
namespace Microsoft.AutoGen.Extensions.SemanticKernel;

View File

@ -1,4 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// SemanticKernelHostingExtensions.cs
using System.Text.Json;
using Azure.AI.OpenAI;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentWorkerHostingExtensions.cs
using System.Diagnostics;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentWorkerRegistryGrain.cs
using Microsoft.AutoGen.Abstractions;
namespace Microsoft.AutoGen.Runtime;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Host.cs
using Microsoft.AspNetCore.Builder;
using Microsoft.Extensions.Hosting;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IAgentWorkerRegistryGrain.cs
using Microsoft.AutoGen.Abstractions;
namespace Microsoft.AutoGen.Runtime;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IWorkerAgentGrain.cs
using Microsoft.AutoGen.Abstractions;
namespace Microsoft.AutoGen.Runtime;

View File

@ -1,3 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IWorkerGateway.cs
using Microsoft.AutoGen.Abstractions;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// OrleansRuntimeHostingExtenions.cs
using System.Configuration;
using Microsoft.AspNetCore.Builder;
using Microsoft.Extensions.Configuration;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// WorkerAgentGrain.cs
using Microsoft.AutoGen.Abstractions;
namespace Microsoft.AutoGen.Runtime;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// WorkerGateway.cs
using System.Collections.Concurrent;
using Grpc.Core;
using Microsoft.AutoGen.Abstractions;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// WorkerGatewayService.cs
using Grpc.Core;
using Microsoft.AutoGen.Abstractions;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// WorkerProcessConnection.cs
using System.Threading.Channels;
using Grpc.Core;
using Microsoft.AutoGen.Abstractions;

View File

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Extensions.cs
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Diagnostics.HealthChecks;
using Microsoft.Extensions.DependencyInjection;

View File

@ -1,4 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// ChatRequestMessageTests.cs
using System;
using System.Collections.Generic;

View File

@ -14,206 +14,205 @@ using FluentAssertions;
using OpenAI;
using Xunit.Abstractions;
namespace AutoGen.OpenAI.Tests
namespace AutoGen.OpenAI.Tests;
public partial class MathClassTest
{
public partial class MathClassTest
private readonly ITestOutputHelper _output;
// as of 2024-05-20, aoai return 500 error when round > 1
// I'm pretty sure that round > 5 was supported before
// So this is probably some wield regression on aoai side
// I'll keep this test case here for now, plus setting round to 1
// so the test can still pass.
// In the future, we should rewind this test case to round > 1 (previously was 5)
private int round = 1;
public MathClassTest(ITestOutputHelper output)
{
private readonly ITestOutputHelper _output;
_output = output;
}
// as of 2024-05-20, aoai return 500 error when round > 1
// I'm pretty sure that round > 5 was supported before
// So this is probably some wield regression on aoai side
// I'll keep this test case here for now, plus setting round to 1
// so the test can still pass.
// In the future, we should rewind this test case to round > 1 (previously was 5)
private int round = 1;
public MathClassTest(ITestOutputHelper output)
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
{
try
{
_output = output;
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
_output.WriteLine(reply.FormatMessage());
return Task.FromResult(reply);
}
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
catch (Exception)
{
try
_output.WriteLine("Request failed");
_output.WriteLine($"agent name: {agent.Name}");
foreach (var message in messages)
{
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
_output.WriteLine(reply.FormatMessage());
return Task.FromResult(reply);
}
catch (Exception)
{
_output.WriteLine("Request failed");
_output.WriteLine($"agent name: {agent.Name}");
foreach (var message in messages)
{
_output.WriteLine(message.FormatMessage());
}
throw;
_output.WriteLine(message.FormatMessage());
}
throw;
}
[FunctionAttribute]
public async Task<string> CreateMathQuestion(string question, int question_index)
{
return $@"[MATH_QUESTION]
}
[FunctionAttribute]
public async Task<string> CreateMathQuestion(string question, int question_index)
{
return $@"[MATH_QUESTION]
Question {question_index}:
{question}
Student, please answer";
}
}
[FunctionAttribute]
public async Task<string> AnswerQuestion(string answer)
{
return $@"[MATH_ANSWER]
[FunctionAttribute]
public async Task<string> AnswerQuestion(string answer)
{
return $@"[MATH_ANSWER]
The answer is {answer}
teacher please check answer";
}
}
[FunctionAttribute]
public async Task<string> AnswerIsCorrect(string message)
{
return $@"[ANSWER_IS_CORRECT]
[FunctionAttribute]
public async Task<string> AnswerIsCorrect(string message)
{
return $@"[ANSWER_IS_CORRECT]
{message}
please update progress";
}
}
[FunctionAttribute]
public async Task<string> UpdateProgress(int correctAnswerCount)
[FunctionAttribute]
public async Task<string> UpdateProgress(int correctAnswerCount)
{
if (correctAnswerCount >= this.round)
{
if (correctAnswerCount >= this.round)
{
return $@"[UPDATE_PROGRESS]
return $@"[UPDATE_PROGRESS]
{GroupChatExtension.TERMINATE}";
}
else
{
return $@"[UPDATE_PROGRESS]
}
else
{
return $@"[UPDATE_PROGRESS]
the number of resolved question is {correctAnswerCount}
teacher, please create the next math question";
}
}
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task OpenAIAgentMathChatTestAsync()
{
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
var openaiClient = new AzureOpenAIClient(new Uri(endPoint), new ApiKeyCredential(key));
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task OpenAIAgentMathChatTestAsync()
{
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
var openaiClient = new AzureOpenAIClient(new Uri(endPoint), new ApiKeyCredential(key));
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
var adminFunctionMiddleware = new FunctionCallMiddleware(
functions: [this.UpdateProgressFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
});
var admin = new OpenAIChatAgent(
chatClient: openaiClient.GetChatClient(deployName),
name: "Admin",
systemMessage: $@"You are admin. You update progress after each question is answered.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(adminFunctionMiddleware)
.RegisterMiddleware(Print);
var adminFunctionMiddleware = new FunctionCallMiddleware(
functions: [this.UpdateProgressFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
});
var admin = new OpenAIChatAgent(
chatClient: openaiClient.GetChatClient(deployName),
name: "Admin",
systemMessage: $@"You are admin. You update progress after each question is answered.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(adminFunctionMiddleware)
.RegisterMiddleware(Print);
var groupAdmin = new OpenAIChatAgent(
chatClient: openaiClient.GetChatClient(deployName),
name: "GroupAdmin",
systemMessage: "You are group admin. You manage the group chat.")
.RegisterMessageConnector()
.RegisterMiddleware(Print);
await RunMathChatAsync(teacher, student, admin, groupAdmin);
}
var groupAdmin = new OpenAIChatAgent(
chatClient: openaiClient.GetChatClient(deployName),
name: "GroupAdmin",
systemMessage: "You are group admin. You manage the group chat.")
.RegisterMessageConnector()
.RegisterMiddleware(Print);
await RunMathChatAsync(teacher, student, admin, groupAdmin);
}
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
});
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
});
var teacher = new OpenAIChatAgent(
chatClient: client.GetChatClient(model),
name: "Teacher",
systemMessage: @"You are a preschool math teacher.
var teacher = new OpenAIChatAgent(
chatClient: client.GetChatClient(model),
name: "Teacher",
systemMessage: @"You are a preschool math teacher.
You create math question and ask student to answer it.
Then you check if the answer is correct.
If the answer is wrong, you ask student to fix it")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
return teacher;
}
return teacher;
}
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.AnswerQuestionFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
});
var student = new OpenAIChatAgent(
chatClient: client.GetChatClient(model),
name: "Student",
systemMessage: @"You are a student. You answer math question from teacher.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.AnswerQuestionFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
});
var student = new OpenAIChatAgent(
chatClient: client.GetChatClient(model),
name: "Student",
systemMessage: @"You are a student. You answer math question from teacher.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
return student;
}
return student;
}
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
{
var teacher2Student = Transition.Create(teacher, student);
var student2Teacher = Transition.Create(student, teacher);
var teacher2Admin = Transition.Create(teacher, admin);
var admin2Teacher = Transition.Create(admin, teacher);
var workflow = new Graph(
[
teacher2Student,
student2Teacher,
teacher2Admin,
admin2Teacher,
]);
var group = new GroupChat(
workflow: workflow,
members: [
admin,
teacher,
student,
],
admin: groupAdmin);
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
{
var teacher2Student = Transition.Create(teacher, student);
var student2Teacher = Transition.Create(student, teacher);
var teacher2Admin = Transition.Create(teacher, admin);
var admin2Teacher = Transition.Create(admin, teacher);
var workflow = new Graph(
[
teacher2Student,
student2Teacher,
teacher2Admin,
admin2Teacher,
]);
var group = new GroupChat(
workflow: workflow,
members: [
admin,
teacher,
student,
],
admin: groupAdmin);
var groupChatManager = new GroupChatManager(group);
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
var groupChatManager = new GroupChatManager(group);
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
// check if there's terminate chat message from admin
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
.Count()
.Should().Be(1);
}
// check if there's terminate chat message from admin
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
.Count()
.Should().Be(1);
}
}

View File

@ -13,214 +13,213 @@ using Azure.AI.OpenAI;
using FluentAssertions;
using Xunit.Abstractions;
namespace AutoGen.OpenAI.V1.Tests
namespace AutoGen.OpenAI.V1.Tests;
public partial class MathClassTest
{
public partial class MathClassTest
private readonly ITestOutputHelper _output;
// as of 2024-05-20, aoai return 500 error when round > 1
// I'm pretty sure that round > 5 was supported before
// So this is probably some wield regression on aoai side
// I'll keep this test case here for now, plus setting round to 1
// so the test can still pass.
// In the future, we should rewind this test case to round > 1 (previously was 5)
private int round = 1;
public MathClassTest(ITestOutputHelper output)
{
private readonly ITestOutputHelper _output;
_output = output;
}
// as of 2024-05-20, aoai return 500 error when round > 1
// I'm pretty sure that round > 5 was supported before
// So this is probably some wield regression on aoai side
// I'll keep this test case here for now, plus setting round to 1
// so the test can still pass.
// In the future, we should rewind this test case to round > 1 (previously was 5)
private int round = 1;
public MathClassTest(ITestOutputHelper output)
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
{
try
{
_output = output;
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
_output.WriteLine(reply.FormatMessage());
return Task.FromResult(reply);
}
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
catch (Exception)
{
try
_output.WriteLine("Request failed");
_output.WriteLine($"agent name: {agent.Name}");
foreach (var message in messages)
{
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
_output.WriteLine(reply.FormatMessage());
return Task.FromResult(reply);
}
catch (Exception)
{
_output.WriteLine("Request failed");
_output.WriteLine($"agent name: {agent.Name}");
foreach (var message in messages)
if (message is IMessage<object> envelope)
{
if (message is IMessage<object> envelope)
{
var json = JsonSerializer.Serialize(envelope.Content, new JsonSerializerOptions { WriteIndented = true });
_output.WriteLine(json);
}
var json = JsonSerializer.Serialize(envelope.Content, new JsonSerializerOptions { WriteIndented = true });
_output.WriteLine(json);
}
throw;
}
throw;
}
[FunctionAttribute]
public async Task<string> CreateMathQuestion(string question, int question_index)
{
return $@"[MATH_QUESTION]
}
[FunctionAttribute]
public async Task<string> CreateMathQuestion(string question, int question_index)
{
return $@"[MATH_QUESTION]
Question {question_index}:
{question}
Student, please answer";
}
}
[FunctionAttribute]
public async Task<string> AnswerQuestion(string answer)
{
return $@"[MATH_ANSWER]
[FunctionAttribute]
public async Task<string> AnswerQuestion(string answer)
{
return $@"[MATH_ANSWER]
The answer is {answer}
teacher please check answer";
}
}
[FunctionAttribute]
public async Task<string> AnswerIsCorrect(string message)
{
return $@"[ANSWER_IS_CORRECT]
[FunctionAttribute]
public async Task<string> AnswerIsCorrect(string message)
{
return $@"[ANSWER_IS_CORRECT]
{message}
please update progress";
}
}
[FunctionAttribute]
public async Task<string> UpdateProgress(int correctAnswerCount)
[FunctionAttribute]
public async Task<string> UpdateProgress(int correctAnswerCount)
{
if (correctAnswerCount >= this.round)
{
if (correctAnswerCount >= this.round)
{
return $@"[UPDATE_PROGRESS]
return $@"[UPDATE_PROGRESS]
{GroupChatExtension.TERMINATE}";
}
else
{
return $@"[UPDATE_PROGRESS]
}
else
{
return $@"[UPDATE_PROGRESS]
the number of resolved question is {correctAnswerCount}
teacher, please create the next math question";
}
}
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task OpenAIAgentMathChatTestAsync()
{
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
var openaiClient = new OpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(key));
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task OpenAIAgentMathChatTestAsync()
{
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
var openaiClient = new OpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(key));
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
var adminFunctionMiddleware = new FunctionCallMiddleware(
functions: [this.UpdateProgressFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
});
var admin = new OpenAIChatAgent(
openAIClient: openaiClient,
modelName: deployName,
name: "Admin",
systemMessage: $@"You are admin. You update progress after each question is answered.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(adminFunctionMiddleware)
.RegisterMiddleware(Print);
var adminFunctionMiddleware = new FunctionCallMiddleware(
functions: [this.UpdateProgressFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
});
var admin = new OpenAIChatAgent(
openAIClient: openaiClient,
modelName: deployName,
name: "Admin",
systemMessage: $@"You are admin. You update progress after each question is answered.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(adminFunctionMiddleware)
.RegisterMiddleware(Print);
var groupAdmin = new OpenAIChatAgent(
openAIClient: openaiClient,
modelName: deployName,
name: "GroupAdmin",
systemMessage: "You are group admin. You manage the group chat.")
.RegisterMessageConnector()
.RegisterMiddleware(Print);
await RunMathChatAsync(teacher, student, admin, groupAdmin);
}
var groupAdmin = new OpenAIChatAgent(
openAIClient: openaiClient,
modelName: deployName,
name: "GroupAdmin",
systemMessage: "You are group admin. You manage the group chat.")
.RegisterMessageConnector()
.RegisterMiddleware(Print);
await RunMathChatAsync(teacher, student, admin, groupAdmin);
}
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
});
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
});
var teacher = new OpenAIChatAgent(
openAIClient: client,
name: "Teacher",
systemMessage: @"You are a preschool math teacher.
var teacher = new OpenAIChatAgent(
openAIClient: client,
name: "Teacher",
systemMessage: @"You are a preschool math teacher.
You create math question and ask student to answer it.
Then you check if the answer is correct.
If the answer is wrong, you ask student to fix it",
modelName: model)
.RegisterMiddleware(Print)
.RegisterMiddleware(new OpenAIChatRequestMessageConnector())
.RegisterMiddleware(functionCallMiddleware);
modelName: model)
.RegisterMiddleware(Print)
.RegisterMiddleware(new OpenAIChatRequestMessageConnector())
.RegisterMiddleware(functionCallMiddleware);
return teacher;
}
return teacher;
}
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.AnswerQuestionFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
});
var student = new OpenAIChatAgent(
openAIClient: client,
name: "Student",
modelName: model,
systemMessage: @"You are a student. You answer math question from teacher.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
{
var functionCallMiddleware = new FunctionCallMiddleware(
functions: [this.AnswerQuestionFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
});
var student = new OpenAIChatAgent(
openAIClient: client,
name: "Student",
modelName: model,
systemMessage: @"You are a student. You answer math question from teacher.")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
return student;
}
return student;
}
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
{
var teacher2Student = Transition.Create(teacher, student);
var student2Teacher = Transition.Create(student, teacher);
var teacher2Admin = Transition.Create(teacher, admin);
var admin2Teacher = Transition.Create(admin, teacher);
var workflow = new Graph(
[
teacher2Student,
student2Teacher,
teacher2Admin,
admin2Teacher,
]);
var group = new GroupChat(
workflow: workflow,
members: [
admin,
teacher,
student,
],
admin: groupAdmin);
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
{
var teacher2Student = Transition.Create(teacher, student);
var student2Teacher = Transition.Create(student, teacher);
var teacher2Admin = Transition.Create(teacher, admin);
var admin2Teacher = Transition.Create(admin, teacher);
var workflow = new Graph(
[
teacher2Student,
student2Teacher,
teacher2Admin,
admin2Teacher,
]);
var group = new GroupChat(
workflow: workflow,
members: [
admin,
teacher,
student,
],
admin: groupAdmin);
var groupChatManager = new GroupChatManager(group);
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
var groupChatManager = new GroupChatManager(group);
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
.Count()
.Should().BeGreaterThanOrEqualTo(this.round);
// check if there's terminate chat message from admin
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
.Count()
.Should().Be(1);
}
// check if there's terminate chat message from admin
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
.Count()
.Should().Be(1);
}
}

View File

@ -1,4 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// KernelFunctionMiddlewareTests.cs
using System.ClientModel;
using AutoGen.Core;

View File

@ -1,87 +1,87 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// FunctionCallTemplateEncodingTests.cs
using AutoGen.SourceGenerator.Template; // Needed for FunctionCallTemplate
using Xunit; // Needed for Fact and Assert
namespace AutoGen.SourceGenerator.Tests
namespace AutoGen.SourceGenerator.Tests;
public class FunctionCallTemplateEncodingTests
{
public class FunctionCallTemplateEncodingTests
[Fact]
public void FunctionDescription_Should_Encode_DoubleQuotes()
{
[Fact]
public void FunctionDescription_Should_Encode_DoubleQuotes()
// Arrange
var functionContracts = new List<SourceGeneratorFunctionContract>
{
// Arrange
var functionContracts = new List<SourceGeneratorFunctionContract>
new SourceGeneratorFunctionContract
{
new SourceGeneratorFunctionContract
Name = "TestFunction",
Description = "This is a \"test\" function",
Parameters = new SourceGeneratorParameterContract[]
{
Name = "TestFunction",
Description = "This is a \"test\" function",
Parameters = new SourceGeneratorParameterContract[]
new SourceGeneratorParameterContract
{
new SourceGeneratorParameterContract
{
Name = "param1",
Description = "This is a \"parameter\" description",
Type = "string",
IsOptional = false
}
},
ReturnType = "void"
}
};
Name = "param1",
Description = "This is a \"parameter\" description",
Type = "string",
IsOptional = false
}
},
ReturnType = "void"
}
};
var template = new FunctionCallTemplate
{
NameSpace = "TestNamespace",
ClassName = "TestClass",
FunctionContracts = functionContracts
};
// Act
var result = template.TransformText();
// Assert
Assert.Contains("Description = @\"This is a \"\"test\"\" function\"", result);
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
}
[Fact]
public void ParameterDescription_Should_Encode_DoubleQuotes()
var template = new FunctionCallTemplate
{
// Arrange
var functionContracts = new List<SourceGeneratorFunctionContract>
NameSpace = "TestNamespace",
ClassName = "TestClass",
FunctionContracts = functionContracts
};
// Act
var result = template.TransformText();
// Assert
Assert.Contains("Description = @\"This is a \"\"test\"\" function\"", result);
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
}
[Fact]
public void ParameterDescription_Should_Encode_DoubleQuotes()
{
// Arrange
var functionContracts = new List<SourceGeneratorFunctionContract>
{
new SourceGeneratorFunctionContract
{
new SourceGeneratorFunctionContract
Name = "TestFunction",
Description = "This is a test function",
Parameters = new SourceGeneratorParameterContract[]
{
Name = "TestFunction",
Description = "This is a test function",
Parameters = new SourceGeneratorParameterContract[]
new SourceGeneratorParameterContract
{
new SourceGeneratorParameterContract
{
Name = "param1",
Description = "This is a \"parameter\" description",
Type = "string",
IsOptional = false
}
},
ReturnType = "void"
}
};
Name = "param1",
Description = "This is a \"parameter\" description",
Type = "string",
IsOptional = false
}
},
ReturnType = "void"
}
};
var template = new FunctionCallTemplate
{
NameSpace = "TestNamespace",
ClassName = "TestClass",
FunctionContracts = functionContracts
};
var template = new FunctionCallTemplate
{
NameSpace = "TestNamespace",
ClassName = "TestClass",
FunctionContracts = functionContracts
};
// Act
var result = template.TransformText();
// Act
var result = template.TransformText();
// Assert
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
}
// Assert
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
}
}

View File

@ -10,122 +10,121 @@ using FluentAssertions;
using OpenAI.Chat;
using Xunit;
namespace AutoGen.SourceGenerator.Tests
namespace AutoGen.SourceGenerator.Tests;
public class FunctionExample
{
public class FunctionExample
private readonly FunctionExamples functionExamples = new FunctionExamples();
private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions
{
private readonly FunctionExamples functionExamples = new FunctionExamples();
private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions
WriteIndented = true,
};
[Fact]
public void Add_Test()
{
var args = new
{
WriteIndented = true,
a = 1,
b = 2,
};
[Fact]
public void Add_Test()
this.VerifyFunction(functionExamples.AddWrapper, args, 3);
this.VerifyFunctionDefinition(functionExamples.AddFunctionContract.ToChatTool());
}
[Fact]
public void Sum_Test()
{
var args = new
{
var args = new
args = new double[] { 1, 2, 3 },
};
this.VerifyFunction(functionExamples.SumWrapper, args, 6.0);
this.VerifyFunctionDefinition(functionExamples.SumFunctionContract.ToChatTool());
}
[Fact]
public async Task DictionaryToString_Test()
{
var args = new
{
xargs = new Dictionary<string, string>
{
a = 1,
b = 2,
};
{ "a", "1" },
{ "b", "2" },
},
};
this.VerifyFunction(functionExamples.AddWrapper, args, 3);
this.VerifyFunctionDefinition(functionExamples.AddFunctionContract.ToChatTool());
}
await this.VerifyAsyncFunction(functionExamples.DictionaryToStringAsyncWrapper, args, JsonSerializer.Serialize(args.xargs, jsonSerializerOptions));
this.VerifyFunctionDefinition(functionExamples.DictionaryToStringAsyncFunctionContract.ToChatTool());
}
[Fact]
public void Sum_Test()
[Fact]
public async Task TopLevelFunctionExampleAddTestAsync()
{
var example = new TopLevelStatementFunctionExample();
var args = new
{
var args = new
{
args = new double[] { 1, 2, 3 },
};
a = 1,
b = 2,
};
this.VerifyFunction(functionExamples.SumWrapper, args, 6.0);
this.VerifyFunctionDefinition(functionExamples.SumFunctionContract.ToChatTool());
}
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
}
[Fact]
public async Task DictionaryToString_Test()
[Fact]
public async Task FilescopeFunctionExampleAddTestAsync()
{
var example = new FilescopeNamespaceFunctionExample();
var args = new
{
var args = new
{
xargs = new Dictionary<string, string>
{
{ "a", "1" },
{ "b", "2" },
},
};
a = 1,
b = 2,
};
await this.VerifyAsyncFunction(functionExamples.DictionaryToStringAsyncWrapper, args, JsonSerializer.Serialize(args.xargs, jsonSerializerOptions));
this.VerifyFunctionDefinition(functionExamples.DictionaryToStringAsyncFunctionContract.ToChatTool());
}
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
}
[Fact]
public async Task TopLevelFunctionExampleAddTestAsync()
[Fact]
public void Query_Test()
{
var args = new
{
var example = new TopLevelStatementFunctionExample();
var args = new
{
a = 1,
b = 2,
};
query = "hello",
k = 3,
};
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
}
this.VerifyFunction(functionExamples.QueryWrapper, args, new[] { "hello", "hello", "hello" });
this.VerifyFunctionDefinition(functionExamples.QueryFunctionContract.ToChatTool());
}
[Fact]
public async Task FilescopeFunctionExampleAddTestAsync()
[UseReporter(typeof(DiffReporter))]
[UseApprovalSubdirectory("ApprovalTests")]
private void VerifyFunctionDefinition(ChatTool function)
{
var func = new
{
var example = new FilescopeNamespaceFunctionExample();
var args = new
{
a = 1,
b = 2,
};
name = function.FunctionName,
description = function.FunctionDescription.Replace(Environment.NewLine, ","),
parameters = function.FunctionParameters.ToObjectFromJson<object>(options: jsonSerializerOptions),
};
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
}
Approvals.Verify(JsonSerializer.Serialize(func, jsonSerializerOptions));
}
[Fact]
public void Query_Test()
{
var args = new
{
query = "hello",
k = 3,
};
private void VerifyFunction<T, U>(Func<string, T> func, U args, T expected)
{
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
var res = func(str);
res.Should().BeEquivalentTo(expected);
}
this.VerifyFunction(functionExamples.QueryWrapper, args, new[] { "hello", "hello", "hello" });
this.VerifyFunctionDefinition(functionExamples.QueryFunctionContract.ToChatTool());
}
[UseReporter(typeof(DiffReporter))]
[UseApprovalSubdirectory("ApprovalTests")]
private void VerifyFunctionDefinition(ChatTool function)
{
var func = new
{
name = function.FunctionName,
description = function.FunctionDescription.Replace(Environment.NewLine, ","),
parameters = function.FunctionParameters.ToObjectFromJson<object>(options: jsonSerializerOptions),
};
Approvals.Verify(JsonSerializer.Serialize(func, jsonSerializerOptions));
}
private void VerifyFunction<T, U>(Func<string, T> func, U args, T expected)
{
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
var res = func(str);
res.Should().BeEquivalentTo(expected);
}
private async Task VerifyAsyncFunction<T, U>(Func<string, Task<T>> func, U args, T expected)
{
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
var res = await func(str);
res.Should().BeEquivalentTo(expected);
}
private async Task VerifyAsyncFunction<T, U>(Func<string, Task<T>> func, U args, T expected)
{
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
var res = await func(str);
res.Should().BeEquivalentTo(expected);
}
}

View File

@ -4,67 +4,66 @@
using System.Text.Json;
using AutoGen.Core;
namespace AutoGen.SourceGenerator.Tests
namespace AutoGen.SourceGenerator.Tests;
public partial class FunctionExamples
{
public partial class FunctionExamples
/// <summary>
/// Add function
/// </summary>
/// <param name="a">a</param>
/// <param name="b">b</param>
[FunctionAttribute]
public int Add(int a, int b)
{
/// <summary>
/// Add function
/// </summary>
/// <param name="a">a</param>
/// <param name="b">b</param>
[FunctionAttribute]
public int Add(int a, int b)
{
return a + b;
}
return a + b;
}
/// <summary>
/// Add two numbers.
/// </summary>
/// <param name="a">The first number.</param>
/// <param name="b">The second number.</param>
[Function]
public Task<string> AddAsync(int a, int b)
{
return Task.FromResult($"{a} + {b} = {a + b}");
}
/// <summary>
/// Add two numbers.
/// </summary>
/// <param name="a">The first number.</param>
/// <param name="b">The second number.</param>
[Function]
public Task<string> AddAsync(int a, int b)
{
return Task.FromResult($"{a} + {b} = {a + b}");
}
/// <summary>
/// Sum function
/// </summary>
/// <param name="args">an array of double values</param>
[FunctionAttribute]
public double Sum(double[] args)
{
return args.Sum();
}
/// <summary>
/// Sum function
/// </summary>
/// <param name="args">an array of double values</param>
[FunctionAttribute]
public double Sum(double[] args)
{
return args.Sum();
}
/// <summary>
/// DictionaryToString function
/// </summary>
/// <param name="xargs">an object of key-value pairs. key is string, value is string</param>
[FunctionAttribute]
public Task<string> DictionaryToStringAsync(Dictionary<string, string> xargs)
/// <summary>
/// DictionaryToString function
/// </summary>
/// <param name="xargs">an object of key-value pairs. key is string, value is string</param>
[FunctionAttribute]
public Task<string> DictionaryToStringAsync(Dictionary<string, string> xargs)
{
var res = JsonSerializer.Serialize(xargs, new JsonSerializerOptions
{
var res = JsonSerializer.Serialize(xargs, new JsonSerializerOptions
{
WriteIndented = true,
});
WriteIndented = true,
});
return Task.FromResult(res);
}
return Task.FromResult(res);
}
/// <summary>
/// query function
/// </summary>
/// <param name="query">query, required</param>
/// <param name="k">top k, optional, default value is 3</param>
/// <param name="thresold">thresold, optional, default value is 0.5</param>
[FunctionAttribute]
public string[] Query(string query, int k = 3, float thresold = 0.5f)
{
return Enumerable.Repeat(query, k).ToArray();
}
/// <summary>
/// query function
/// </summary>
/// <param name="query">query, required</param>
/// <param name="k">top k, optional, default value is 3</param>
/// <param name="thresold">thresold, optional, default value is 0.5</param>
[FunctionAttribute]
public string[] Query(string query, int k = 3, float thresold = 0.5f)
{
return Enumerable.Repeat(query, k).ToArray();
}
}

View File

@ -7,73 +7,72 @@ using System.Threading.Tasks;
using AutoGen.BasicSample;
using Xunit.Abstractions;
namespace AutoGen.Tests
namespace AutoGen.Tests;
public class BasicSampleTest
{
public class BasicSampleTest
private readonly ITestOutputHelper _output;
public BasicSampleTest(ITestOutputHelper output)
{
private readonly ITestOutputHelper _output;
_output = output;
Console.SetOut(new ConsoleWriter(_output));
}
public BasicSampleTest(ITestOutputHelper output)
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task AssistantAgentTestAsync()
{
await Example01_AssistantAgent.RunAsync();
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task TwoAgentMathClassTestAsync()
{
await Example02_TwoAgent_MathChat.RunAsync();
}
[ApiKeyFact("OPENAI_API_KEY")]
public async Task AgentFunctionCallTestAsync()
{
await Example03_Agent_FunctionCall.RunAsync();
}
[ApiKeyFact("MISTRAL_API_KEY")]
public async Task MistralClientAgent_TokenCount()
{
await Example14_MistralClientAgent_TokenCount.RunAsync();
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task DynamicGroupChatCalculateFibonacciAsync()
{
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunAsync();
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunWorkflowAsync();
}
[ApiKeyFact("OPENAI_API_KEY")]
public async Task DalleAndGPT4VTestAsync()
{
await Example05_Dalle_And_GPT4V.RunAsync();
}
[ApiKeyFact("OPENAI_API_KEY")]
public async Task GPT4ImageMessage()
{
await Example15_GPT4V_BinaryDataImageMessage.RunAsync();
}
public class ConsoleWriter : StringWriter
{
private ITestOutputHelper output;
public ConsoleWriter(ITestOutputHelper output)
{
_output = output;
Console.SetOut(new ConsoleWriter(_output));
this.output = output;
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task AssistantAgentTestAsync()
public override void WriteLine(string? m)
{
await Example01_AssistantAgent.RunAsync();
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task TwoAgentMathClassTestAsync()
{
await Example02_TwoAgent_MathChat.RunAsync();
}
[ApiKeyFact("OPENAI_API_KEY")]
public async Task AgentFunctionCallTestAsync()
{
await Example03_Agent_FunctionCall.RunAsync();
}
[ApiKeyFact("MISTRAL_API_KEY")]
public async Task MistralClientAgent_TokenCount()
{
await Example14_MistralClientAgent_TokenCount.RunAsync();
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task DynamicGroupChatCalculateFibonacciAsync()
{
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunAsync();
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunWorkflowAsync();
}
[ApiKeyFact("OPENAI_API_KEY")]
public async Task DalleAndGPT4VTestAsync()
{
await Example05_Dalle_And_GPT4V.RunAsync();
}
[ApiKeyFact("OPENAI_API_KEY")]
public async Task GPT4ImageMessage()
{
await Example15_GPT4V_BinaryDataImageMessage.RunAsync();
}
public class ConsoleWriter : StringWriter
{
private ITestOutputHelper output;
public ConsoleWriter(ITestOutputHelper output)
{
this.output = output;
}
public override void WriteLine(string? m)
{
output.WriteLine(m);
}
output.WriteLine(m);
}
}
}

View File

@ -3,18 +3,17 @@
using Xunit;
namespace AutoGen.Tests
{
public class GraphTests
{
[Fact]
public void GraphTest()
{
var graph1 = new Graph();
Assert.NotNull(graph1);
namespace AutoGen.Tests;
var graph2 = new Graph(null);
Assert.NotNull(graph2);
}
public class GraphTests
{
[Fact]
public void GraphTest()
{
var graph1 = new Graph();
Assert.NotNull(graph1);
var graph2 = new Graph(null);
Assert.NotNull(graph2);
}
}

View File

@ -9,219 +9,218 @@ using FluentAssertions;
using Xunit;
using Xunit.Abstractions;
namespace AutoGen.Tests
namespace AutoGen.Tests;
public partial class SingleAgentTest
{
public partial class SingleAgentTest
private ITestOutputHelper _output;
public SingleAgentTest(ITestOutputHelper output)
{
private ITestOutputHelper _output;
public SingleAgentTest(ITestOutputHelper output)
{
_output = output;
}
_output = output;
}
private ILLMConfig CreateAzureOpenAIGPT35TurboConfig()
{
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
return new AzureOpenAIConfig(endpoint, deployName, key);
}
private ILLMConfig CreateAzureOpenAIGPT35TurboConfig()
{
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
return new AzureOpenAIConfig(endpoint, deployName, key);
}
private ILLMConfig CreateOpenAIGPT4VisionConfig()
{
var key = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new ArgumentException("OPENAI_API_KEY is not set");
return new OpenAIConfig(key, "gpt-4-vision-preview");
}
private ILLMConfig CreateOpenAIGPT4VisionConfig()
{
var key = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new ArgumentException("OPENAI_API_KEY is not set");
return new OpenAIConfig(key, "gpt-4-vision-preview");
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task AssistantAgentFunctionCallTestAsync()
{
var config = this.CreateAzureOpenAIGPT35TurboConfig();
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task AssistantAgentFunctionCallTestAsync()
{
var config = this.CreateAzureOpenAIGPT35TurboConfig();
var llmConfig = new ConversableAgentConfig
var llmConfig = new ConversableAgentConfig
{
Temperature = 0,
FunctionContracts = new[]
{
Temperature = 0,
FunctionContracts = new[]
{
this.EchoAsyncFunctionContract,
},
ConfigList = new[]
{
config,
},
};
var assistantAgent = new AssistantAgent(
name: "assistant",
llmConfig: llmConfig);
await EchoFunctionCallTestAsync(assistantAgent);
}
[Fact]
public async Task AssistantAgentDefaultReplyTestAsync()
{
var assistantAgent = new AssistantAgent(
llmConfig: null,
name: "assistant",
defaultReply: "hello world");
var reply = await assistantAgent.SendAsync("hi");
reply.GetContent().Should().Be("hello world");
reply.GetRole().Should().Be(Role.Assistant);
reply.From.Should().Be(assistantAgent.Name);
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task AssistantAgentFunctionCallSelfExecutionTestAsync()
{
var config = this.CreateAzureOpenAIGPT35TurboConfig();
var llmConfig = new ConversableAgentConfig
this.EchoAsyncFunctionContract,
},
ConfigList = new[]
{
FunctionContracts = new[]
{
this.EchoAsyncFunctionContract,
},
ConfigList = new[]
{
config,
},
};
var assistantAgent = new AssistantAgent(
name: "assistant",
llmConfig: llmConfig,
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ nameof(EchoAsync), this.EchoAsyncWrapper },
});
config,
},
};
await EchoFunctionCallExecutionTestAsync(assistantAgent);
}
var assistantAgent = new AssistantAgent(
name: "assistant",
llmConfig: llmConfig);
/// <summary>
/// echo when asked.
/// </summary>
/// <param name="message">message to echo</param>
[FunctionAttribute]
public async Task<string> EchoAsync(string message)
await EchoFunctionCallTestAsync(assistantAgent);
}
[Fact]
public async Task AssistantAgentDefaultReplyTestAsync()
{
var assistantAgent = new AssistantAgent(
llmConfig: null,
name: "assistant",
defaultReply: "hello world");
var reply = await assistantAgent.SendAsync("hi");
reply.GetContent().Should().Be("hello world");
reply.GetRole().Should().Be(Role.Assistant);
reply.From.Should().Be(assistantAgent.Name);
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task AssistantAgentFunctionCallSelfExecutionTestAsync()
{
var config = this.CreateAzureOpenAIGPT35TurboConfig();
var llmConfig = new ConversableAgentConfig
{
return $"[ECHO] {message}";
}
FunctionContracts = new[]
{
this.EchoAsyncFunctionContract,
},
ConfigList = new[]
{
config,
},
};
var assistantAgent = new AssistantAgent(
name: "assistant",
llmConfig: llmConfig,
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ nameof(EchoAsync), this.EchoAsyncWrapper },
});
/// <summary>
/// return the label name with hightest inference cost
/// </summary>
/// <param name="labelName"></param>
/// <returns></returns>
[FunctionAttribute]
public async Task<string> GetHighestLabel(string labelName, string color)
await EchoFunctionCallExecutionTestAsync(assistantAgent);
}
/// <summary>
/// echo when asked.
/// </summary>
/// <param name="message">message to echo</param>
[FunctionAttribute]
public async Task<string> EchoAsync(string message)
{
return $"[ECHO] {message}";
}
/// <summary>
/// return the label name with hightest inference cost
/// </summary>
/// <param name="labelName"></param>
/// <returns></returns>
[FunctionAttribute]
public async Task<string> GetHighestLabel(string labelName, string color)
{
return $"[HIGHEST_LABEL] {labelName} {color}";
}
public async Task EchoFunctionCallTestAsync(IAgent agent)
{
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that call echo function");
var helloWorld = new TextMessage(Role.User, "echo Hello world");
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
reply.From.Should().Be(agent.Name);
reply.GetToolCalls()!.First().FunctionName.Should().Be(nameof(EchoAsync));
}
public async Task EchoFunctionCallExecutionTestAsync(IAgent agent)
{
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
var helloWorld = new TextMessage(Role.User, "echo Hello world");
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
reply.GetContent().Should().Be("[ECHO] Hello world");
reply.From.Should().Be(agent.Name);
reply.Should().BeOfType<ToolCallAggregateMessage>();
}
public async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent agent)
{
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
var helloWorld = new TextMessage(Role.User, "echo Hello world");
var option = new GenerateReplyOptions
{
return $"[HIGHEST_LABEL] {labelName} {color}";
}
public async Task EchoFunctionCallTestAsync(IAgent agent)
Temperature = 0,
};
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { helloWorld }, option);
var answer = "[ECHO] Hello world";
IMessage? finalReply = default;
await foreach (var reply in replyStream)
{
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that call echo function");
var helloWorld = new TextMessage(Role.User, "echo Hello world");
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
reply.From.Should().Be(agent.Name);
reply.GetToolCalls()!.First().FunctionName.Should().Be(nameof(EchoAsync));
finalReply = reply;
}
public async Task EchoFunctionCallExecutionTestAsync(IAgent agent)
if (finalReply is ToolCallAggregateMessage aggregateMessage)
{
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
var helloWorld = new TextMessage(Role.User, "echo Hello world");
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
reply.GetContent().Should().Be("[ECHO] Hello world");
reply.From.Should().Be(agent.Name);
reply.Should().BeOfType<ToolCallAggregateMessage>();
var toolCallResultMessage = aggregateMessage.Message2;
toolCallResultMessage.ToolCalls.First().Result.Should().Be(answer);
toolCallResultMessage.From.Should().Be(agent.Name);
toolCallResultMessage.ToolCalls.First().FunctionName.Should().Be(nameof(EchoAsync));
}
public async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent agent)
else
{
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
var helloWorld = new TextMessage(Role.User, "echo Hello world");
var option = new GenerateReplyOptions
{
Temperature = 0,
};
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { helloWorld }, option);
var answer = "[ECHO] Hello world";
IMessage? finalReply = default;
await foreach (var reply in replyStream)
{
reply.From.Should().Be(agent.Name);
finalReply = reply;
}
if (finalReply is ToolCallAggregateMessage aggregateMessage)
{
var toolCallResultMessage = aggregateMessage.Message2;
toolCallResultMessage.ToolCalls.First().Result.Should().Be(answer);
toolCallResultMessage.From.Should().Be(agent.Name);
toolCallResultMessage.ToolCalls.First().FunctionName.Should().Be(nameof(EchoAsync));
}
else
{
throw new Exception("unexpected message type");
}
}
public async Task UpperCaseTestAsync(IAgent agent)
{
var message = new TextMessage(Role.User, "Please convert abcde to upper case.");
var reply = await agent.SendAsync(chatHistory: new[] { message });
reply.GetContent().Should().Contain("ABCDE");
reply.From.Should().Be(agent.Name);
}
public async Task UpperCaseStreamingTestAsync(IStreamingAgent agent)
{
var message = new TextMessage(Role.User, "Please convert 'hello world' to upper case");
var option = new GenerateReplyOptions
{
Temperature = 0,
};
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message }, option);
var answer = "HELLO WORLD";
TextMessage? finalReply = default;
await foreach (var reply in replyStream)
{
if (reply is TextMessageUpdate update)
{
update.From.Should().Be(agent.Name);
if (finalReply is null)
{
finalReply = new TextMessage(update);
}
else
{
finalReply.Update(update);
}
continue;
}
else if (reply is TextMessage textMessage)
{
finalReply = textMessage;
continue;
}
throw new Exception("unexpected message type");
}
finalReply!.Content.Should().Contain(answer);
finalReply!.Role.Should().Be(Role.Assistant);
finalReply!.From.Should().Be(agent.Name);
throw new Exception("unexpected message type");
}
}
public async Task UpperCaseTestAsync(IAgent agent)
{
var message = new TextMessage(Role.User, "Please convert abcde to upper case.");
var reply = await agent.SendAsync(chatHistory: new[] { message });
reply.GetContent().Should().Contain("ABCDE");
reply.From.Should().Be(agent.Name);
}
public async Task UpperCaseStreamingTestAsync(IStreamingAgent agent)
{
var message = new TextMessage(Role.User, "Please convert 'hello world' to upper case");
var option = new GenerateReplyOptions
{
Temperature = 0,
};
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message }, option);
var answer = "HELLO WORLD";
TextMessage? finalReply = default;
await foreach (var reply in replyStream)
{
if (reply is TextMessageUpdate update)
{
update.From.Should().Be(agent.Name);
if (finalReply is null)
{
finalReply = new TextMessage(update);
}
else
{
finalReply.Update(update);
}
continue;
}
else if (reply is TextMessage textMessage)
{
finalReply = textMessage;
continue;
}
throw new Exception("unexpected message type");
}
finalReply!.Content.Should().Contain(answer);
finalReply!.Role.Should().Be(Role.Assistant);
finalReply!.From.Should().Be(agent.Name);
}
}

View File

@ -1,5 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// TwoAgentTest.cs
#pragma warning disable xUnit1013
using System;
using System.Collections.Generic;

View File

@ -18,12 +18,16 @@ from autogen_core.components.tools import FunctionTool, Tool
from pydantic import BaseModel, ConfigDict, Field, model_validator
from .. import EVENT_LOGGER_NAME
from ..base import Response
from ..messages import (
ChatMessage,
HandoffMessage,
InnerMessage,
ResetMessage,
StopMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessages,
)
from ._base_chat_agent import BaseChatAgent
@ -207,7 +211,14 @@ class AssistantAgent(BaseChatAgent):
)
self._model_context: List[LLMMessage] = []
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
"""The types of messages that the assistant agent produces."""
if self._handoffs:
return [TextMessage, HandoffMessage, StopMessage]
return [TextMessage, StopMessage]
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
# Add messages to the model context.
for msg in messages:
if isinstance(msg, ResetMessage):
@ -215,6 +226,9 @@ class AssistantAgent(BaseChatAgent):
else:
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
# Inner messages.
inner_messages: List[InnerMessage] = []
# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
result = await self._model_client.create(
@ -227,12 +241,16 @@ class AssistantAgent(BaseChatAgent):
# Run tool calls until the model produces a string response.
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
# Add the tool call message to the output.
inner_messages.append(ToolCallMessage(content=result.content, source=self.name))
# Execute the tool calls.
results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
)
event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name))
self._model_context.append(FunctionExecutionResultMessage(content=results))
inner_messages.append(ToolCallResultMessages(content=results, source=self.name))
# Detect handoff requests.
handoffs: List[Handoff] = []
@ -242,8 +260,13 @@ class AssistantAgent(BaseChatAgent):
if len(handoffs) > 0:
if len(handoffs) > 1:
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
# Respond with a handoff message.
return HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name)
# Return the output messages to signal the handoff.
return Response(
chat_message=HandoffMessage(
content=handoffs[0].message, target=handoffs[0].target, source=self.name
),
inner_messages=inner_messages,
)
# Generate an inference result based on the current model context.
result = await self._model_client.create(
@ -255,9 +278,13 @@ class AssistantAgent(BaseChatAgent):
# Detect stop request.
request_stop = "terminate" in result.content.strip().lower()
if request_stop:
return StopMessage(content=result.content, source=self.name)
return Response(
chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages
)
return TextMessage(content=result.content, source=self.name)
return Response(
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
)
async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken

Some files were not shown because too many files have changed in this diff Show More