This commit is contained in:
David Luong 2024-09-25 08:29:02 -07:00 committed by GitHub
commit 5206f63c64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 428 additions and 21 deletions

View File

@ -0,0 +1,113 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Create_Ollama_Agent_With_Tool.cs
using AutoGen.Core;
using AutoGen.Ollama.Extension;
using FluentAssertions;
namespace AutoGen.Ollama.Sample;
#region WeatherFunction
public partial class WeatherFunction
{
/// <summary>
/// Gets the weather based on the location and the unit
/// </summary>
/// <param name="location"></param>
/// <param name="unit"></param>
/// <returns></returns>
[Function]
public async Task<string> GetWeather(string location, string unit)
{
// dummy implementation
return $"The weather in {location} is currently sunny with a tempature of {unit} (s)";
}
}
#endregion
public class Create_Ollama_Agent_With_Tool
{
public static async Task RunAsync()
{
#region define_tool
var tool = new Tool()
{
Function = new Function
{
Name = "get_current_weather",
Description = "Get the current weather for a location",
Parameters = new Parameters
{
Properties = new Dictionary<string, Properties>
{
{
"location",
new Properties
{
Type = "string", Description = "The location to get the weather for, e.g. San Francisco, CA"
}
},
{
"format", new Properties
{
Type = "string",
Description =
"The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
Enum = new List<string> {"celsius", "fahrenheit"}
}
}
},
Required = new List<string> { "location", "format" }
}
}
};
var weatherFunction = new WeatherFunction();
var functionMiddleware = new FunctionCallMiddleware(
functions: [
weatherFunction.GetWeatherFunctionContract,
],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ weatherFunction.GetWeatherFunctionContract.Name!, weatherFunction.GetWeatherWrapper },
});
#endregion
#region create_ollama_agent_llama3.1
var agent = new OllamaAgent(
new HttpClient { BaseAddress = new Uri("http://localhost:11434") },
"MyAgent",
"llama3.1",
tools: [tool]);
#endregion
// TODO cannot stream
#region register_middleware
var agentWithConnector = agent
.RegisterMessageConnector()
.RegisterPrintMessage()
.RegisterStreamingMiddleware(functionMiddleware);
#endregion register_middleware
#region single_turn
var question = new TextMessage(Role.Assistant,
"What is the weather like in San Francisco?",
from: "user");
var functionCallReply = await agentWithConnector.SendAsync(question);
#endregion
#region Single_turn_verify_reply
functionCallReply.Should().BeOfType<ToolCallAggregateMessage>();
#endregion Single_turn_verify_reply
#region Multi_turn
var finalReply = await agentWithConnector.SendAsync(chatHistory: [question, functionCallReply]);
#endregion Multi_turn
#region Multi_turn_verify_reply
finalReply.Should().BeOfType<TextMessage>();
#endregion Multi_turn_verify_reply
}
}

View File

@ -3,4 +3,4 @@
using AutoGen.Ollama.Sample;
await Chat_With_LLaVA.RunAsync();
await Create_Ollama_Agent_With_Tool.RunAsync();

View File

@ -24,16 +24,23 @@ public class OllamaAgent : IStreamingAgent
private readonly string _modelName;
private readonly string _systemMessage;
private readonly OllamaReplyOptions? _replyOptions;
private readonly Tool[]? _tools;
public OllamaAgent(HttpClient httpClient, string name, string modelName,
string systemMessage = "You are a helpful AI assistant",
OllamaReplyOptions? replyOptions = null)
OllamaReplyOptions? replyOptions = null, Tool[]? tools = null)
{
Name = name;
_httpClient = httpClient;
_modelName = modelName;
_systemMessage = systemMessage;
_replyOptions = replyOptions;
_tools = tools;
if (_httpClient.BaseAddress == null)
{
throw new InvalidOperationException($"Please add the base address to httpClient");
}
}
public async Task<IMessage> GenerateReplyAsync(
@ -97,7 +104,8 @@ public class OllamaAgent : IStreamingAgent
var request = new ChatRequest
{
Model = _modelName,
Messages = await BuildChatHistory(messages)
Messages = await BuildChatHistory(messages),
Tools = _tools
};
if (options is OllamaReplyOptions replyOptions)

View File

@ -50,4 +50,11 @@ public class ChatRequest
[JsonPropertyName("keep_alive")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? KeepAlive { get; set; }
/// <summary>
/// Tools for the model to use. Not all models currently support tools.
/// Requires stream to be set to false
/// </summary>
[JsonPropertyName("tools")]
public IEnumerable<Tool>? Tools { get; set; }
}

View File

@ -12,7 +12,7 @@ public class Message
{
}
public Message(string role, string value)
public Message(string role, string? value = null)
{
Role = role;
Value = value;
@ -27,11 +27,34 @@ public class Message
/// the content of the message
/// </summary>
[JsonPropertyName("content")]
public string Value { get; set; } = string.Empty;
public string? Value { get; set; }
/// <summary>
/// (optional): a list of images to include in the message (for multimodal models such as llava)
/// </summary>
[JsonPropertyName("images")]
public IList<string>? Images { get; set; }
/// <summary>
/// A list of tools the model wants to use. Not all models currently support tools.
/// Tool call is not supported while streaming.
/// </summary>
[JsonPropertyName("tool_calls")]
public IEnumerable<ToolCall>? ToolCalls { get; set; }
public class ToolCall
{
[JsonPropertyName("function")]
public Function? Function { get; set; }
}
public class Function
{
[JsonPropertyName("name")]
public string? Name { get; set; }
[JsonPropertyName("arguments")]
public Dictionary<string, string>? Arguments { get; set; }
}
}

View File

@ -0,0 +1,52 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Tools.cs
using System.Collections.Generic;
using System.Text.Json.Serialization;
namespace AutoGen.Ollama;
public class Tool
{
[JsonPropertyName("type")]
public string? Type { get; set; } = "function";
[JsonPropertyName("function")]
public Function? Function { get; set; }
}
public class Function
{
[JsonPropertyName("name")]
public string? Name { get; set; }
[JsonPropertyName("description")]
public string? Description { get; set; }
[JsonPropertyName("parameters")]
public Parameters? Parameters { get; set; }
}
public class Parameters
{
[JsonPropertyName("type")]
public string? Type { get; set; } = "object";
[JsonPropertyName("properties")]
public Dictionary<string, Properties>? Properties { get; set; }
[JsonPropertyName("required")]
public IEnumerable<string>? Required { get; set; }
}
public class Properties
{
[JsonPropertyName("type")]
public string? Type { get; set; }
[JsonPropertyName("description")]
public string? Description { get; set; }
[JsonPropertyName("enum")]
public IEnumerable<string>? Enum { get; set; }
}

View File

@ -6,6 +6,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.Core;
@ -24,6 +25,7 @@ public class OllamaMessageConnector : IStreamingMiddleware
return reply switch
{
IMessage<ChatResponse> { Content.Message.ToolCalls: not null } messageEnvelope when messageEnvelope.Content.Message.ToolCalls.Any() => ProcessToolCalls(messageEnvelope, agent),
IMessage<ChatResponse> messageEnvelope when messageEnvelope.Content.Message?.Value is string content => new TextMessage(Role.Assistant, content, messageEnvelope.From),
IMessage<ChatResponse> messageEnvelope when messageEnvelope.Content.Message?.Value is null => throw new InvalidOperationException("Message content is null"),
_ => reply
@ -73,20 +75,21 @@ public class OllamaMessageConnector : IStreamingMiddleware
{
return messages.SelectMany(m =>
{
if (m is IMessage<Message> messageEnvelope)
if (m is IMessage<Message>)
{
return [m];
}
else
return m switch
{
return m switch
{
TextMessage textMessage => ProcessTextMessage(textMessage, agent),
ImageMessage imageMessage => ProcessImageMessage(imageMessage, agent),
MultiModalMessage multiModalMessage => ProcessMultiModalMessage(multiModalMessage, agent),
_ => [m],
};
}
TextMessage textMessage => ProcessTextMessage(textMessage, agent),
ImageMessage imageMessage => ProcessImageMessage(imageMessage, agent),
ToolCallMessage toolCallMessage => ProcessToolCallMessage(toolCallMessage, agent),
ToolCallResultMessage toolCallResultMessage => ProcessToolCallResultMessage(toolCallResultMessage),
AggregateMessage<ToolCallMessage, ToolCallResultMessage> toolCallAggregateMessage => ProcessToolCallAggregateMessage(toolCallAggregateMessage, agent),
MultiModalMessage multiModalMessage => ProcessMultiModalMessage(multiModalMessage, agent),
_ => [m],
};
});
}
@ -183,4 +186,64 @@ public class OllamaMessageConnector : IStreamingMiddleware
return [MessageEnvelope.Create(message, agent.Name)];
}
}
private IMessage ProcessToolCalls(IMessage<ChatResponse> messageEnvelope, IAgent agent)
{
var toolCalls = new List<ToolCall>();
foreach (var messageToolCall in messageEnvelope.Content.Message?.ToolCalls!)
{
toolCalls.Add(new ToolCall(
messageToolCall.Function?.Name ?? string.Empty,
JsonSerializer.Serialize(messageToolCall.Function?.Arguments)));
}
return new ToolCallMessage(toolCalls, agent.Name) { Content = messageEnvelope.Content.Message.Value };
}
private IEnumerable<IMessage> ProcessToolCallMessage(ToolCallMessage toolCallMessage, IAgent agent)
{
var chatMessage = new Message(toolCallMessage.From ?? string.Empty, toolCallMessage.GetContent())
{
ToolCalls = toolCallMessage.ToolCalls.Select(t => new Message.ToolCall
{
Function = new Message.Function
{
Name = t.FunctionName,
Arguments = JsonSerializer.Deserialize<Dictionary<string, string>>(t.FunctionArguments),
},
}),
};
return [MessageEnvelope.Create(chatMessage, toolCallMessage.From)];
}
private IEnumerable<IMessage> ProcessToolCallResultMessage(ToolCallResultMessage toolCallResultMessage)
{
foreach (var toolCall in toolCallResultMessage.ToolCalls)
{
if (!string.IsNullOrEmpty(toolCall.Result))
{
return [MessageEnvelope.Create(new Message("tool", toolCall.Result), toolCallResultMessage.From)];
}
}
throw new InvalidOperationException("Expected to have at least one tool call result");
}
private IEnumerable<IMessage> ProcessToolCallAggregateMessage(AggregateMessage<ToolCallMessage, ToolCallResultMessage> toolCallAggregateMessage, IAgent agent)
{
if (toolCallAggregateMessage.From is { } from && from != agent.Name)
{
var contents = toolCallAggregateMessage.Message2.ToolCalls.Select(t => t.Result);
var messages =
contents.Select(c => new Message("assistant", c ?? throw new ArgumentNullException(nameof(c))));
return messages.Select(m => new MessageEnvelope<Message>(m, from: from));
}
var toolCallMessage = ProcessToolCallMessage(toolCallAggregateMessage.Message1, agent);
var toolCallResult = ProcessToolCallResultMessage(toolCallAggregateMessage.Message2);
return toolCallMessage.Concat(toolCallResult);
}
}

View File

@ -170,7 +170,7 @@ public class AnthropicClientAgentTest
)
.RegisterMessageConnector();
var weatherFunctionArgumets = """
var weatherFunctionArguments = """
{
"city": "Philadelphia",
"date": "6/14/2024"
@ -178,8 +178,8 @@ public class AnthropicClientAgentTest
""";
var function = new AnthropicTestFunctionCalls();
var functionCallResult = await function.GetWeatherReportWrapper(weatherFunctionArgumets);
var toolCall = new ToolCall(function.WeatherReportFunctionContract.Name!, weatherFunctionArgumets)
var functionCallResult = await function.GetWeatherReportWrapper(weatherFunctionArguments);
var toolCall = new ToolCall(function.WeatherReportFunctionContract.Name!, weatherFunctionArguments)
{
ToolCallId = "get_weather",
Result = functionCallResult,

View File

@ -11,6 +11,7 @@
<ItemGroup>
<ProjectReference Include="..\..\src\AutoGen.Ollama\AutoGen.Ollama.csproj" />
<ProjectReference Include="..\AutoGen.Tests\AutoGen.Tests.csproj" />
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
</ItemGroup>
<ItemGroup>

View File

@ -6,6 +6,7 @@ using AutoGen.Core;
using AutoGen.Ollama.Extension;
using AutoGen.Tests;
using FluentAssertions;
using Xunit;
namespace AutoGen.Ollama.Tests;
@ -49,7 +50,7 @@ public class OllamaAgentTests
result.Should().BeOfType<MessageEnvelope<ChatResponse>>();
result.From.Should().Be(ollamaAgent.Name);
string jsonContent = ((MessageEnvelope<ChatResponse>)result).Content.Message!.Value;
string jsonContent = ((MessageEnvelope<ChatResponse>)result).Content.Message!.Value ?? string.Empty;
bool isValidJson = IsValidJsonMessage(jsonContent);
isValidJson.Should().BeTrue();
}
@ -195,6 +196,66 @@ public class OllamaAgentTests
update.TotalDuration.Should().BeGreaterThan(0);
}
[Fact]
public async Task GenerateReplyAsync_ReturnsValidToolMessage()
{
var host = @" http://localhost:11434";
var modelName = "llama3.1";
var ollamaAgent = BuildOllamaAgent(host, modelName, [OllamaTestUtils.WeatherTool]);
var message = new Message("user", "What is the weather today?");
var messages = new IMessage[] { MessageEnvelope.Create(message, from: modelName) };
var result = await ollamaAgent.GenerateReplyAsync(messages);
result.Should().BeOfType<MessageEnvelope<ChatResponse>>();
var chatResponse = ((MessageEnvelope<ChatResponse>)result).Content;
chatResponse.Message.Should().BeOfType<Message>();
chatResponse.Message.Should().NotBeNull();
var toolCall = chatResponse.Message!.ToolCalls!.First();
toolCall.Function.Should().NotBeNull();
toolCall.Function!.Name.Should().Be("get_current_weather");
toolCall.Function!.Arguments.Should().ContainKey("location");
toolCall.Function!.Arguments!["location"].Should().Be("San Francisco, CA");
toolCall.Function!.Arguments!.Should().ContainKey("format");
toolCall.Function!.Arguments!["format"].Should().BeOneOf("celsius", "fahrenheit");
}
[Fact]
public async Task OllamaAgentFunctionCallMessageTest()
{
var host = @" http://localhost:11434";
var modelName = "llama3.1";
var weatherFunctionArguments = """
{
"city": "Philadelphia",
"date": "6/14/2024"
}
""";
var function = new OllamaTestFunctionCalls();
var functionCallResult = await function.GetWeatherReportWrapper(weatherFunctionArguments);
var toolCall = new ToolCall(function.WeatherReportFunctionContract.Name!, weatherFunctionArguments)
{
ToolCallId = "get_weather",
Result = functionCallResult,
};
var ollamaAgent = BuildOllamaAgent(host, modelName, [OllamaTestUtils.WeatherTool]).RegisterMessageConnector();
IMessage[] chatHistory = [
new TextMessage(Role.User, "what's the weather in Philadelphia?"),
new ToolCallMessage([toolCall], from: "assistant"),
new ToolCallResultMessage([toolCall], from: "user"),
];
var reply = await ollamaAgent.SendAsync(chatHistory: chatHistory);
reply.Should().BeOfType<TextMessage>();
reply.GetContent().Should().Contain("Philadelphia");
reply.GetContent().Should().Contain("sunny");
}
private static bool IsValidJsonMessage(string input)
{
try
@ -213,12 +274,12 @@ public class OllamaAgentTests
}
}
private static OllamaAgent BuildOllamaAgent(string host, string modelName)
private static OllamaAgent BuildOllamaAgent(string host, string modelName, Tool[]? tools = null)
{
var httpClient = new HttpClient
{
BaseAddress = new Uri(host)
};
return new OllamaAgent(httpClient, "TestAgent", modelName);
return new OllamaAgent(httpClient, "TestAgent", modelName, tools: tools);
}
}

View File

@ -0,0 +1,40 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// OllamaTestFunctionCalls.cs
using System.Text.Json;
using System.Text.Json.Serialization;
using AutoGen.Core;
namespace AutoGen.Ollama.Tests;
public partial class OllamaTestFunctionCalls
{
private class GetWeatherSchema
{
[JsonPropertyName("city")]
public string? City { get; set; }
[JsonPropertyName("date")]
public string? Date { get; set; }
}
/// <summary>
/// Get weather report
/// </summary>
/// <param name="city">city</param>
/// <param name="date">date</param>
[Function]
public async Task<string> WeatherReport(string city, string date)
{
return $"Weather report for {city} on {date} is sunny";
}
public Task<string> GetWeatherReportWrapper(string arguments)
{
var schema = JsonSerializer.Deserialize<GetWeatherSchema>(
arguments,
new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase });
return WeatherReport(schema?.City ?? string.Empty, schema?.Date ?? string.Empty);
}
}

View File

@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// OllamaTestUtils.cs
namespace AutoGen.Ollama.Tests;
public static class OllamaTestUtils
{
public static Tool WeatherTool => new()
{
Function = new Function
{
Name = "get_current_weather",
Description = "Get the current weather for a location",
Parameters = new Parameters
{
Properties = new Dictionary<string, Properties>
{
{
"location",
new Properties
{
Type = "string", Description = "The location to get the weather for, e.g. San Francisco, CA"
}
},
{
"format", new Properties
{
Type = "string",
Description =
"The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
Enum = new List<string> {"celsius", "fahrenheit"}
}
}
},
Required = new List<string> { "location", "format" }
}
}
};
}