mirror of https://github.com/microsoft/autogen.git
Compare commits
16 Commits
d32a6b1102
...
9028bb6d18
Author | SHA1 | Date |
---|---|---|
Mohammad Mazraeh | 9028bb6d18 | |
Mohammad Mazraeh | a6f0c7cc8a | |
Mohammad Mazraeh | 2b20a74ef0 | |
Mohammad Mazraeh | ba0be1f78f | |
Mohammad Mazraeh | ccf94bfdb5 | |
Mohammad Mazraeh | 07603e1054 | |
Mohammad Mazraeh | a6bf6eff49 | |
Mohammad Mazraeh | d98e8b2fdf | |
Mohammad Mazraeh | e83217ca1a | |
Mohammad Mazraeh | 9f0e747e93 | |
Rohan Thacker | 3c63f6f3ef | |
Xiaoyun Zhang | 6bea055b26 | |
Eric Zhu | 3d51ab76ae | |
Xiaoyun Zhang | e63fd17ed5 | |
Ryan Sweet | 51cd5b8d1f | |
Eric Zhu | 4a49844996 |
|
@ -193,10 +193,6 @@ csharp_using_directive_placement = outside_namespace:error
|
|||
csharp_prefer_static_local_function = true:warning
|
||||
csharp_preferred_modifier_order = public,private,protected,internal,static,extern,new,virtual,abstract,sealed,override,readonly,unsafe,volatile,async:warning
|
||||
|
||||
# Header template
|
||||
file_header_template = Copyright (c) Microsoft Corporation. All rights reserved.\n{fileName}
|
||||
dotnet_diagnostic.IDE0073.severity = error
|
||||
|
||||
# enable format error
|
||||
dotnet_diagnostic.IDE0055.severity = error
|
||||
|
||||
|
@ -221,13 +217,14 @@ dotnet_diagnostic.IDE0161.severity = warning # Use file-scoped namespace
|
|||
|
||||
csharp_style_var_elsewhere = true:suggestion # Prefer 'var' everywhere
|
||||
csharp_prefer_simple_using_statement = true:suggestion
|
||||
csharp_style_namespace_declarations = block_scoped:silent
|
||||
csharp_style_namespace_declarations = file_scoped:warning
|
||||
csharp_style_prefer_method_group_conversion = true:silent
|
||||
csharp_style_prefer_top_level_statements = true:silent
|
||||
csharp_style_prefer_primary_constructors = true:suggestion
|
||||
csharp_style_expression_bodied_lambdas = true:silent
|
||||
csharp_style_prefer_local_over_anonymous_function = true:suggestion
|
||||
dotnet_diagnostic.CA2016.severity = suggestion
|
||||
csharp_prefer_static_anonymous_function = true:suggestion
|
||||
|
||||
# disable check for generated code
|
||||
[*.generated.cs]
|
||||
|
@ -556,8 +553,8 @@ dotnet_diagnostic.IDE0060.severity = warning
|
|||
dotnet_diagnostic.IDE0062.severity = warning
|
||||
|
||||
# IDE0073: File header
|
||||
dotnet_diagnostic.IDE0073.severity = suggestion
|
||||
file_header_template = Copyright (c) Microsoft. All rights reserved.
|
||||
dotnet_diagnostic.IDE0073.severity = warning
|
||||
file_header_template = Copyright (c) Microsoft Corporation. All rights reserved.\n{fileName}
|
||||
|
||||
# IDE1006: Required naming style
|
||||
dotnet_diagnostic.IDE1006.severity = warning
|
||||
|
@ -697,6 +694,7 @@ dotnet_style_prefer_compound_assignment = true:suggestion
|
|||
dotnet_style_prefer_simplified_interpolation = true:suggestion
|
||||
dotnet_style_prefer_collection_expression = when_types_loosely_match:suggestion
|
||||
dotnet_style_namespace_match_folder = true:suggestion
|
||||
dotnet_style_qualification_for_method = false:silent
|
||||
|
||||
[**/*.g.cs]
|
||||
generated_code = true
|
||||
|
|
|
@ -123,9 +123,11 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HelloAgent", "samples\Hello
|
|||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AIModelClientHostingExtensions", "src\Microsoft.AutoGen\Extensions\AIModelClientHostingExtensions\AIModelClientHostingExtensions.csproj", "{97550E87-48C6-4EBF-85E1-413ABAE9DBFD}"
|
||||
EndProject
|
||||
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.Agents.Tests", "Microsoft.AutoGen.Agents.Tests\Microsoft.AutoGen.Agents.Tests.csproj", "{CF4C92BD-28AE-4B8F-B173-601004AEC9BF}"
|
||||
EndProject
|
||||
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "sample", "sample", "{686480D7-8FEC-4ED3-9C5D-CEBE1057A7ED}"
|
||||
EndProject
|
||||
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "HelloAgentState", "samples\Hello\HelloAgentState\HelloAgentState.csproj", "{64EF61E7-00A6-4E5E-9808-62E10993A0E5}"
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HelloAgentState", "samples\Hello\HelloAgentState\HelloAgentState.csproj", "{64EF61E7-00A6-4E5E-9808-62E10993A0E5}"
|
||||
EndProject
|
||||
Global
|
||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution
|
||||
|
@ -337,6 +339,10 @@ Global
|
|||
{97550E87-48C6-4EBF-85E1-413ABAE9DBFD}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{97550E87-48C6-4EBF-85E1-413ABAE9DBFD}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{97550E87-48C6-4EBF-85E1-413ABAE9DBFD}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{CF4C92BD-28AE-4B8F-B173-601004AEC9BF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{CF4C92BD-28AE-4B8F-B173-601004AEC9BF}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{CF4C92BD-28AE-4B8F-B173-601004AEC9BF}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{CF4C92BD-28AE-4B8F-B173-601004AEC9BF}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{64EF61E7-00A6-4E5E-9808-62E10993A0E5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{64EF61E7-00A6-4E5E-9808-62E10993A0E5}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{64EF61E7-00A6-4E5E-9808-62E10993A0E5}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
|
@ -401,6 +407,7 @@ Global
|
|||
{A20B9894-F352-4338-872A-F215A241D43D} = {7EB336C2-7C0A-4BC8-80C6-A3173AB8DC45}
|
||||
{8F7560CF-EEBB-4333-A69F-838CA40FD85D} = {7EB336C2-7C0A-4BC8-80C6-A3173AB8DC45}
|
||||
{97550E87-48C6-4EBF-85E1-413ABAE9DBFD} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
|
||||
{CF4C92BD-28AE-4B8F-B173-601004AEC9BF} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
|
||||
{64EF61E7-00A6-4E5E-9808-62E10993A0E5} = {7EB336C2-7C0A-4BC8-80C6-A3173AB8DC45}
|
||||
EndGlobalSection
|
||||
GlobalSection(ExtensibilityGlobals) = postSolution
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// AgentBaseTests.cs
|
||||
|
||||
using FluentAssertions;
|
||||
using Google.Protobuf.Reflection;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
using Moq;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.AutoGen.Agents.Tests;
|
||||
|
||||
public class AgentBaseTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task ItInvokeRightHandlerTestAsync()
|
||||
{
|
||||
var mockContext = new Mock<IAgentContext>();
|
||||
var agent = new TestAgent(mockContext.Object, new EventTypes(TypeRegistry.Empty, [], []));
|
||||
|
||||
await agent.HandleObject("hello world");
|
||||
await agent.HandleObject(42);
|
||||
|
||||
agent.ReceivedItems.Should().HaveCount(2);
|
||||
agent.ReceivedItems[0].Should().Be("hello world");
|
||||
agent.ReceivedItems[1].Should().Be(42);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The test agent is a simple agent that is used for testing purposes.
|
||||
/// </summary>
|
||||
public class TestAgent : AgentBase, IHandle<string>, IHandle<int>
|
||||
{
|
||||
public TestAgent(IAgentContext context, EventTypes eventTypes) : base(context, eventTypes)
|
||||
{
|
||||
}
|
||||
|
||||
public Task Handle(string item)
|
||||
{
|
||||
ReceivedItems.Add(item);
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public Task Handle(int item)
|
||||
{
|
||||
ReceivedItems.Add(item);
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public List<object> ReceivedItems { get; private set; } = [];
|
||||
}
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<TargetFrameworks>$(TestTargetFrameworks)</TargetFrameworks>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<Nullable>enable</Nullable>
|
||||
<IsTestProject>True</IsTestProject>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\src\Microsoft.AutoGen\Agents\Microsoft.AutoGen.Agents.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
|
@ -1,5 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// AgentCodeSnippet.cs
|
||||
|
||||
using AutoGen.Core;
|
||||
|
||||
namespace AutoGen.BasicSample.CodeSnippet;
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// UserProxyAgentCodeSnippet.cs
|
||||
|
||||
using AutoGen.Core;
|
||||
|
||||
namespace AutoGen.BasicSample.CodeSnippet;
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Example06_UserProxyAgent.cs
|
||||
|
||||
using AutoGen.Core;
|
||||
using AutoGen.OpenAI;
|
||||
using AutoGen.OpenAI.Extension;
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Program.cs
|
||||
|
||||
//await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunAsync();
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Connect_To_Azure_OpenAI.cs
|
||||
|
||||
#region using_statement
|
||||
using System.ClientModel;
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Program.cs
|
||||
|
||||
using Microsoft.Extensions.Hosting;
|
||||
|
||||
var app = await Microsoft.AutoGen.Runtime.Host.StartAsync(local: true);
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Program.cs
|
||||
|
||||
var builder = DistributedApplication.CreateBuilder(args);
|
||||
var backend = builder.AddProject<Projects.Backend>("backend");
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// HelloAIAgent.cs
|
||||
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
using Microsoft.AutoGen.Agents;
|
||||
using Microsoft.Extensions.AI;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Program.cs
|
||||
|
||||
using Hello;
|
||||
using Microsoft.AspNetCore.Builder;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
|
|
|
@ -1,11 +1,20 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Program.cs
|
||||
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
using Microsoft.AutoGen.Agents;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Hosting;
|
||||
|
||||
// send a message to the agent
|
||||
// step 1: create in-memory agent runtime
|
||||
|
||||
// step 2: register HelloAgent to that agent runtime
|
||||
|
||||
// step 3: start the agent runtime
|
||||
|
||||
// step 4: send a message to the agent
|
||||
|
||||
// step 5: wait for the agent runtime to shutdown
|
||||
var app = await AgentsApp.PublishMessageAsync("HelloAgents", new NewMessageReceived
|
||||
{
|
||||
Message = "World"
|
||||
|
@ -18,10 +27,11 @@ namespace Hello
|
|||
[TopicSubscription("HelloAgents")]
|
||||
public class HelloAgent(
|
||||
IAgentContext context,
|
||||
[FromKeyedServices("EventTypes")] EventTypes typeRegistry) : ConsoleAgent(
|
||||
[FromKeyedServices("EventTypes")] EventTypes typeRegistry) : AgentBase(
|
||||
context,
|
||||
typeRegistry),
|
||||
ISayHello,
|
||||
IHandleConsole,
|
||||
IHandle<NewMessageReceived>,
|
||||
IHandle<ConversationClosed>
|
||||
{
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Program.cs
|
||||
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
using Microsoft.AutoGen.Agents;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Program.cs
|
||||
|
||||
using Microsoft.AutoGen.Runtime;
|
||||
var builder = WebApplication.CreateBuilder(args);
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Developer.cs
|
||||
|
||||
using DevTeam.Shared;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
using Microsoft.AutoGen.Agents;
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// DeveloperPrompts.cs
|
||||
|
||||
namespace DevTeam.Agents;
|
||||
public static class DeveloperSkills
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// DeveloperLead.cs
|
||||
|
||||
using DevTeam.Shared;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
using Microsoft.AutoGen.Agents;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// DeveloperLeadPrompts.cs
|
||||
|
||||
namespace DevTeam.Agents;
|
||||
public static class DevLeadSkills
|
||||
{
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// PMPrompts.cs
|
||||
|
||||
namespace DevTeam.Agents;
|
||||
public static class PMSkills
|
||||
{
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// ProductManager.cs
|
||||
|
||||
using DevTeam.Shared;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
using Microsoft.AutoGen.Agents;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Program.cs
|
||||
|
||||
using DevTeam.Agents;
|
||||
using Microsoft.AutoGen.Agents;
|
||||
using Microsoft.AutoGen.Extensions.SemanticKernel;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Program.cs
|
||||
|
||||
var builder = DistributedApplication.CreateBuilder(args);
|
||||
|
||||
builder.AddAzureProvisioning();
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// AzureGenie.cs
|
||||
|
||||
using DevTeam.Backend;
|
||||
using DevTeam.Shared;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Hubber.cs
|
||||
|
||||
using System.Text.Json;
|
||||
using DevTeam;
|
||||
using DevTeam.Backend;
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
// TODO: Reimplement using ACA Sessions
|
||||
// using DevTeam.Events;
|
||||
// using Microsoft.AutoGen.Abstractions;
|
||||
// using Microsoft.AutoGen.Agents;
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Sandbox.cs
|
||||
|
||||
// namespace DevTeam.Backend;
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Program.cs
|
||||
|
||||
using Azure.Identity;
|
||||
using DevTeam.Backend;
|
||||
using DevTeam.Options;
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// AzureService.cs
|
||||
|
||||
using System.Text;
|
||||
using Azure;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// GithubAuthService.cs
|
||||
|
||||
using System.IdentityModel.Tokens.Jwt;
|
||||
using System.Security.Claims;
|
||||
using System.Security.Cryptography;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// GithubService.cs
|
||||
|
||||
using System.Text;
|
||||
using Azure.Storage.Files.Shares;
|
||||
using DevTeam.Options;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// GithubWebHookProcessor.cs
|
||||
|
||||
using System.Globalization;
|
||||
using DevTeam.Shared;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// EventExtensions.cs
|
||||
|
||||
using System.Globalization;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// DevPlan.cs
|
||||
|
||||
namespace DevTeam;
|
||||
public class DevLeadPlan
|
||||
{
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// AzureOptions.cs
|
||||
|
||||
using System.ComponentModel.DataAnnotations;
|
||||
|
||||
namespace DevTeam.Options;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// GithubOptions.cs
|
||||
|
||||
using System.ComponentModel.DataAnnotations;
|
||||
|
||||
namespace DevTeam.Options;
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// ParseExtensions.cs
|
||||
|
||||
namespace DevTeam;
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// ChatCompletionRequest.cs
|
||||
|
||||
using System.Collections.Generic;
|
||||
using System.Text.Json.Serialization;
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// AgentExtension.cs
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// GroupChat.cs
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
|
|
|
@ -5,45 +5,44 @@ using System;
|
|||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
|
||||
namespace AutoGen
|
||||
namespace AutoGen;
|
||||
|
||||
public static class LLMConfigAPI
|
||||
{
|
||||
public static class LLMConfigAPI
|
||||
public static IEnumerable<ILLMConfig> GetOpenAIConfigList(
|
||||
string apiKey,
|
||||
IEnumerable<string>? modelIDs = null)
|
||||
{
|
||||
public static IEnumerable<ILLMConfig> GetOpenAIConfigList(
|
||||
string apiKey,
|
||||
IEnumerable<string>? modelIDs = null)
|
||||
var models = modelIDs ?? new[]
|
||||
{
|
||||
var models = modelIDs ?? new[]
|
||||
{
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-4-1106-preview",
|
||||
};
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-4-1106-preview",
|
||||
};
|
||||
|
||||
return models.Select(modelId => new OpenAIConfig(apiKey, modelId));
|
||||
}
|
||||
return models.Select(modelId => new OpenAIConfig(apiKey, modelId));
|
||||
}
|
||||
|
||||
public static IEnumerable<ILLMConfig> GetAzureOpenAIConfigList(
|
||||
string endpoint,
|
||||
string apiKey,
|
||||
IEnumerable<string> deploymentNames)
|
||||
{
|
||||
return deploymentNames.Select(deploymentName => new AzureOpenAIConfig(endpoint, deploymentName, apiKey));
|
||||
}
|
||||
public static IEnumerable<ILLMConfig> GetAzureOpenAIConfigList(
|
||||
string endpoint,
|
||||
string apiKey,
|
||||
IEnumerable<string> deploymentNames)
|
||||
{
|
||||
return deploymentNames.Select(deploymentName => new AzureOpenAIConfig(endpoint, deploymentName, apiKey));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get a list of LLMConfig objects from a JSON file.
|
||||
/// </summary>
|
||||
internal static IEnumerable<ILLMConfig> ConfigListFromJson(
|
||||
string filePath,
|
||||
IEnumerable<string>? filterModels = null)
|
||||
{
|
||||
// Disable this API from documentation for now.
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
/// <summary>
|
||||
/// Get a list of LLMConfig objects from a JSON file.
|
||||
/// </summary>
|
||||
internal static IEnumerable<ILLMConfig> ConfigListFromJson(
|
||||
string filePath,
|
||||
IEnumerable<string>? filterModels = null)
|
||||
{
|
||||
// Disable this API from documentation for now.
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// LMStudioConfig.cs
|
||||
|
||||
using System;
|
||||
using System.ClientModel;
|
||||
using OpenAI;
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// AgentId.cs
|
||||
|
||||
namespace Microsoft.AutoGen.Abstractions;
|
||||
|
||||
public partial class AgentId
|
||||
{
|
||||
public AgentId(string type, string key)
|
||||
{
|
||||
Type = type;
|
||||
Key = key;
|
||||
}
|
||||
}
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// ChatHistoryItem.cs
|
||||
|
||||
namespace Microsoft.AutoGen.Abstractions;
|
||||
|
||||
[Serializable]
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// ChatState.cs
|
||||
|
||||
using Google.Protobuf;
|
||||
|
||||
namespace Microsoft.AutoGen.Abstractions;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// ChatUserType.cs
|
||||
|
||||
namespace Microsoft.AutoGen.Abstractions;
|
||||
|
||||
public enum ChatUserType
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// IAgentBase.cs
|
||||
|
||||
using Google.Protobuf;
|
||||
|
||||
namespace Microsoft.AutoGen.Abstractions;
|
||||
|
||||
public interface IAgentBase
|
||||
{
|
||||
// Properties
|
||||
AgentId AgentId { get; }
|
||||
IAgentContext Context { get; }
|
||||
|
||||
// Methods
|
||||
Task CallHandler(CloudEvent item);
|
||||
Task<RpcResponse> HandleRequest(RpcRequest request);
|
||||
void ReceiveMessage(Message message);
|
||||
Task Store(AgentState state);
|
||||
Task<T> Read<T>(AgentId agentId) where T : IMessage, new();
|
||||
ValueTask PublishEvent(CloudEvent item);
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// IAgentContext.cs
|
||||
|
||||
using System.Diagnostics;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace Microsoft.AutoGen.Abstractions;
|
||||
|
||||
public interface IAgentContext
|
||||
{
|
||||
AgentId AgentId { get; }
|
||||
IAgentBase? AgentInstance { get; set; }
|
||||
DistributedContextPropagator DistributedContextPropagator { get; } // TODO: Remove this. An abstraction should not have a dependency on DistributedContextPropagator.
|
||||
ILogger Logger { get; } // TODO: Remove this. An abstraction should not have a dependency on ILogger.
|
||||
ValueTask Store(AgentState value);
|
||||
ValueTask<AgentState> Read(AgentId agentId);
|
||||
ValueTask SendResponseAsync(RpcRequest request, RpcResponse response);
|
||||
ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request);
|
||||
ValueTask PublishEventAsync(CloudEvent @event);
|
||||
}
|
|
@ -1,8 +1,8 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// IAgentWorkerRuntime.cs
|
||||
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
namespace Microsoft.AutoGen.Abstractions;
|
||||
|
||||
namespace Microsoft.AutoGen.Agents;
|
||||
public interface IAgentWorkerRuntime
|
||||
{
|
||||
ValueTask PublishEvent(CloudEvent evt);
|
|
@ -1,6 +1,14 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// IHandle.cs
|
||||
|
||||
namespace Microsoft.AutoGen.Abstractions;
|
||||
|
||||
public interface IHandle<T>
|
||||
public interface IHandle
|
||||
{
|
||||
Task HandleObject(object item);
|
||||
}
|
||||
|
||||
public interface IHandle<T> : IHandle
|
||||
{
|
||||
Task Handle(T item);
|
||||
}
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// MessageExtensions.cs
|
||||
|
||||
using Google.Protobuf;
|
||||
using Google.Protobuf.WellKnownTypes;
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// TopicSubscriptionAttribute.cs
|
||||
|
||||
namespace Microsoft.AutoGen.Abstractions;
|
||||
|
||||
[AttributeUsage(AttributeTargets.All)]
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// AgentBase.cs
|
||||
|
||||
using System.Diagnostics;
|
||||
using System.Reflection;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Threading.Channels;
|
||||
|
@ -8,7 +12,7 @@ using Microsoft.Extensions.Logging;
|
|||
|
||||
namespace Microsoft.AutoGen.Agents;
|
||||
|
||||
public abstract class AgentBase : IAgentBase
|
||||
public abstract class AgentBase : IAgentBase, IHandle
|
||||
{
|
||||
public static readonly ActivitySource s_source = new("AutoGen.Agent");
|
||||
public AgentId AgentId => _context.AgentId;
|
||||
|
@ -17,6 +21,8 @@ public abstract class AgentBase : IAgentBase
|
|||
|
||||
private readonly Channel<object> _mailbox = Channel.CreateUnbounded<object>();
|
||||
private readonly IAgentContext _context;
|
||||
public string Route { get; set; } = "base";
|
||||
|
||||
protected internal ILogger Logger => _context.Logger;
|
||||
public IAgentContext Context => _context;
|
||||
protected readonly EventTypes EventTypes;
|
||||
|
@ -212,16 +218,59 @@ public abstract class AgentBase : IAgentBase
|
|||
public Task CallHandler(CloudEvent item)
|
||||
{
|
||||
// Only send the event to the handler if the agent type is handling that type
|
||||
if (EventTypes.EventsMap[GetType()].Contains(item.Type))
|
||||
// foreach of the keys in the EventTypes.EventsMap[] if it contains the item.type
|
||||
foreach (var key in EventTypes.EventsMap.Keys)
|
||||
{
|
||||
var payload = item.ProtoData.Unpack(EventTypes.TypeRegistry);
|
||||
var convertedPayload = Convert.ChangeType(payload, EventTypes.Types[item.Type]);
|
||||
var genericInterfaceType = typeof(IHandle<>).MakeGenericType(EventTypes.Types[item.Type]);
|
||||
var methodInfo = genericInterfaceType.GetMethod(nameof(IHandle<object>.Handle)) ?? throw new InvalidOperationException($"Method not found on type {genericInterfaceType.FullName}");
|
||||
return methodInfo.Invoke(this, [payload]) as Task ?? Task.CompletedTask;
|
||||
if (EventTypes.EventsMap[key].Contains(item.Type))
|
||||
{
|
||||
var payload = item.ProtoData.Unpack(EventTypes.TypeRegistry);
|
||||
var convertedPayload = Convert.ChangeType(payload, EventTypes.Types[item.Type]);
|
||||
var genericInterfaceType = typeof(IHandle<>).MakeGenericType(EventTypes.Types[item.Type]);
|
||||
|
||||
MethodInfo methodInfo;
|
||||
try
|
||||
{
|
||||
// check that our target actually implements this interface, otherwise call the default static
|
||||
if (genericInterfaceType.IsAssignableFrom(this.GetType()))
|
||||
{
|
||||
methodInfo = genericInterfaceType.GetMethod(nameof(IHandle<object>.Handle), BindingFlags.Public | BindingFlags.Instance)
|
||||
?? throw new InvalidOperationException($"Method not found on type {genericInterfaceType.FullName}");
|
||||
return methodInfo.Invoke(this, [payload]) as Task ?? Task.CompletedTask;
|
||||
}
|
||||
else
|
||||
{
|
||||
// The error here is we have registered for an event that we do not have code to listen to
|
||||
throw new InvalidOperationException($"No handler found for event '{item.Type}'; expecting IHandle<{item.Type}> implementation.");
|
||||
}
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
Logger.LogError(ex, $"Error invoking method {nameof(IHandle<object>.Handle)}");
|
||||
throw; // TODO: ?
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public virtual Task<RpcResponse> HandleRequest(RpcRequest request) => Task.FromResult(new RpcResponse { Error = "Not implemented" });
|
||||
public Task<RpcResponse> HandleRequest(RpcRequest request) => Task.FromResult(new RpcResponse { Error = "Not implemented" });
|
||||
|
||||
public virtual Task HandleObject(object item)
|
||||
{
|
||||
// get all Handle<T> methods
|
||||
var handleTMethods = this.GetType().GetMethods().Where(m => m.Name == "Handle" && m.GetParameters().Length == 1).ToList();
|
||||
|
||||
// get the one that matches the type of the item
|
||||
var handleTMethod = handleTMethods.FirstOrDefault(m => m.GetParameters()[0].ParameterType == item.GetType());
|
||||
|
||||
// if we found one, invoke it
|
||||
if (handleTMethod != null)
|
||||
{
|
||||
return (Task)handleTMethod.Invoke(this, [item])!;
|
||||
}
|
||||
|
||||
// otherwise, complain
|
||||
throw new InvalidOperationException($"No handler found for type {item.GetType().FullName}");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// AgentBaseExtensions.cs
|
||||
|
||||
using System.Diagnostics;
|
||||
|
||||
namespace Microsoft.AutoGen.Agents;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// AgentContext.cs
|
||||
|
||||
using System.Diagnostics;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
@ -10,14 +13,14 @@ internal sealed class AgentContext(AgentId agentId, IAgentWorkerRuntime runtime,
|
|||
|
||||
public AgentId AgentId { get; } = agentId;
|
||||
public ILogger Logger { get; } = logger;
|
||||
public AgentBase? AgentInstance { get; set; }
|
||||
public IAgentBase? AgentInstance { get; set; }
|
||||
public DistributedContextPropagator DistributedContextPropagator { get; } = distributedContextPropagator;
|
||||
public async ValueTask SendResponseAsync(RpcRequest request, RpcResponse response)
|
||||
{
|
||||
response.RequestId = request.RequestId;
|
||||
await _runtime.SendResponse(response);
|
||||
}
|
||||
public async ValueTask SendRequestAsync(AgentBase agent, RpcRequest request)
|
||||
public async ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request)
|
||||
{
|
||||
await _runtime.SendRequest(agent, request).ConfigureAwait(false);
|
||||
}
|
||||
|
|
|
@ -1,15 +0,0 @@
|
|||
using RpcAgentId = Microsoft.AutoGen.Abstractions.AgentId;
|
||||
|
||||
namespace Microsoft.AutoGen.Agents;
|
||||
|
||||
public sealed record class AgentId(string Type, string Key)
|
||||
{
|
||||
public static implicit operator RpcAgentId(AgentId agentId) => new()
|
||||
{
|
||||
Type = agentId.Type,
|
||||
Key = agentId.Key
|
||||
};
|
||||
|
||||
public static implicit operator AgentId(RpcAgentId agentId) => new(agentId.Type, agentId.Key);
|
||||
public override string ToString() => $"{Type}/{Key}";
|
||||
}
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// AgentWorker.cs
|
||||
|
||||
using System.Diagnostics;
|
||||
using Google.Protobuf;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// InferenceAgent.cs
|
||||
|
||||
using Google.Protobuf;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
using Microsoft.Extensions.AI;
|
||||
namespace Microsoft.AutoGen.Agents.Client;
|
||||
public abstract class InferenceAgent<T> : AgentBase where T : IMessage, new()
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// SKAiAgent.cs
|
||||
|
||||
using System.Globalization;
|
||||
using System.Text;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
using Microsoft.SemanticKernel;
|
||||
using Microsoft.SemanticKernel.Connectors.OpenAI;
|
||||
using Microsoft.SemanticKernel.Memory;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// ConsoleAgent.cs
|
||||
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// IHandleConsole.cs
|
||||
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
|
||||
namespace Microsoft.AutoGen.Agents;
|
||||
|
||||
public interface IHandleConsole : IHandle<Output>, IHandle<Input>
|
||||
{
|
||||
string Route { get; }
|
||||
AgentId AgentId { get; }
|
||||
ValueTask PublishEvent(CloudEvent item);
|
||||
|
||||
async Task IHandle<Output>.Handle(Output item)
|
||||
{
|
||||
// Assuming item has a property `Message` that we want to write to the console
|
||||
Console.WriteLine(item.Message);
|
||||
await ProcessOutput(item.Message);
|
||||
|
||||
var evt = new OutputWritten
|
||||
{
|
||||
Route = "console"
|
||||
}.ToCloudEvent(AgentId.Key);
|
||||
await PublishEvent(evt);
|
||||
}
|
||||
async Task IHandle<Input>.Handle(Input item)
|
||||
{
|
||||
Console.WriteLine("Please enter input:");
|
||||
string content = Console.ReadLine() ?? string.Empty;
|
||||
|
||||
await ProcessInput(content);
|
||||
|
||||
var evt = new InputProcessed
|
||||
{
|
||||
Route = "console"
|
||||
}.ToCloudEvent(AgentId.Key);
|
||||
await PublishEvent(evt);
|
||||
}
|
||||
static Task ProcessOutput(string message)
|
||||
{
|
||||
// Implement your output processing logic here
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
static Task<string> ProcessInput(string message)
|
||||
{
|
||||
// Implement your input processing logic here
|
||||
return Task.FromResult(message);
|
||||
}
|
||||
}
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// FileAgent.cs
|
||||
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// IOAgent.cs
|
||||
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
|
||||
namespace Microsoft.AutoGen.Agents;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// WebAPIAgent.cs
|
||||
|
||||
using Microsoft.AspNetCore.Builder;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// App.cs
|
||||
|
||||
using System.Diagnostics.CodeAnalysis;
|
||||
using Google.Protobuf;
|
||||
using Microsoft.AspNetCore.Builder;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// GrpcAgentWorkerRuntime.cs
|
||||
|
||||
using System.Collections.Concurrent;
|
||||
using System.Diagnostics;
|
||||
using System.Reflection;
|
||||
|
@ -83,6 +86,7 @@ public sealed class GrpcAgentWorkerRuntime : IHostedService, IDisposable, IAgent
|
|||
message.Response.RequestId = request.OriginalRequestId;
|
||||
request.Agent.ReceiveMessage(message);
|
||||
break;
|
||||
|
||||
case Message.MessageOneofCase.RegisterAgentTypeResponse:
|
||||
if (!message.RegisterAgentTypeResponse.Success)
|
||||
{
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// HostBuilderExtensions.cs
|
||||
|
||||
using System.Diagnostics;
|
||||
using System.Diagnostics.CodeAnalysis;
|
||||
using System.Reflection;
|
||||
|
@ -71,7 +74,52 @@ public static class HostBuilderExtensions
|
|||
.Where(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IHandle<>))
|
||||
.Select(i => (GetMessageDescriptor(i.GetGenericArguments().First())?.FullName ?? "")).ToHashSet()))
|
||||
.ToDictionary(item => item.t, item => item.Item2);
|
||||
// if the assembly contains any interfaces of type IHandler, then add all the methods of the interface to the eventsMap
|
||||
var handlersMap = AppDomain.CurrentDomain.GetAssemblies()
|
||||
.SelectMany(assembly => assembly.GetTypes())
|
||||
.Where(type => ReflectionHelper.IsSubclassOfGeneric(type, typeof(AgentBase)) && !type.IsAbstract)
|
||||
.Select(t => (t, t.GetMethods()
|
||||
.Where(m => m.Name == "Handle")
|
||||
.Select(m => (GetMessageDescriptor(m.GetParameters().First().ParameterType)?.FullName ?? "")).ToHashSet()))
|
||||
.ToDictionary(item => item.t, item => item.Item2);
|
||||
// get interfaces implemented by the agent and get the methods of the interface if they are named Handle
|
||||
var ifaceHandlersMap = AppDomain.CurrentDomain.GetAssemblies()
|
||||
.SelectMany(assembly => assembly.GetTypes())
|
||||
.Where(type => ReflectionHelper.IsSubclassOfGeneric(type, typeof(AgentBase)) && !type.IsAbstract)
|
||||
.Select(t => t.GetInterfaces()
|
||||
.Select(i => (t, i, i.GetMethods()
|
||||
.Where(m => m.Name == "Handle")
|
||||
.Select(m => (GetMessageDescriptor(m.GetParameters().First().ParameterType)?.FullName ?? ""))
|
||||
//to dictionary of type t and paramter type of the method
|
||||
.ToDictionary(m => m, m => m).Keys.ToHashSet())).ToList());
|
||||
// for each item in ifaceHandlersMap, add the handlers to eventsMap with item as the key
|
||||
foreach (var item in ifaceHandlersMap)
|
||||
{
|
||||
foreach (var iface in item)
|
||||
{
|
||||
if (eventsMap.TryGetValue(iface.Item2, out var events))
|
||||
{
|
||||
events.UnionWith(iface.Item3);
|
||||
}
|
||||
else
|
||||
{
|
||||
eventsMap[iface.Item2] = iface.Item3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// merge the handlersMap into the eventsMap
|
||||
foreach (var item in handlersMap)
|
||||
{
|
||||
if (eventsMap.TryGetValue(item.Key, out var events))
|
||||
{
|
||||
events.UnionWith(item.Value);
|
||||
}
|
||||
else
|
||||
{
|
||||
eventsMap[item.Key] = item.Value;
|
||||
}
|
||||
}
|
||||
return new EventTypes(typeRegistry, types, eventsMap);
|
||||
});
|
||||
return new AgentApplicationBuilder(builder);
|
||||
|
|
|
@ -1,20 +0,0 @@
|
|||
using Google.Protobuf;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
|
||||
namespace Microsoft.AutoGen.Agents
|
||||
{
|
||||
public interface IAgentBase
|
||||
{
|
||||
// Properties
|
||||
AgentId AgentId { get; }
|
||||
IAgentContext Context { get; }
|
||||
|
||||
// Methods
|
||||
Task CallHandler(CloudEvent item);
|
||||
Task<RpcResponse> HandleRequest(RpcRequest request);
|
||||
void ReceiveMessage(Message message);
|
||||
Task Store(AgentState state);
|
||||
Task<T> Read<T>(AgentId agentId) where T : IMessage, new();
|
||||
ValueTask PublishEvent(CloudEvent item);
|
||||
}
|
||||
}
|
|
@ -1,18 +0,0 @@
|
|||
using System.Diagnostics;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace Microsoft.AutoGen.Agents;
|
||||
|
||||
public interface IAgentContext
|
||||
{
|
||||
AgentId AgentId { get; }
|
||||
AgentBase? AgentInstance { get; set; }
|
||||
DistributedContextPropagator DistributedContextPropagator { get; }
|
||||
ILogger Logger { get; }
|
||||
ValueTask Store(AgentState value);
|
||||
ValueTask<AgentState> Read(AgentId agentId);
|
||||
ValueTask SendResponseAsync(RpcRequest request, RpcResponse response);
|
||||
ValueTask SendRequestAsync(AgentBase agent, RpcRequest request);
|
||||
ValueTask PublishEventAsync(CloudEvent @event);
|
||||
}
|
|
@ -12,7 +12,7 @@
|
|||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="../Abstractions/Microsoft.AutoGen.Abstractions.csproj" />
|
||||
<ProjectReference Include="..\Abstractions\Microsoft.AutoGen.Abstractions.csproj" />
|
||||
<ProjectReference Include="..\Runtime\Microsoft.AutoGen.Runtime.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
|
|
|
@ -1,33 +1,35 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// AIModelClientHostingExtensions.cs
|
||||
|
||||
using Microsoft.Extensions.AI;
|
||||
|
||||
namespace Microsoft.Extensions.Hosting
|
||||
{
|
||||
public static class AIModelClient
|
||||
{
|
||||
public static IHostApplicationBuilder AddChatCompletionService(this IHostApplicationBuilder builder, string serviceName)
|
||||
{
|
||||
var pipeline = (ChatClientBuilder pipeline) => pipeline
|
||||
.UseLogging()
|
||||
.UseFunctionInvocation()
|
||||
.UseOpenTelemetry(configure: c => c.EnableSensitiveData = true);
|
||||
namespace Microsoft.Extensions.Hosting;
|
||||
|
||||
if (builder.Configuration[$"{serviceName}:ModelType"] == "ollama")
|
||||
{
|
||||
builder.AddOllamaChatClient(serviceName, pipeline);
|
||||
}
|
||||
else if (builder.Configuration[$"{serviceName}:ModelType"] == "openai" || builder.Configuration[$"{serviceName}:ModelType"] == "azureopenai")
|
||||
{
|
||||
builder.AddOpenAIChatClient(serviceName, pipeline);
|
||||
}
|
||||
else if (builder.Configuration[$"{serviceName}:ModelType"] == "azureaiinference")
|
||||
{
|
||||
builder.AddAzureChatClient(serviceName, pipeline);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("Did not find a valid model implementation for the given service name ${serviceName}, valid supported implemenation types are ollama, openai, azureopenai, azureaiinference");
|
||||
}
|
||||
return builder;
|
||||
public static class AIModelClient
|
||||
{
|
||||
public static IHostApplicationBuilder AddChatCompletionService(this IHostApplicationBuilder builder, string serviceName)
|
||||
{
|
||||
var pipeline = (ChatClientBuilder pipeline) => pipeline
|
||||
.UseLogging()
|
||||
.UseFunctionInvocation()
|
||||
.UseOpenTelemetry(configure: c => c.EnableSensitiveData = true);
|
||||
|
||||
if (builder.Configuration[$"{serviceName}:ModelType"] == "ollama")
|
||||
{
|
||||
builder.AddOllamaChatClient(serviceName, pipeline);
|
||||
}
|
||||
else if (builder.Configuration[$"{serviceName}:ModelType"] == "openai" || builder.Configuration[$"{serviceName}:ModelType"] == "azureopenai")
|
||||
{
|
||||
builder.AddOpenAIChatClient(serviceName, pipeline);
|
||||
}
|
||||
else if (builder.Configuration[$"{serviceName}:ModelType"] == "azureaiinference")
|
||||
{
|
||||
builder.AddAzureChatClient(serviceName, pipeline);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("Did not find a valid model implementation for the given service name ${serviceName}, valid supported implemenation types are ollama, openai, azureopenai, azureaiinference");
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// AIClientOptions.cs
|
||||
|
||||
using System.ComponentModel.DataAnnotations;
|
||||
|
||||
namespace Microsoft.Extensions.Hosting;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// ServiceCollectionChatCompletionExtensions.cs
|
||||
|
||||
using System.ClientModel;
|
||||
using System.Data.Common;
|
||||
using Azure;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// CloudEventExtensions.cs
|
||||
|
||||
using Google.Protobuf;
|
||||
using Google.Protobuf.WellKnownTypes;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// QdrantOptions.cs
|
||||
|
||||
using System.ComponentModel.DataAnnotations;
|
||||
|
||||
namespace Microsoft.AutoGen.Extensions.SemanticKernel;
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// SemanticKernelHostingExtensions.cs
|
||||
|
||||
using System.Text.Json;
|
||||
using Azure.AI.OpenAI;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// AgentWorkerHostingExtensions.cs
|
||||
|
||||
using System.Diagnostics;
|
||||
using Microsoft.AspNetCore.Builder;
|
||||
using Microsoft.AspNetCore.Hosting;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// AgentWorkerRegistryGrain.cs
|
||||
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
|
||||
namespace Microsoft.AutoGen.Runtime;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Host.cs
|
||||
|
||||
using Microsoft.AspNetCore.Builder;
|
||||
using Microsoft.Extensions.Hosting;
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// IAgentWorkerRegistryGrain.cs
|
||||
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
|
||||
namespace Microsoft.AutoGen.Runtime;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// IWorkerAgentGrain.cs
|
||||
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
|
||||
namespace Microsoft.AutoGen.Runtime;
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// IWorkerGateway.cs
|
||||
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// OrleansRuntimeHostingExtenions.cs
|
||||
|
||||
using System.Configuration;
|
||||
using Microsoft.AspNetCore.Builder;
|
||||
using Microsoft.Extensions.Configuration;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// WorkerAgentGrain.cs
|
||||
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
|
||||
namespace Microsoft.AutoGen.Runtime;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// WorkerGateway.cs
|
||||
|
||||
using System.Collections.Concurrent;
|
||||
using Grpc.Core;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// WorkerGatewayService.cs
|
||||
|
||||
using Grpc.Core;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// WorkerProcessConnection.cs
|
||||
|
||||
using System.Threading.Channels;
|
||||
using Grpc.Core;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Extensions.cs
|
||||
|
||||
using Microsoft.AspNetCore.Builder;
|
||||
using Microsoft.AspNetCore.Diagnostics.HealthChecks;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// ChatRequestMessageTests.cs
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
|
|
|
@ -14,206 +14,205 @@ using FluentAssertions;
|
|||
using OpenAI;
|
||||
using Xunit.Abstractions;
|
||||
|
||||
namespace AutoGen.OpenAI.Tests
|
||||
namespace AutoGen.OpenAI.Tests;
|
||||
|
||||
public partial class MathClassTest
|
||||
{
|
||||
public partial class MathClassTest
|
||||
private readonly ITestOutputHelper _output;
|
||||
|
||||
// as of 2024-05-20, aoai return 500 error when round > 1
|
||||
// I'm pretty sure that round > 5 was supported before
|
||||
// So this is probably some wield regression on aoai side
|
||||
// I'll keep this test case here for now, plus setting round to 1
|
||||
// so the test can still pass.
|
||||
// In the future, we should rewind this test case to round > 1 (previously was 5)
|
||||
private int round = 1;
|
||||
public MathClassTest(ITestOutputHelper output)
|
||||
{
|
||||
private readonly ITestOutputHelper _output;
|
||||
_output = output;
|
||||
}
|
||||
|
||||
// as of 2024-05-20, aoai return 500 error when round > 1
|
||||
// I'm pretty sure that round > 5 was supported before
|
||||
// So this is probably some wield regression on aoai side
|
||||
// I'll keep this test case here for now, plus setting round to 1
|
||||
// so the test can still pass.
|
||||
// In the future, we should rewind this test case to round > 1 (previously was 5)
|
||||
private int round = 1;
|
||||
public MathClassTest(ITestOutputHelper output)
|
||||
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
|
||||
{
|
||||
try
|
||||
{
|
||||
_output = output;
|
||||
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
|
||||
|
||||
_output.WriteLine(reply.FormatMessage());
|
||||
return Task.FromResult(reply);
|
||||
}
|
||||
|
||||
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
|
||||
catch (Exception)
|
||||
{
|
||||
try
|
||||
_output.WriteLine("Request failed");
|
||||
_output.WriteLine($"agent name: {agent.Name}");
|
||||
foreach (var message in messages)
|
||||
{
|
||||
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
|
||||
|
||||
_output.WriteLine(reply.FormatMessage());
|
||||
return Task.FromResult(reply);
|
||||
}
|
||||
catch (Exception)
|
||||
{
|
||||
_output.WriteLine("Request failed");
|
||||
_output.WriteLine($"agent name: {agent.Name}");
|
||||
foreach (var message in messages)
|
||||
{
|
||||
_output.WriteLine(message.FormatMessage());
|
||||
}
|
||||
|
||||
throw;
|
||||
_output.WriteLine(message.FormatMessage());
|
||||
}
|
||||
|
||||
throw;
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> CreateMathQuestion(string question, int question_index)
|
||||
{
|
||||
return $@"[MATH_QUESTION]
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> CreateMathQuestion(string question, int question_index)
|
||||
{
|
||||
return $@"[MATH_QUESTION]
|
||||
Question {question_index}:
|
||||
{question}
|
||||
|
||||
Student, please answer";
|
||||
}
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerQuestion(string answer)
|
||||
{
|
||||
return $@"[MATH_ANSWER]
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerQuestion(string answer)
|
||||
{
|
||||
return $@"[MATH_ANSWER]
|
||||
The answer is {answer}
|
||||
teacher please check answer";
|
||||
}
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerIsCorrect(string message)
|
||||
{
|
||||
return $@"[ANSWER_IS_CORRECT]
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerIsCorrect(string message)
|
||||
{
|
||||
return $@"[ANSWER_IS_CORRECT]
|
||||
{message}
|
||||
please update progress";
|
||||
}
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> UpdateProgress(int correctAnswerCount)
|
||||
[FunctionAttribute]
|
||||
public async Task<string> UpdateProgress(int correctAnswerCount)
|
||||
{
|
||||
if (correctAnswerCount >= this.round)
|
||||
{
|
||||
if (correctAnswerCount >= this.round)
|
||||
{
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
{GroupChatExtension.TERMINATE}";
|
||||
}
|
||||
else
|
||||
{
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
}
|
||||
else
|
||||
{
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
the number of resolved question is {correctAnswerCount}
|
||||
teacher, please create the next math question";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task OpenAIAgentMathChatTestAsync()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
|
||||
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
|
||||
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
|
||||
var openaiClient = new AzureOpenAIClient(new Uri(endPoint), new ApiKeyCredential(key));
|
||||
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
|
||||
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task OpenAIAgentMathChatTestAsync()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
|
||||
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
|
||||
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
|
||||
var openaiClient = new AzureOpenAIClient(new Uri(endPoint), new ApiKeyCredential(key));
|
||||
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
|
||||
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
|
||||
|
||||
var adminFunctionMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.UpdateProgressFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
|
||||
});
|
||||
var admin = new OpenAIChatAgent(
|
||||
chatClient: openaiClient.GetChatClient(deployName),
|
||||
name: "Admin",
|
||||
systemMessage: $@"You are admin. You update progress after each question is answered.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(adminFunctionMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
var adminFunctionMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.UpdateProgressFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
|
||||
});
|
||||
var admin = new OpenAIChatAgent(
|
||||
chatClient: openaiClient.GetChatClient(deployName),
|
||||
name: "Admin",
|
||||
systemMessage: $@"You are admin. You update progress after each question is answered.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(adminFunctionMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
|
||||
var groupAdmin = new OpenAIChatAgent(
|
||||
chatClient: openaiClient.GetChatClient(deployName),
|
||||
name: "GroupAdmin",
|
||||
systemMessage: "You are group admin. You manage the group chat.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterMiddleware(Print);
|
||||
await RunMathChatAsync(teacher, student, admin, groupAdmin);
|
||||
}
|
||||
var groupAdmin = new OpenAIChatAgent(
|
||||
chatClient: openaiClient.GetChatClient(deployName),
|
||||
name: "GroupAdmin",
|
||||
systemMessage: "You are group admin. You manage the group chat.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterMiddleware(Print);
|
||||
await RunMathChatAsync(teacher, student, admin, groupAdmin);
|
||||
}
|
||||
|
||||
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
|
||||
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
|
||||
});
|
||||
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
|
||||
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
|
||||
});
|
||||
|
||||
var teacher = new OpenAIChatAgent(
|
||||
chatClient: client.GetChatClient(model),
|
||||
name: "Teacher",
|
||||
systemMessage: @"You are a preschool math teacher.
|
||||
var teacher = new OpenAIChatAgent(
|
||||
chatClient: client.GetChatClient(model),
|
||||
name: "Teacher",
|
||||
systemMessage: @"You are a preschool math teacher.
|
||||
You create math question and ask student to answer it.
|
||||
Then you check if the answer is correct.
|
||||
If the answer is wrong, you ask student to fix it")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
|
||||
return teacher;
|
||||
}
|
||||
return teacher;
|
||||
}
|
||||
|
||||
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.AnswerQuestionFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
|
||||
});
|
||||
var student = new OpenAIChatAgent(
|
||||
chatClient: client.GetChatClient(model),
|
||||
name: "Student",
|
||||
systemMessage: @"You are a student. You answer math question from teacher.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.AnswerQuestionFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
|
||||
});
|
||||
var student = new OpenAIChatAgent(
|
||||
chatClient: client.GetChatClient(model),
|
||||
name: "Student",
|
||||
systemMessage: @"You are a student. You answer math question from teacher.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
|
||||
return student;
|
||||
}
|
||||
return student;
|
||||
}
|
||||
|
||||
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
|
||||
{
|
||||
var teacher2Student = Transition.Create(teacher, student);
|
||||
var student2Teacher = Transition.Create(student, teacher);
|
||||
var teacher2Admin = Transition.Create(teacher, admin);
|
||||
var admin2Teacher = Transition.Create(admin, teacher);
|
||||
var workflow = new Graph(
|
||||
[
|
||||
teacher2Student,
|
||||
student2Teacher,
|
||||
teacher2Admin,
|
||||
admin2Teacher,
|
||||
]);
|
||||
var group = new GroupChat(
|
||||
workflow: workflow,
|
||||
members: [
|
||||
admin,
|
||||
teacher,
|
||||
student,
|
||||
],
|
||||
admin: groupAdmin);
|
||||
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
|
||||
{
|
||||
var teacher2Student = Transition.Create(teacher, student);
|
||||
var student2Teacher = Transition.Create(student, teacher);
|
||||
var teacher2Admin = Transition.Create(teacher, admin);
|
||||
var admin2Teacher = Transition.Create(admin, teacher);
|
||||
var workflow = new Graph(
|
||||
[
|
||||
teacher2Student,
|
||||
student2Teacher,
|
||||
teacher2Admin,
|
||||
admin2Teacher,
|
||||
]);
|
||||
var group = new GroupChat(
|
||||
workflow: workflow,
|
||||
members: [
|
||||
admin,
|
||||
teacher,
|
||||
student,
|
||||
],
|
||||
admin: groupAdmin);
|
||||
|
||||
var groupChatManager = new GroupChatManager(group);
|
||||
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
|
||||
var groupChatManager = new GroupChatManager(group);
|
||||
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
|
||||
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
|
||||
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
|
||||
// check if there's terminate chat message from admin
|
||||
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
|
||||
.Count()
|
||||
.Should().Be(1);
|
||||
}
|
||||
// check if there's terminate chat message from admin
|
||||
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
|
||||
.Count()
|
||||
.Should().Be(1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,214 +13,213 @@ using Azure.AI.OpenAI;
|
|||
using FluentAssertions;
|
||||
using Xunit.Abstractions;
|
||||
|
||||
namespace AutoGen.OpenAI.V1.Tests
|
||||
namespace AutoGen.OpenAI.V1.Tests;
|
||||
|
||||
public partial class MathClassTest
|
||||
{
|
||||
public partial class MathClassTest
|
||||
private readonly ITestOutputHelper _output;
|
||||
|
||||
// as of 2024-05-20, aoai return 500 error when round > 1
|
||||
// I'm pretty sure that round > 5 was supported before
|
||||
// So this is probably some wield regression on aoai side
|
||||
// I'll keep this test case here for now, plus setting round to 1
|
||||
// so the test can still pass.
|
||||
// In the future, we should rewind this test case to round > 1 (previously was 5)
|
||||
private int round = 1;
|
||||
public MathClassTest(ITestOutputHelper output)
|
||||
{
|
||||
private readonly ITestOutputHelper _output;
|
||||
_output = output;
|
||||
}
|
||||
|
||||
// as of 2024-05-20, aoai return 500 error when round > 1
|
||||
// I'm pretty sure that round > 5 was supported before
|
||||
// So this is probably some wield regression on aoai side
|
||||
// I'll keep this test case here for now, plus setting round to 1
|
||||
// so the test can still pass.
|
||||
// In the future, we should rewind this test case to round > 1 (previously was 5)
|
||||
private int round = 1;
|
||||
public MathClassTest(ITestOutputHelper output)
|
||||
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
|
||||
{
|
||||
try
|
||||
{
|
||||
_output = output;
|
||||
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
|
||||
|
||||
_output.WriteLine(reply.FormatMessage());
|
||||
return Task.FromResult(reply);
|
||||
}
|
||||
|
||||
private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOptions? option, IAgent agent, CancellationToken ct)
|
||||
catch (Exception)
|
||||
{
|
||||
try
|
||||
_output.WriteLine("Request failed");
|
||||
_output.WriteLine($"agent name: {agent.Name}");
|
||||
foreach (var message in messages)
|
||||
{
|
||||
var reply = agent.GenerateReplyAsync(messages, option, ct).Result;
|
||||
|
||||
_output.WriteLine(reply.FormatMessage());
|
||||
return Task.FromResult(reply);
|
||||
}
|
||||
catch (Exception)
|
||||
{
|
||||
_output.WriteLine("Request failed");
|
||||
_output.WriteLine($"agent name: {agent.Name}");
|
||||
foreach (var message in messages)
|
||||
if (message is IMessage<object> envelope)
|
||||
{
|
||||
if (message is IMessage<object> envelope)
|
||||
{
|
||||
var json = JsonSerializer.Serialize(envelope.Content, new JsonSerializerOptions { WriteIndented = true });
|
||||
_output.WriteLine(json);
|
||||
}
|
||||
var json = JsonSerializer.Serialize(envelope.Content, new JsonSerializerOptions { WriteIndented = true });
|
||||
_output.WriteLine(json);
|
||||
}
|
||||
|
||||
throw;
|
||||
}
|
||||
|
||||
throw;
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> CreateMathQuestion(string question, int question_index)
|
||||
{
|
||||
return $@"[MATH_QUESTION]
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> CreateMathQuestion(string question, int question_index)
|
||||
{
|
||||
return $@"[MATH_QUESTION]
|
||||
Question {question_index}:
|
||||
{question}
|
||||
|
||||
Student, please answer";
|
||||
}
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerQuestion(string answer)
|
||||
{
|
||||
return $@"[MATH_ANSWER]
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerQuestion(string answer)
|
||||
{
|
||||
return $@"[MATH_ANSWER]
|
||||
The answer is {answer}
|
||||
teacher please check answer";
|
||||
}
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerIsCorrect(string message)
|
||||
{
|
||||
return $@"[ANSWER_IS_CORRECT]
|
||||
[FunctionAttribute]
|
||||
public async Task<string> AnswerIsCorrect(string message)
|
||||
{
|
||||
return $@"[ANSWER_IS_CORRECT]
|
||||
{message}
|
||||
please update progress";
|
||||
}
|
||||
}
|
||||
|
||||
[FunctionAttribute]
|
||||
public async Task<string> UpdateProgress(int correctAnswerCount)
|
||||
[FunctionAttribute]
|
||||
public async Task<string> UpdateProgress(int correctAnswerCount)
|
||||
{
|
||||
if (correctAnswerCount >= this.round)
|
||||
{
|
||||
if (correctAnswerCount >= this.round)
|
||||
{
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
{GroupChatExtension.TERMINATE}";
|
||||
}
|
||||
else
|
||||
{
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
}
|
||||
else
|
||||
{
|
||||
return $@"[UPDATE_PROGRESS]
|
||||
the number of resolved question is {correctAnswerCount}
|
||||
teacher, please create the next math question";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task OpenAIAgentMathChatTestAsync()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
|
||||
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
|
||||
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
|
||||
var openaiClient = new OpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(key));
|
||||
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
|
||||
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task OpenAIAgentMathChatTestAsync()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
|
||||
var endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
|
||||
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
|
||||
var openaiClient = new OpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(key));
|
||||
var teacher = await CreateTeacherAgentAsync(openaiClient, deployName);
|
||||
var student = await CreateStudentAssistantAgentAsync(openaiClient, deployName);
|
||||
|
||||
var adminFunctionMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.UpdateProgressFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
|
||||
});
|
||||
var admin = new OpenAIChatAgent(
|
||||
openAIClient: openaiClient,
|
||||
modelName: deployName,
|
||||
name: "Admin",
|
||||
systemMessage: $@"You are admin. You update progress after each question is answered.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(adminFunctionMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
var adminFunctionMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.UpdateProgressFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.UpdateProgressFunctionContract.Name, this.UpdateProgressWrapper },
|
||||
});
|
||||
var admin = new OpenAIChatAgent(
|
||||
openAIClient: openaiClient,
|
||||
modelName: deployName,
|
||||
name: "Admin",
|
||||
systemMessage: $@"You are admin. You update progress after each question is answered.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(adminFunctionMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
|
||||
var groupAdmin = new OpenAIChatAgent(
|
||||
openAIClient: openaiClient,
|
||||
modelName: deployName,
|
||||
name: "GroupAdmin",
|
||||
systemMessage: "You are group admin. You manage the group chat.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterMiddleware(Print);
|
||||
await RunMathChatAsync(teacher, student, admin, groupAdmin);
|
||||
}
|
||||
var groupAdmin = new OpenAIChatAgent(
|
||||
openAIClient: openaiClient,
|
||||
modelName: deployName,
|
||||
name: "GroupAdmin",
|
||||
systemMessage: "You are group admin. You manage the group chat.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterMiddleware(Print);
|
||||
await RunMathChatAsync(teacher, student, admin, groupAdmin);
|
||||
}
|
||||
|
||||
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
|
||||
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
|
||||
});
|
||||
private async Task<IAgent> CreateTeacherAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.CreateMathQuestionFunctionContract, this.AnswerIsCorrectFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.CreateMathQuestionFunctionContract.Name!, this.CreateMathQuestionWrapper },
|
||||
{ this.AnswerIsCorrectFunctionContract.Name!, this.AnswerIsCorrectWrapper },
|
||||
});
|
||||
|
||||
var teacher = new OpenAIChatAgent(
|
||||
openAIClient: client,
|
||||
name: "Teacher",
|
||||
systemMessage: @"You are a preschool math teacher.
|
||||
var teacher = new OpenAIChatAgent(
|
||||
openAIClient: client,
|
||||
name: "Teacher",
|
||||
systemMessage: @"You are a preschool math teacher.
|
||||
You create math question and ask student to answer it.
|
||||
Then you check if the answer is correct.
|
||||
If the answer is wrong, you ask student to fix it",
|
||||
modelName: model)
|
||||
.RegisterMiddleware(Print)
|
||||
.RegisterMiddleware(new OpenAIChatRequestMessageConnector())
|
||||
.RegisterMiddleware(functionCallMiddleware);
|
||||
modelName: model)
|
||||
.RegisterMiddleware(Print)
|
||||
.RegisterMiddleware(new OpenAIChatRequestMessageConnector())
|
||||
.RegisterMiddleware(functionCallMiddleware);
|
||||
|
||||
return teacher;
|
||||
}
|
||||
return teacher;
|
||||
}
|
||||
|
||||
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.AnswerQuestionFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
|
||||
});
|
||||
var student = new OpenAIChatAgent(
|
||||
openAIClient: client,
|
||||
name: "Student",
|
||||
modelName: model,
|
||||
systemMessage: @"You are a student. You answer math question from teacher.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
private async Task<IAgent> CreateStudentAssistantAgentAsync(OpenAIClient client, string model)
|
||||
{
|
||||
var functionCallMiddleware = new FunctionCallMiddleware(
|
||||
functions: [this.AnswerQuestionFunctionContract],
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ this.AnswerQuestionFunctionContract.Name!, this.AnswerQuestionWrapper },
|
||||
});
|
||||
var student = new OpenAIChatAgent(
|
||||
openAIClient: client,
|
||||
name: "Student",
|
||||
modelName: model,
|
||||
systemMessage: @"You are a student. You answer math question from teacher.")
|
||||
.RegisterMessageConnector()
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware)
|
||||
.RegisterMiddleware(Print);
|
||||
|
||||
return student;
|
||||
}
|
||||
return student;
|
||||
}
|
||||
|
||||
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
|
||||
{
|
||||
var teacher2Student = Transition.Create(teacher, student);
|
||||
var student2Teacher = Transition.Create(student, teacher);
|
||||
var teacher2Admin = Transition.Create(teacher, admin);
|
||||
var admin2Teacher = Transition.Create(admin, teacher);
|
||||
var workflow = new Graph(
|
||||
[
|
||||
teacher2Student,
|
||||
student2Teacher,
|
||||
teacher2Admin,
|
||||
admin2Teacher,
|
||||
]);
|
||||
var group = new GroupChat(
|
||||
workflow: workflow,
|
||||
members: [
|
||||
admin,
|
||||
teacher,
|
||||
student,
|
||||
],
|
||||
admin: groupAdmin);
|
||||
private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin, IAgent groupAdmin)
|
||||
{
|
||||
var teacher2Student = Transition.Create(teacher, student);
|
||||
var student2Teacher = Transition.Create(student, teacher);
|
||||
var teacher2Admin = Transition.Create(teacher, admin);
|
||||
var admin2Teacher = Transition.Create(admin, teacher);
|
||||
var workflow = new Graph(
|
||||
[
|
||||
teacher2Student,
|
||||
student2Teacher,
|
||||
teacher2Admin,
|
||||
admin2Teacher,
|
||||
]);
|
||||
var group = new GroupChat(
|
||||
workflow: workflow,
|
||||
members: [
|
||||
admin,
|
||||
teacher,
|
||||
student,
|
||||
],
|
||||
admin: groupAdmin);
|
||||
|
||||
var groupChatManager = new GroupChatManager(group);
|
||||
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
|
||||
var groupChatManager = new GroupChatManager(group);
|
||||
var chatHistory = await admin.InitiateChatAsync(groupChatManager, "teacher, create question", maxRound: 50);
|
||||
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
|
||||
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true)
|
||||
.Count()
|
||||
.Should().BeGreaterThanOrEqualTo(this.round);
|
||||
|
||||
// check if there's terminate chat message from admin
|
||||
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
|
||||
.Count()
|
||||
.Should().Be(1);
|
||||
}
|
||||
// check if there's terminate chat message from admin
|
||||
chatHistory.Where(msg => msg.From == admin.Name && msg.IsGroupChatTerminateMessage())
|
||||
.Count()
|
||||
.Should().Be(1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// KernelFunctionMiddlewareTests.cs
|
||||
|
||||
using System.ClientModel;
|
||||
using AutoGen.Core;
|
||||
|
|
|
@ -1,87 +1,87 @@
|
|||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// FunctionCallTemplateEncodingTests.cs
|
||||
|
||||
using AutoGen.SourceGenerator.Template; // Needed for FunctionCallTemplate
|
||||
using Xunit; // Needed for Fact and Assert
|
||||
|
||||
namespace AutoGen.SourceGenerator.Tests
|
||||
namespace AutoGen.SourceGenerator.Tests;
|
||||
|
||||
public class FunctionCallTemplateEncodingTests
|
||||
{
|
||||
public class FunctionCallTemplateEncodingTests
|
||||
[Fact]
|
||||
public void FunctionDescription_Should_Encode_DoubleQuotes()
|
||||
{
|
||||
[Fact]
|
||||
public void FunctionDescription_Should_Encode_DoubleQuotes()
|
||||
// Arrange
|
||||
var functionContracts = new List<SourceGeneratorFunctionContract>
|
||||
{
|
||||
// Arrange
|
||||
var functionContracts = new List<SourceGeneratorFunctionContract>
|
||||
new SourceGeneratorFunctionContract
|
||||
{
|
||||
new SourceGeneratorFunctionContract
|
||||
Name = "TestFunction",
|
||||
Description = "This is a \"test\" function",
|
||||
Parameters = new SourceGeneratorParameterContract[]
|
||||
{
|
||||
Name = "TestFunction",
|
||||
Description = "This is a \"test\" function",
|
||||
Parameters = new SourceGeneratorParameterContract[]
|
||||
new SourceGeneratorParameterContract
|
||||
{
|
||||
new SourceGeneratorParameterContract
|
||||
{
|
||||
Name = "param1",
|
||||
Description = "This is a \"parameter\" description",
|
||||
Type = "string",
|
||||
IsOptional = false
|
||||
}
|
||||
},
|
||||
ReturnType = "void"
|
||||
}
|
||||
};
|
||||
Name = "param1",
|
||||
Description = "This is a \"parameter\" description",
|
||||
Type = "string",
|
||||
IsOptional = false
|
||||
}
|
||||
},
|
||||
ReturnType = "void"
|
||||
}
|
||||
};
|
||||
|
||||
var template = new FunctionCallTemplate
|
||||
{
|
||||
NameSpace = "TestNamespace",
|
||||
ClassName = "TestClass",
|
||||
FunctionContracts = functionContracts
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = template.TransformText();
|
||||
|
||||
// Assert
|
||||
Assert.Contains("Description = @\"This is a \"\"test\"\" function\"", result);
|
||||
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ParameterDescription_Should_Encode_DoubleQuotes()
|
||||
var template = new FunctionCallTemplate
|
||||
{
|
||||
// Arrange
|
||||
var functionContracts = new List<SourceGeneratorFunctionContract>
|
||||
NameSpace = "TestNamespace",
|
||||
ClassName = "TestClass",
|
||||
FunctionContracts = functionContracts
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = template.TransformText();
|
||||
|
||||
// Assert
|
||||
Assert.Contains("Description = @\"This is a \"\"test\"\" function\"", result);
|
||||
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ParameterDescription_Should_Encode_DoubleQuotes()
|
||||
{
|
||||
// Arrange
|
||||
var functionContracts = new List<SourceGeneratorFunctionContract>
|
||||
{
|
||||
new SourceGeneratorFunctionContract
|
||||
{
|
||||
new SourceGeneratorFunctionContract
|
||||
Name = "TestFunction",
|
||||
Description = "This is a test function",
|
||||
Parameters = new SourceGeneratorParameterContract[]
|
||||
{
|
||||
Name = "TestFunction",
|
||||
Description = "This is a test function",
|
||||
Parameters = new SourceGeneratorParameterContract[]
|
||||
new SourceGeneratorParameterContract
|
||||
{
|
||||
new SourceGeneratorParameterContract
|
||||
{
|
||||
Name = "param1",
|
||||
Description = "This is a \"parameter\" description",
|
||||
Type = "string",
|
||||
IsOptional = false
|
||||
}
|
||||
},
|
||||
ReturnType = "void"
|
||||
}
|
||||
};
|
||||
Name = "param1",
|
||||
Description = "This is a \"parameter\" description",
|
||||
Type = "string",
|
||||
IsOptional = false
|
||||
}
|
||||
},
|
||||
ReturnType = "void"
|
||||
}
|
||||
};
|
||||
|
||||
var template = new FunctionCallTemplate
|
||||
{
|
||||
NameSpace = "TestNamespace",
|
||||
ClassName = "TestClass",
|
||||
FunctionContracts = functionContracts
|
||||
};
|
||||
var template = new FunctionCallTemplate
|
||||
{
|
||||
NameSpace = "TestNamespace",
|
||||
ClassName = "TestClass",
|
||||
FunctionContracts = functionContracts
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = template.TransformText();
|
||||
// Act
|
||||
var result = template.TransformText();
|
||||
|
||||
// Assert
|
||||
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
|
||||
}
|
||||
// Assert
|
||||
Assert.Contains("Description = @\"This is a \"\"parameter\"\" description\"", result);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,122 +10,121 @@ using FluentAssertions;
|
|||
using OpenAI.Chat;
|
||||
using Xunit;
|
||||
|
||||
namespace AutoGen.SourceGenerator.Tests
|
||||
namespace AutoGen.SourceGenerator.Tests;
|
||||
|
||||
public class FunctionExample
|
||||
{
|
||||
public class FunctionExample
|
||||
private readonly FunctionExamples functionExamples = new FunctionExamples();
|
||||
private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions
|
||||
{
|
||||
private readonly FunctionExamples functionExamples = new FunctionExamples();
|
||||
private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions
|
||||
WriteIndented = true,
|
||||
};
|
||||
|
||||
[Fact]
|
||||
public void Add_Test()
|
||||
{
|
||||
var args = new
|
||||
{
|
||||
WriteIndented = true,
|
||||
a = 1,
|
||||
b = 2,
|
||||
};
|
||||
|
||||
[Fact]
|
||||
public void Add_Test()
|
||||
this.VerifyFunction(functionExamples.AddWrapper, args, 3);
|
||||
this.VerifyFunctionDefinition(functionExamples.AddFunctionContract.ToChatTool());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Sum_Test()
|
||||
{
|
||||
var args = new
|
||||
{
|
||||
var args = new
|
||||
args = new double[] { 1, 2, 3 },
|
||||
};
|
||||
|
||||
this.VerifyFunction(functionExamples.SumWrapper, args, 6.0);
|
||||
this.VerifyFunctionDefinition(functionExamples.SumFunctionContract.ToChatTool());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task DictionaryToString_Test()
|
||||
{
|
||||
var args = new
|
||||
{
|
||||
xargs = new Dictionary<string, string>
|
||||
{
|
||||
a = 1,
|
||||
b = 2,
|
||||
};
|
||||
{ "a", "1" },
|
||||
{ "b", "2" },
|
||||
},
|
||||
};
|
||||
|
||||
this.VerifyFunction(functionExamples.AddWrapper, args, 3);
|
||||
this.VerifyFunctionDefinition(functionExamples.AddFunctionContract.ToChatTool());
|
||||
}
|
||||
await this.VerifyAsyncFunction(functionExamples.DictionaryToStringAsyncWrapper, args, JsonSerializer.Serialize(args.xargs, jsonSerializerOptions));
|
||||
this.VerifyFunctionDefinition(functionExamples.DictionaryToStringAsyncFunctionContract.ToChatTool());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Sum_Test()
|
||||
[Fact]
|
||||
public async Task TopLevelFunctionExampleAddTestAsync()
|
||||
{
|
||||
var example = new TopLevelStatementFunctionExample();
|
||||
var args = new
|
||||
{
|
||||
var args = new
|
||||
{
|
||||
args = new double[] { 1, 2, 3 },
|
||||
};
|
||||
a = 1,
|
||||
b = 2,
|
||||
};
|
||||
|
||||
this.VerifyFunction(functionExamples.SumWrapper, args, 6.0);
|
||||
this.VerifyFunctionDefinition(functionExamples.SumFunctionContract.ToChatTool());
|
||||
}
|
||||
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task DictionaryToString_Test()
|
||||
[Fact]
|
||||
public async Task FilescopeFunctionExampleAddTestAsync()
|
||||
{
|
||||
var example = new FilescopeNamespaceFunctionExample();
|
||||
var args = new
|
||||
{
|
||||
var args = new
|
||||
{
|
||||
xargs = new Dictionary<string, string>
|
||||
{
|
||||
{ "a", "1" },
|
||||
{ "b", "2" },
|
||||
},
|
||||
};
|
||||
a = 1,
|
||||
b = 2,
|
||||
};
|
||||
|
||||
await this.VerifyAsyncFunction(functionExamples.DictionaryToStringAsyncWrapper, args, JsonSerializer.Serialize(args.xargs, jsonSerializerOptions));
|
||||
this.VerifyFunctionDefinition(functionExamples.DictionaryToStringAsyncFunctionContract.ToChatTool());
|
||||
}
|
||||
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TopLevelFunctionExampleAddTestAsync()
|
||||
[Fact]
|
||||
public void Query_Test()
|
||||
{
|
||||
var args = new
|
||||
{
|
||||
var example = new TopLevelStatementFunctionExample();
|
||||
var args = new
|
||||
{
|
||||
a = 1,
|
||||
b = 2,
|
||||
};
|
||||
query = "hello",
|
||||
k = 3,
|
||||
};
|
||||
|
||||
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
|
||||
}
|
||||
this.VerifyFunction(functionExamples.QueryWrapper, args, new[] { "hello", "hello", "hello" });
|
||||
this.VerifyFunctionDefinition(functionExamples.QueryFunctionContract.ToChatTool());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task FilescopeFunctionExampleAddTestAsync()
|
||||
[UseReporter(typeof(DiffReporter))]
|
||||
[UseApprovalSubdirectory("ApprovalTests")]
|
||||
private void VerifyFunctionDefinition(ChatTool function)
|
||||
{
|
||||
var func = new
|
||||
{
|
||||
var example = new FilescopeNamespaceFunctionExample();
|
||||
var args = new
|
||||
{
|
||||
a = 1,
|
||||
b = 2,
|
||||
};
|
||||
name = function.FunctionName,
|
||||
description = function.FunctionDescription.Replace(Environment.NewLine, ","),
|
||||
parameters = function.FunctionParameters.ToObjectFromJson<object>(options: jsonSerializerOptions),
|
||||
};
|
||||
|
||||
await this.VerifyAsyncFunction(example.AddWrapper, args, "3");
|
||||
}
|
||||
Approvals.Verify(JsonSerializer.Serialize(func, jsonSerializerOptions));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Query_Test()
|
||||
{
|
||||
var args = new
|
||||
{
|
||||
query = "hello",
|
||||
k = 3,
|
||||
};
|
||||
private void VerifyFunction<T, U>(Func<string, T> func, U args, T expected)
|
||||
{
|
||||
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
|
||||
var res = func(str);
|
||||
res.Should().BeEquivalentTo(expected);
|
||||
}
|
||||
|
||||
this.VerifyFunction(functionExamples.QueryWrapper, args, new[] { "hello", "hello", "hello" });
|
||||
this.VerifyFunctionDefinition(functionExamples.QueryFunctionContract.ToChatTool());
|
||||
}
|
||||
|
||||
[UseReporter(typeof(DiffReporter))]
|
||||
[UseApprovalSubdirectory("ApprovalTests")]
|
||||
private void VerifyFunctionDefinition(ChatTool function)
|
||||
{
|
||||
var func = new
|
||||
{
|
||||
name = function.FunctionName,
|
||||
description = function.FunctionDescription.Replace(Environment.NewLine, ","),
|
||||
parameters = function.FunctionParameters.ToObjectFromJson<object>(options: jsonSerializerOptions),
|
||||
};
|
||||
|
||||
Approvals.Verify(JsonSerializer.Serialize(func, jsonSerializerOptions));
|
||||
}
|
||||
|
||||
private void VerifyFunction<T, U>(Func<string, T> func, U args, T expected)
|
||||
{
|
||||
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
|
||||
var res = func(str);
|
||||
res.Should().BeEquivalentTo(expected);
|
||||
}
|
||||
|
||||
private async Task VerifyAsyncFunction<T, U>(Func<string, Task<T>> func, U args, T expected)
|
||||
{
|
||||
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
|
||||
var res = await func(str);
|
||||
res.Should().BeEquivalentTo(expected);
|
||||
}
|
||||
private async Task VerifyAsyncFunction<T, U>(Func<string, Task<T>> func, U args, T expected)
|
||||
{
|
||||
var str = JsonSerializer.Serialize(args, jsonSerializerOptions);
|
||||
var res = await func(str);
|
||||
res.Should().BeEquivalentTo(expected);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,67 +4,66 @@
|
|||
using System.Text.Json;
|
||||
using AutoGen.Core;
|
||||
|
||||
namespace AutoGen.SourceGenerator.Tests
|
||||
namespace AutoGen.SourceGenerator.Tests;
|
||||
|
||||
public partial class FunctionExamples
|
||||
{
|
||||
public partial class FunctionExamples
|
||||
/// <summary>
|
||||
/// Add function
|
||||
/// </summary>
|
||||
/// <param name="a">a</param>
|
||||
/// <param name="b">b</param>
|
||||
[FunctionAttribute]
|
||||
public int Add(int a, int b)
|
||||
{
|
||||
/// <summary>
|
||||
/// Add function
|
||||
/// </summary>
|
||||
/// <param name="a">a</param>
|
||||
/// <param name="b">b</param>
|
||||
[FunctionAttribute]
|
||||
public int Add(int a, int b)
|
||||
{
|
||||
return a + b;
|
||||
}
|
||||
return a + b;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Add two numbers.
|
||||
/// </summary>
|
||||
/// <param name="a">The first number.</param>
|
||||
/// <param name="b">The second number.</param>
|
||||
[Function]
|
||||
public Task<string> AddAsync(int a, int b)
|
||||
{
|
||||
return Task.FromResult($"{a} + {b} = {a + b}");
|
||||
}
|
||||
/// <summary>
|
||||
/// Add two numbers.
|
||||
/// </summary>
|
||||
/// <param name="a">The first number.</param>
|
||||
/// <param name="b">The second number.</param>
|
||||
[Function]
|
||||
public Task<string> AddAsync(int a, int b)
|
||||
{
|
||||
return Task.FromResult($"{a} + {b} = {a + b}");
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sum function
|
||||
/// </summary>
|
||||
/// <param name="args">an array of double values</param>
|
||||
[FunctionAttribute]
|
||||
public double Sum(double[] args)
|
||||
{
|
||||
return args.Sum();
|
||||
}
|
||||
/// <summary>
|
||||
/// Sum function
|
||||
/// </summary>
|
||||
/// <param name="args">an array of double values</param>
|
||||
[FunctionAttribute]
|
||||
public double Sum(double[] args)
|
||||
{
|
||||
return args.Sum();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// DictionaryToString function
|
||||
/// </summary>
|
||||
/// <param name="xargs">an object of key-value pairs. key is string, value is string</param>
|
||||
[FunctionAttribute]
|
||||
public Task<string> DictionaryToStringAsync(Dictionary<string, string> xargs)
|
||||
/// <summary>
|
||||
/// DictionaryToString function
|
||||
/// </summary>
|
||||
/// <param name="xargs">an object of key-value pairs. key is string, value is string</param>
|
||||
[FunctionAttribute]
|
||||
public Task<string> DictionaryToStringAsync(Dictionary<string, string> xargs)
|
||||
{
|
||||
var res = JsonSerializer.Serialize(xargs, new JsonSerializerOptions
|
||||
{
|
||||
var res = JsonSerializer.Serialize(xargs, new JsonSerializerOptions
|
||||
{
|
||||
WriteIndented = true,
|
||||
});
|
||||
WriteIndented = true,
|
||||
});
|
||||
|
||||
return Task.FromResult(res);
|
||||
}
|
||||
return Task.FromResult(res);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// query function
|
||||
/// </summary>
|
||||
/// <param name="query">query, required</param>
|
||||
/// <param name="k">top k, optional, default value is 3</param>
|
||||
/// <param name="thresold">thresold, optional, default value is 0.5</param>
|
||||
[FunctionAttribute]
|
||||
public string[] Query(string query, int k = 3, float thresold = 0.5f)
|
||||
{
|
||||
return Enumerable.Repeat(query, k).ToArray();
|
||||
}
|
||||
/// <summary>
|
||||
/// query function
|
||||
/// </summary>
|
||||
/// <param name="query">query, required</param>
|
||||
/// <param name="k">top k, optional, default value is 3</param>
|
||||
/// <param name="thresold">thresold, optional, default value is 0.5</param>
|
||||
[FunctionAttribute]
|
||||
public string[] Query(string query, int k = 3, float thresold = 0.5f)
|
||||
{
|
||||
return Enumerable.Repeat(query, k).ToArray();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,73 +7,72 @@ using System.Threading.Tasks;
|
|||
using AutoGen.BasicSample;
|
||||
using Xunit.Abstractions;
|
||||
|
||||
namespace AutoGen.Tests
|
||||
namespace AutoGen.Tests;
|
||||
|
||||
public class BasicSampleTest
|
||||
{
|
||||
public class BasicSampleTest
|
||||
private readonly ITestOutputHelper _output;
|
||||
|
||||
public BasicSampleTest(ITestOutputHelper output)
|
||||
{
|
||||
private readonly ITestOutputHelper _output;
|
||||
_output = output;
|
||||
Console.SetOut(new ConsoleWriter(_output));
|
||||
}
|
||||
|
||||
public BasicSampleTest(ITestOutputHelper output)
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task AssistantAgentTestAsync()
|
||||
{
|
||||
await Example01_AssistantAgent.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task TwoAgentMathClassTestAsync()
|
||||
{
|
||||
await Example02_TwoAgent_MathChat.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("OPENAI_API_KEY")]
|
||||
public async Task AgentFunctionCallTestAsync()
|
||||
{
|
||||
await Example03_Agent_FunctionCall.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("MISTRAL_API_KEY")]
|
||||
public async Task MistralClientAgent_TokenCount()
|
||||
{
|
||||
await Example14_MistralClientAgent_TokenCount.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task DynamicGroupChatCalculateFibonacciAsync()
|
||||
{
|
||||
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunAsync();
|
||||
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunWorkflowAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("OPENAI_API_KEY")]
|
||||
public async Task DalleAndGPT4VTestAsync()
|
||||
{
|
||||
await Example05_Dalle_And_GPT4V.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("OPENAI_API_KEY")]
|
||||
public async Task GPT4ImageMessage()
|
||||
{
|
||||
await Example15_GPT4V_BinaryDataImageMessage.RunAsync();
|
||||
}
|
||||
|
||||
public class ConsoleWriter : StringWriter
|
||||
{
|
||||
private ITestOutputHelper output;
|
||||
public ConsoleWriter(ITestOutputHelper output)
|
||||
{
|
||||
_output = output;
|
||||
Console.SetOut(new ConsoleWriter(_output));
|
||||
this.output = output;
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task AssistantAgentTestAsync()
|
||||
public override void WriteLine(string? m)
|
||||
{
|
||||
await Example01_AssistantAgent.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task TwoAgentMathClassTestAsync()
|
||||
{
|
||||
await Example02_TwoAgent_MathChat.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("OPENAI_API_KEY")]
|
||||
public async Task AgentFunctionCallTestAsync()
|
||||
{
|
||||
await Example03_Agent_FunctionCall.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("MISTRAL_API_KEY")]
|
||||
public async Task MistralClientAgent_TokenCount()
|
||||
{
|
||||
await Example14_MistralClientAgent_TokenCount.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task DynamicGroupChatCalculateFibonacciAsync()
|
||||
{
|
||||
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunAsync();
|
||||
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunWorkflowAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("OPENAI_API_KEY")]
|
||||
public async Task DalleAndGPT4VTestAsync()
|
||||
{
|
||||
await Example05_Dalle_And_GPT4V.RunAsync();
|
||||
}
|
||||
|
||||
[ApiKeyFact("OPENAI_API_KEY")]
|
||||
public async Task GPT4ImageMessage()
|
||||
{
|
||||
await Example15_GPT4V_BinaryDataImageMessage.RunAsync();
|
||||
}
|
||||
|
||||
public class ConsoleWriter : StringWriter
|
||||
{
|
||||
private ITestOutputHelper output;
|
||||
public ConsoleWriter(ITestOutputHelper output)
|
||||
{
|
||||
this.output = output;
|
||||
}
|
||||
|
||||
public override void WriteLine(string? m)
|
||||
{
|
||||
output.WriteLine(m);
|
||||
}
|
||||
output.WriteLine(m);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,18 +3,17 @@
|
|||
|
||||
using Xunit;
|
||||
|
||||
namespace AutoGen.Tests
|
||||
{
|
||||
public class GraphTests
|
||||
{
|
||||
[Fact]
|
||||
public void GraphTest()
|
||||
{
|
||||
var graph1 = new Graph();
|
||||
Assert.NotNull(graph1);
|
||||
namespace AutoGen.Tests;
|
||||
|
||||
var graph2 = new Graph(null);
|
||||
Assert.NotNull(graph2);
|
||||
}
|
||||
public class GraphTests
|
||||
{
|
||||
[Fact]
|
||||
public void GraphTest()
|
||||
{
|
||||
var graph1 = new Graph();
|
||||
Assert.NotNull(graph1);
|
||||
|
||||
var graph2 = new Graph(null);
|
||||
Assert.NotNull(graph2);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,219 +9,218 @@ using FluentAssertions;
|
|||
using Xunit;
|
||||
using Xunit.Abstractions;
|
||||
|
||||
namespace AutoGen.Tests
|
||||
namespace AutoGen.Tests;
|
||||
|
||||
public partial class SingleAgentTest
|
||||
{
|
||||
public partial class SingleAgentTest
|
||||
private ITestOutputHelper _output;
|
||||
public SingleAgentTest(ITestOutputHelper output)
|
||||
{
|
||||
private ITestOutputHelper _output;
|
||||
public SingleAgentTest(ITestOutputHelper output)
|
||||
{
|
||||
_output = output;
|
||||
}
|
||||
_output = output;
|
||||
}
|
||||
|
||||
private ILLMConfig CreateAzureOpenAIGPT35TurboConfig()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
|
||||
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
|
||||
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
|
||||
return new AzureOpenAIConfig(endpoint, deployName, key);
|
||||
}
|
||||
private ILLMConfig CreateAzureOpenAIGPT35TurboConfig()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
|
||||
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
|
||||
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new ArgumentException("AZURE_OPENAI_DEPLOY_NAME is not set");
|
||||
return new AzureOpenAIConfig(endpoint, deployName, key);
|
||||
}
|
||||
|
||||
private ILLMConfig CreateOpenAIGPT4VisionConfig()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new ArgumentException("OPENAI_API_KEY is not set");
|
||||
return new OpenAIConfig(key, "gpt-4-vision-preview");
|
||||
}
|
||||
private ILLMConfig CreateOpenAIGPT4VisionConfig()
|
||||
{
|
||||
var key = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new ArgumentException("OPENAI_API_KEY is not set");
|
||||
return new OpenAIConfig(key, "gpt-4-vision-preview");
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task AssistantAgentFunctionCallTestAsync()
|
||||
{
|
||||
var config = this.CreateAzureOpenAIGPT35TurboConfig();
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task AssistantAgentFunctionCallTestAsync()
|
||||
{
|
||||
var config = this.CreateAzureOpenAIGPT35TurboConfig();
|
||||
|
||||
var llmConfig = new ConversableAgentConfig
|
||||
var llmConfig = new ConversableAgentConfig
|
||||
{
|
||||
Temperature = 0,
|
||||
FunctionContracts = new[]
|
||||
{
|
||||
Temperature = 0,
|
||||
FunctionContracts = new[]
|
||||
{
|
||||
this.EchoAsyncFunctionContract,
|
||||
},
|
||||
ConfigList = new[]
|
||||
{
|
||||
config,
|
||||
},
|
||||
};
|
||||
|
||||
var assistantAgent = new AssistantAgent(
|
||||
name: "assistant",
|
||||
llmConfig: llmConfig);
|
||||
|
||||
await EchoFunctionCallTestAsync(assistantAgent);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task AssistantAgentDefaultReplyTestAsync()
|
||||
{
|
||||
var assistantAgent = new AssistantAgent(
|
||||
llmConfig: null,
|
||||
name: "assistant",
|
||||
defaultReply: "hello world");
|
||||
|
||||
var reply = await assistantAgent.SendAsync("hi");
|
||||
|
||||
reply.GetContent().Should().Be("hello world");
|
||||
reply.GetRole().Should().Be(Role.Assistant);
|
||||
reply.From.Should().Be(assistantAgent.Name);
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task AssistantAgentFunctionCallSelfExecutionTestAsync()
|
||||
{
|
||||
var config = this.CreateAzureOpenAIGPT35TurboConfig();
|
||||
var llmConfig = new ConversableAgentConfig
|
||||
this.EchoAsyncFunctionContract,
|
||||
},
|
||||
ConfigList = new[]
|
||||
{
|
||||
FunctionContracts = new[]
|
||||
{
|
||||
this.EchoAsyncFunctionContract,
|
||||
},
|
||||
ConfigList = new[]
|
||||
{
|
||||
config,
|
||||
},
|
||||
};
|
||||
var assistantAgent = new AssistantAgent(
|
||||
name: "assistant",
|
||||
llmConfig: llmConfig,
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ nameof(EchoAsync), this.EchoAsyncWrapper },
|
||||
});
|
||||
config,
|
||||
},
|
||||
};
|
||||
|
||||
await EchoFunctionCallExecutionTestAsync(assistantAgent);
|
||||
}
|
||||
var assistantAgent = new AssistantAgent(
|
||||
name: "assistant",
|
||||
llmConfig: llmConfig);
|
||||
|
||||
/// <summary>
|
||||
/// echo when asked.
|
||||
/// </summary>
|
||||
/// <param name="message">message to echo</param>
|
||||
[FunctionAttribute]
|
||||
public async Task<string> EchoAsync(string message)
|
||||
await EchoFunctionCallTestAsync(assistantAgent);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task AssistantAgentDefaultReplyTestAsync()
|
||||
{
|
||||
var assistantAgent = new AssistantAgent(
|
||||
llmConfig: null,
|
||||
name: "assistant",
|
||||
defaultReply: "hello world");
|
||||
|
||||
var reply = await assistantAgent.SendAsync("hi");
|
||||
|
||||
reply.GetContent().Should().Be("hello world");
|
||||
reply.GetRole().Should().Be(Role.Assistant);
|
||||
reply.From.Should().Be(assistantAgent.Name);
|
||||
}
|
||||
|
||||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
|
||||
public async Task AssistantAgentFunctionCallSelfExecutionTestAsync()
|
||||
{
|
||||
var config = this.CreateAzureOpenAIGPT35TurboConfig();
|
||||
var llmConfig = new ConversableAgentConfig
|
||||
{
|
||||
return $"[ECHO] {message}";
|
||||
}
|
||||
FunctionContracts = new[]
|
||||
{
|
||||
this.EchoAsyncFunctionContract,
|
||||
},
|
||||
ConfigList = new[]
|
||||
{
|
||||
config,
|
||||
},
|
||||
};
|
||||
var assistantAgent = new AssistantAgent(
|
||||
name: "assistant",
|
||||
llmConfig: llmConfig,
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
{
|
||||
{ nameof(EchoAsync), this.EchoAsyncWrapper },
|
||||
});
|
||||
|
||||
/// <summary>
|
||||
/// return the label name with hightest inference cost
|
||||
/// </summary>
|
||||
/// <param name="labelName"></param>
|
||||
/// <returns></returns>
|
||||
[FunctionAttribute]
|
||||
public async Task<string> GetHighestLabel(string labelName, string color)
|
||||
await EchoFunctionCallExecutionTestAsync(assistantAgent);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// echo when asked.
|
||||
/// </summary>
|
||||
/// <param name="message">message to echo</param>
|
||||
[FunctionAttribute]
|
||||
public async Task<string> EchoAsync(string message)
|
||||
{
|
||||
return $"[ECHO] {message}";
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// return the label name with hightest inference cost
|
||||
/// </summary>
|
||||
/// <param name="labelName"></param>
|
||||
/// <returns></returns>
|
||||
[FunctionAttribute]
|
||||
public async Task<string> GetHighestLabel(string labelName, string color)
|
||||
{
|
||||
return $"[HIGHEST_LABEL] {labelName} {color}";
|
||||
}
|
||||
|
||||
public async Task EchoFunctionCallTestAsync(IAgent agent)
|
||||
{
|
||||
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that call echo function");
|
||||
var helloWorld = new TextMessage(Role.User, "echo Hello world");
|
||||
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
|
||||
|
||||
reply.From.Should().Be(agent.Name);
|
||||
reply.GetToolCalls()!.First().FunctionName.Should().Be(nameof(EchoAsync));
|
||||
}
|
||||
|
||||
public async Task EchoFunctionCallExecutionTestAsync(IAgent agent)
|
||||
{
|
||||
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
|
||||
var helloWorld = new TextMessage(Role.User, "echo Hello world");
|
||||
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
|
||||
|
||||
reply.GetContent().Should().Be("[ECHO] Hello world");
|
||||
reply.From.Should().Be(agent.Name);
|
||||
reply.Should().BeOfType<ToolCallAggregateMessage>();
|
||||
}
|
||||
|
||||
public async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent agent)
|
||||
{
|
||||
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
|
||||
var helloWorld = new TextMessage(Role.User, "echo Hello world");
|
||||
var option = new GenerateReplyOptions
|
||||
{
|
||||
return $"[HIGHEST_LABEL] {labelName} {color}";
|
||||
}
|
||||
|
||||
public async Task EchoFunctionCallTestAsync(IAgent agent)
|
||||
Temperature = 0,
|
||||
};
|
||||
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { helloWorld }, option);
|
||||
var answer = "[ECHO] Hello world";
|
||||
IMessage? finalReply = default;
|
||||
await foreach (var reply in replyStream)
|
||||
{
|
||||
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that call echo function");
|
||||
var helloWorld = new TextMessage(Role.User, "echo Hello world");
|
||||
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
|
||||
|
||||
reply.From.Should().Be(agent.Name);
|
||||
reply.GetToolCalls()!.First().FunctionName.Should().Be(nameof(EchoAsync));
|
||||
finalReply = reply;
|
||||
}
|
||||
|
||||
public async Task EchoFunctionCallExecutionTestAsync(IAgent agent)
|
||||
if (finalReply is ToolCallAggregateMessage aggregateMessage)
|
||||
{
|
||||
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
|
||||
var helloWorld = new TextMessage(Role.User, "echo Hello world");
|
||||
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { helloWorld });
|
||||
|
||||
reply.GetContent().Should().Be("[ECHO] Hello world");
|
||||
reply.From.Should().Be(agent.Name);
|
||||
reply.Should().BeOfType<ToolCallAggregateMessage>();
|
||||
var toolCallResultMessage = aggregateMessage.Message2;
|
||||
toolCallResultMessage.ToolCalls.First().Result.Should().Be(answer);
|
||||
toolCallResultMessage.From.Should().Be(agent.Name);
|
||||
toolCallResultMessage.ToolCalls.First().FunctionName.Should().Be(nameof(EchoAsync));
|
||||
}
|
||||
|
||||
public async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent agent)
|
||||
else
|
||||
{
|
||||
//var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says");
|
||||
var helloWorld = new TextMessage(Role.User, "echo Hello world");
|
||||
var option = new GenerateReplyOptions
|
||||
{
|
||||
Temperature = 0,
|
||||
};
|
||||
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { helloWorld }, option);
|
||||
var answer = "[ECHO] Hello world";
|
||||
IMessage? finalReply = default;
|
||||
await foreach (var reply in replyStream)
|
||||
{
|
||||
reply.From.Should().Be(agent.Name);
|
||||
finalReply = reply;
|
||||
}
|
||||
|
||||
if (finalReply is ToolCallAggregateMessage aggregateMessage)
|
||||
{
|
||||
var toolCallResultMessage = aggregateMessage.Message2;
|
||||
toolCallResultMessage.ToolCalls.First().Result.Should().Be(answer);
|
||||
toolCallResultMessage.From.Should().Be(agent.Name);
|
||||
toolCallResultMessage.ToolCalls.First().FunctionName.Should().Be(nameof(EchoAsync));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new Exception("unexpected message type");
|
||||
}
|
||||
}
|
||||
|
||||
public async Task UpperCaseTestAsync(IAgent agent)
|
||||
{
|
||||
var message = new TextMessage(Role.User, "Please convert abcde to upper case.");
|
||||
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { message });
|
||||
|
||||
reply.GetContent().Should().Contain("ABCDE");
|
||||
reply.From.Should().Be(agent.Name);
|
||||
}
|
||||
|
||||
public async Task UpperCaseStreamingTestAsync(IStreamingAgent agent)
|
||||
{
|
||||
var message = new TextMessage(Role.User, "Please convert 'hello world' to upper case");
|
||||
var option = new GenerateReplyOptions
|
||||
{
|
||||
Temperature = 0,
|
||||
};
|
||||
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message }, option);
|
||||
var answer = "HELLO WORLD";
|
||||
TextMessage? finalReply = default;
|
||||
await foreach (var reply in replyStream)
|
||||
{
|
||||
if (reply is TextMessageUpdate update)
|
||||
{
|
||||
update.From.Should().Be(agent.Name);
|
||||
|
||||
if (finalReply is null)
|
||||
{
|
||||
finalReply = new TextMessage(update);
|
||||
}
|
||||
else
|
||||
{
|
||||
finalReply.Update(update);
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
else if (reply is TextMessage textMessage)
|
||||
{
|
||||
finalReply = textMessage;
|
||||
continue;
|
||||
}
|
||||
|
||||
throw new Exception("unexpected message type");
|
||||
}
|
||||
|
||||
finalReply!.Content.Should().Contain(answer);
|
||||
finalReply!.Role.Should().Be(Role.Assistant);
|
||||
finalReply!.From.Should().Be(agent.Name);
|
||||
throw new Exception("unexpected message type");
|
||||
}
|
||||
}
|
||||
|
||||
public async Task UpperCaseTestAsync(IAgent agent)
|
||||
{
|
||||
var message = new TextMessage(Role.User, "Please convert abcde to upper case.");
|
||||
|
||||
var reply = await agent.SendAsync(chatHistory: new[] { message });
|
||||
|
||||
reply.GetContent().Should().Contain("ABCDE");
|
||||
reply.From.Should().Be(agent.Name);
|
||||
}
|
||||
|
||||
public async Task UpperCaseStreamingTestAsync(IStreamingAgent agent)
|
||||
{
|
||||
var message = new TextMessage(Role.User, "Please convert 'hello world' to upper case");
|
||||
var option = new GenerateReplyOptions
|
||||
{
|
||||
Temperature = 0,
|
||||
};
|
||||
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message }, option);
|
||||
var answer = "HELLO WORLD";
|
||||
TextMessage? finalReply = default;
|
||||
await foreach (var reply in replyStream)
|
||||
{
|
||||
if (reply is TextMessageUpdate update)
|
||||
{
|
||||
update.From.Should().Be(agent.Name);
|
||||
|
||||
if (finalReply is null)
|
||||
{
|
||||
finalReply = new TextMessage(update);
|
||||
}
|
||||
else
|
||||
{
|
||||
finalReply.Update(update);
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
else if (reply is TextMessage textMessage)
|
||||
{
|
||||
finalReply = textMessage;
|
||||
continue;
|
||||
}
|
||||
|
||||
throw new Exception("unexpected message type");
|
||||
}
|
||||
|
||||
finalReply!.Content.Should().Contain(answer);
|
||||
finalReply!.Role.Should().Be(Role.Assistant);
|
||||
finalReply!.From.Should().Be(agent.Name);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// TwoAgentTest.cs
|
||||
|
||||
#pragma warning disable xUnit1013
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
|
|
|
@ -18,12 +18,16 @@ from autogen_core.components.tools import FunctionTool, Tool
|
|||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from .. import EVENT_LOGGER_NAME
|
||||
from ..base import Response
|
||||
from ..messages import (
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
InnerMessage,
|
||||
ResetMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallMessage,
|
||||
ToolCallResultMessages,
|
||||
)
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
|
@ -207,7 +211,14 @@ class AssistantAgent(BaseChatAgent):
|
|||
)
|
||||
self._model_context: List[LLMMessage] = []
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
"""The types of messages that the assistant agent produces."""
|
||||
if self._handoffs:
|
||||
return [TextMessage, HandoffMessage, StopMessage]
|
||||
return [TextMessage, StopMessage]
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
# Add messages to the model context.
|
||||
for msg in messages:
|
||||
if isinstance(msg, ResetMessage):
|
||||
|
@ -215,6 +226,9 @@ class AssistantAgent(BaseChatAgent):
|
|||
else:
|
||||
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
|
||||
|
||||
# Inner messages.
|
||||
inner_messages: List[InnerMessage] = []
|
||||
|
||||
# Generate an inference result based on the current model context.
|
||||
llm_messages = self._system_messages + self._model_context
|
||||
result = await self._model_client.create(
|
||||
|
@ -227,12 +241,16 @@ class AssistantAgent(BaseChatAgent):
|
|||
# Run tool calls until the model produces a string response.
|
||||
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
|
||||
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
|
||||
# Add the tool call message to the output.
|
||||
inner_messages.append(ToolCallMessage(content=result.content, source=self.name))
|
||||
|
||||
# Execute the tool calls.
|
||||
results = await asyncio.gather(
|
||||
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
|
||||
)
|
||||
event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name))
|
||||
self._model_context.append(FunctionExecutionResultMessage(content=results))
|
||||
inner_messages.append(ToolCallResultMessages(content=results, source=self.name))
|
||||
|
||||
# Detect handoff requests.
|
||||
handoffs: List[Handoff] = []
|
||||
|
@ -242,8 +260,13 @@ class AssistantAgent(BaseChatAgent):
|
|||
if len(handoffs) > 0:
|
||||
if len(handoffs) > 1:
|
||||
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
|
||||
# Respond with a handoff message.
|
||||
return HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name)
|
||||
# Return the output messages to signal the handoff.
|
||||
return Response(
|
||||
chat_message=HandoffMessage(
|
||||
content=handoffs[0].message, target=handoffs[0].target, source=self.name
|
||||
),
|
||||
inner_messages=inner_messages,
|
||||
)
|
||||
|
||||
# Generate an inference result based on the current model context.
|
||||
result = await self._model_client.create(
|
||||
|
@ -255,9 +278,13 @@ class AssistantAgent(BaseChatAgent):
|
|||
# Detect stop request.
|
||||
request_stop = "terminate" in result.content.strip().lower()
|
||||
if request_stop:
|
||||
return StopMessage(content=result.content, source=self.name)
|
||||
return Response(
|
||||
chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages
|
||||
)
|
||||
|
||||
return TextMessage(content=result.content, source=self.name)
|
||||
return Response(
|
||||
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
|
||||
)
|
||||
|
||||
async def _execute_tool_call(
|
||||
self, tool_call: FunctionCall, cancellation_token: CancellationToken
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue