Squash changes (#2849)

This commit is contained in:
David Luong 2024-06-10 13:32:33 -04:00 committed by GitHub
parent a16b307dc0
commit d578d0dfd9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 239 additions and 35 deletions

View File

@ -23,7 +23,8 @@ public sealed class AnthropicClient : IDisposable
private static readonly JsonSerializerOptions JsonSerializerOptions = new()
{
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
Converters = { new ContentBaseConverter() }
};
private static readonly JsonSerializerOptions JsonDeserializerOptions = new()

View File

@ -49,9 +49,15 @@ public class ChatMessage
public string Role { get; set; }
[JsonPropertyName("content")]
public string Content { get; set; }
public List<ContentBase> Content { get; set; }
public ChatMessage(string role, string content)
{
Role = role;
Content = new List<ContentBase>() { new TextContent { Text = content } };
}
public ChatMessage(string role, List<ContentBase> content)
{
Role = role;
Content = content;

View File

@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
@ -19,7 +20,7 @@ public class AnthropicMessageConnector : IStreamingMiddleware
public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default)
{
var messages = context.Messages;
var chatMessages = ProcessMessage(messages, agent);
var chatMessages = await ProcessMessageAsync(messages, agent);
var response = await agent.GenerateReplyAsync(chatMessages, context.Options, cancellationToken);
return response is IMessage<ChatCompletionResponse> chatMessage
@ -31,7 +32,7 @@ public class AnthropicMessageConnector : IStreamingMiddleware
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = context.Messages;
var chatMessages = ProcessMessage(messages, agent);
var chatMessages = await ProcessMessageAsync(messages, agent);
await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken))
{
@ -53,60 +54,78 @@ public class AnthropicMessageConnector : IStreamingMiddleware
private IStreamingMessage? ProcessChatCompletionResponse(IStreamingMessage<ChatCompletionResponse> chatMessage,
IStreamingAgent agent)
{
Delta? delta = chatMessage.Content.Delta;
var delta = chatMessage.Content.Delta;
return delta != null && !string.IsNullOrEmpty(delta.Text)
? new TextMessageUpdate(role: Role.Assistant, delta.Text, from: agent.Name)
: null;
}
private IEnumerable<IMessage> ProcessMessage(IEnumerable<IMessage> messages, IAgent agent)
private async Task<IEnumerable<IMessage>> ProcessMessageAsync(IEnumerable<IMessage> messages, IAgent agent)
{
return messages.SelectMany<IMessage, IMessage>(m =>
var processedMessages = new List<IMessage>();
foreach (var message in messages)
{
return m switch
var processedMessage = message switch
{
TextMessage textMessage => ProcessTextMessage(textMessage, agent),
_ => [m],
ImageMessage imageMessage =>
new MessageEnvelope<ChatMessage>(new ChatMessage("user",
new ContentBase[] { new ImageContent { Source = await ProcessImageSourceAsync(imageMessage) } }
.ToList()),
from: agent.Name),
MultiModalMessage multiModalMessage => await ProcessMultiModalMessageAsync(multiModalMessage, agent),
_ => message,
};
});
processedMessages.Add(processedMessage);
}
return processedMessages;
}
private IMessage PostProcessMessage(ChatCompletionResponse response, IAgent from)
{
if (response.Content is null)
{
throw new ArgumentNullException(nameof(response.Content));
}
if (response.Content.Count != 1)
{
throw new NotSupportedException($"{nameof(response.Content)} != 1");
}
return new TextMessage(Role.Assistant, ((TextContent)response.Content[0]).Text ?? string.Empty, from: from.Name);
}
private IEnumerable<IMessage<ChatMessage>> ProcessTextMessage(TextMessage textMessage, IAgent agent)
private IMessage<ChatMessage> ProcessTextMessage(TextMessage textMessage, IAgent agent)
{
IEnumerable<ChatMessage> messages;
ChatMessage messages;
if (textMessage.From == agent.Name)
{
messages = [new ChatMessage(
"assistant", textMessage.Content)];
messages = new ChatMessage(
"assistant", textMessage.Content);
}
else if (textMessage.From is null)
{
if (textMessage.Role == Role.User)
{
messages = [new ChatMessage(
"user", textMessage.Content)];
messages = new ChatMessage(
"user", textMessage.Content);
}
else if (textMessage.Role == Role.Assistant)
{
messages = [new ChatMessage(
"assistant", textMessage.Content)];
messages = new ChatMessage(
"assistant", textMessage.Content);
}
else if (textMessage.Role == Role.System)
{
messages = [new ChatMessage(
"system", textMessage.Content)];
messages = new ChatMessage(
"system", textMessage.Content);
}
else
{
@ -116,10 +135,61 @@ public class AnthropicMessageConnector : IStreamingMiddleware
else
{
// if from is not null, then the message is from user
messages = [new ChatMessage(
"user", textMessage.Content)];
messages = new ChatMessage(
"user", textMessage.Content);
}
return messages.Select(m => new MessageEnvelope<ChatMessage>(m, from: textMessage.From));
return new MessageEnvelope<ChatMessage>(messages, from: textMessage.From);
}
private async Task<IMessage> ProcessMultiModalMessageAsync(MultiModalMessage multiModalMessage, IAgent agent)
{
var content = new List<ContentBase>();
foreach (var message in multiModalMessage.Content)
{
switch (message)
{
case TextMessage textMessage when textMessage.GetContent() is not null:
content.Add(new TextContent { Text = textMessage.GetContent() });
break;
case ImageMessage imageMessage:
content.Add(new ImageContent() { Source = await ProcessImageSourceAsync(imageMessage) });
break;
}
}
var chatMessage = new ChatMessage("user", content);
return MessageEnvelope.Create(chatMessage, agent.Name);
}
private async Task<ImageSource> ProcessImageSourceAsync(ImageMessage imageMessage)
{
if (imageMessage.Data != null)
{
return new ImageSource
{
MediaType = imageMessage.Data.MediaType,
Data = Convert.ToBase64String(imageMessage.Data.ToArray())
};
}
if (imageMessage.Url is null)
{
throw new InvalidOperationException("Invalid ImageMessage, the data or url must be provided");
}
var uri = new Uri(imageMessage.Url);
using var client = new HttpClient();
var response = client.GetAsync(uri).Result;
if (!response.IsSuccessStatusCode)
{
throw new HttpRequestException($"Failed to download the image from {uri}");
}
return new ImageSource
{
MediaType = "image/jpeg",
Data = Convert.ToBase64String(await response.Content.ReadAsByteArrayAsync())
};
}
}

View File

@ -49,7 +49,6 @@ public class MiddlewareStreamingAgent : IMiddlewareStreamAgent
public IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
return _agent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
}

View File

@ -1,31 +1,108 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AnthropicClientAgentTest.cs
using AutoGen.Anthropic.DTO;
using AutoGen.Anthropic.Extensions;
using AutoGen.Anthropic.Utils;
using AutoGen.Core;
using AutoGen.Tests;
using Xunit.Abstractions;
using FluentAssertions;
namespace AutoGen.Anthropic;
namespace AutoGen.Anthropic.Tests;
public class AnthropicClientAgentTest
{
private readonly ITestOutputHelper _output;
public AnthropicClientAgentTest(ITestOutputHelper output) => _output = output;
[ApiKeyFact("ANTHROPIC_API_KEY")]
public async Task AnthropicAgentChatCompletionTestAsync()
{
var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
var agent = new AnthropicClientAgent(
client,
name: "AnthropicAgent",
AnthropicConstants.Claude3Haiku,
systemMessage: "You are a helpful AI assistant that convert user message to upper case")
.RegisterMessageConnector();
var uppCaseMessage = new TextMessage(Role.User, "abcdefg");
var reply = await agent.SendAsync(chatHistory: new[] { uppCaseMessage });
reply.GetContent().Should().Contain("ABCDEFG");
reply.From.Should().Be(agent.Name);
}
[ApiKeyFact("ANTHROPIC_API_KEY")]
public async Task AnthropicAgentTestProcessImageAsync()
{
var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
var agent = new AnthropicClientAgent(
client,
name: "AnthropicAgent",
AnthropicConstants.Claude3Haiku).RegisterMessageConnector();
var singleAgentTest = new SingleAgentTest(_output);
await singleAgentTest.UpperCaseTestAsync(agent);
await singleAgentTest.UpperCaseStreamingTestAsync(agent);
var base64Image = await AnthropicTestUtils.Base64FromImageAsync("square.png");
var imageMessage = new ChatMessage("user",
[new ImageContent { Source = new ImageSource { MediaType = "image/png", Data = base64Image } }]);
var messages = new IMessage[] { MessageEnvelope.Create(imageMessage) };
// test streaming
foreach (var message in messages)
{
var reply = agent.GenerateStreamingReplyAsync([message]);
await foreach (var streamingMessage in reply)
{
streamingMessage.Should().BeOfType<TextMessageUpdate>();
streamingMessage.As<TextMessageUpdate>().From.Should().Be(agent.Name);
}
}
}
[ApiKeyFact("ANTHROPIC_API_KEY")]
public async Task AnthropicAgentTestMultiModalAsync()
{
var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
var agent = new AnthropicClientAgent(
client,
name: "AnthropicAgent",
AnthropicConstants.Claude3Haiku)
.RegisterMessageConnector();
var image = Path.Combine("images", "square.png");
var binaryData = BinaryData.FromBytes(await File.ReadAllBytesAsync(image), "image/png");
var imageMessage = new ImageMessage(Role.User, binaryData);
var textMessage = new TextMessage(Role.User, "What's in this image?");
var multiModalMessage = new MultiModalMessage(Role.User, [textMessage, imageMessage]);
var reply = await agent.SendAsync(multiModalMessage);
reply.Should().BeOfType<TextMessage>();
reply.GetRole().Should().Be(Role.Assistant);
reply.GetContent().Should().NotBeNullOrEmpty();
reply.From.Should().Be(agent.Name);
}
[ApiKeyFact("ANTHROPIC_API_KEY")]
public async Task AnthropicAgentTestImageMessageAsync()
{
var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
var agent = new AnthropicClientAgent(
client,
name: "AnthropicAgent",
AnthropicConstants.Claude3Haiku,
systemMessage: "You are a helpful AI assistant that is capable of determining what an image is. Tell me a brief description of the image."
)
.RegisterMessageConnector();
var image = Path.Combine("images", "square.png");
var binaryData = BinaryData.FromBytes(await File.ReadAllBytesAsync(image), "image/png");
var imageMessage = new ImageMessage(Role.User, binaryData);
var reply = await agent.SendAsync(imageMessage);
reply.Should().BeOfType<TextMessage>();
reply.GetRole().Should().Be(Role.Assistant);
reply.GetContent().Should().NotBeNullOrEmpty();
reply.From.Should().Be(agent.Name);
}
}

View File

@ -7,7 +7,7 @@ using AutoGen.Tests;
using FluentAssertions;
using Xunit;
namespace AutoGen.Anthropic;
namespace AutoGen.Anthropic.Tests;
public class AnthropicClientTests
{
@ -73,6 +73,41 @@ public class AnthropicClientTests
results.First().streamingMessage!.Role.Should().Be("assistant");
}
[ApiKeyFact("ANTHROPIC_API_KEY")]
public async Task AnthropicClientImageChatCompletionTestAsync()
{
var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
var request = new ChatCompletionRequest();
request.Model = AnthropicConstants.Claude3Haiku;
request.Stream = false;
request.MaxTokens = 100;
request.SystemMessage = "You are a LLM that is suppose to describe the content of the image. Give me a description of the provided image.";
var base64Image = await AnthropicTestUtils.Base64FromImageAsync("square.png");
var messages = new List<ChatMessage>
{
new("user",
[
new ImageContent { Source = new ImageSource {MediaType = "image/png", Data = base64Image} }
])
};
request.Messages = messages;
var response = await anthropicClient.CreateChatCompletionsAsync(request, CancellationToken.None);
Assert.NotNull(response);
Assert.NotNull(response.Content);
Assert.NotEmpty(response.Content);
response.Content.Count.Should().Be(1);
response.Content.First().Should().BeOfType<TextContent>();
var textContent = (TextContent)response.Content.First();
Assert.Equal("text", textContent.Type);
Assert.NotNull(response.Usage);
response.Usage.OutputTokens.Should().BeGreaterThan(0);
}
private sealed class Person
{
[JsonPropertyName("name")]

View File

@ -1,10 +1,16 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AnthropicTestUtils.cs
namespace AutoGen.Anthropic;
namespace AutoGen.Anthropic.Tests;
public static class AnthropicTestUtils
{
public static string ApiKey => Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY") ??
throw new Exception("Please set ANTHROPIC_API_KEY environment variable.");
public static async Task<string> Base64FromImageAsync(string imageName)
{
return Convert.ToBase64String(
await File.ReadAllBytesAsync(Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "images", imageName)));
}
}

View File

@ -13,4 +13,10 @@
<ProjectReference Include="..\..\src\AutoGen.Anthropic\AutoGen.Anthropic.csproj" />
<ProjectReference Include="..\AutoGen.Tests\AutoGen.Tests.csproj" />
</ItemGroup>
<ItemGroup>
<None Update="images\square.png">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>
</Project>

View File

@ -0,0 +1 @@
square.png filter=lfs diff=lfs merge=lfs -text

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8341030e5b93aab2c55dcd40ffa26ced8e42cc15736a8348176ffd155ad2d937
size 8167