mirror of https://github.com/microsoft/autogen.git
Compare commits
7 Commits
5924522085
...
d232fb1775
Author | SHA1 | Date |
---|---|---|
David Luong | d232fb1775 | |
Eric Zhu | 8a8fcd8906 | |
Joel Klaverkamp | 23c14bc937 | |
Victor Dibia | 9f428880b3 | |
Xiaoyun Zhang | cc5d24ee30 | |
Xiaoyun Zhang | 7e328d0e97 | |
David Luong | c78e7a4ec3 |
|
@ -1069,7 +1069,7 @@ class OpenAIWrapper:
|
||||||
|
|
||||||
def _throttle_api_calls(self, idx: int) -> None:
|
def _throttle_api_calls(self, idx: int) -> None:
|
||||||
"""Rate limit api calls."""
|
"""Rate limit api calls."""
|
||||||
if self._rate_limiters[idx]:
|
if idx < len(self._rate_limiters) and self._rate_limiters[idx]:
|
||||||
limiter = self._rate_limiters[idx]
|
limiter = self._rate_limiters[idx]
|
||||||
|
|
||||||
assert limiter is not None
|
assert limiter is not None
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
__version__ = "0.2.37"
|
__version__ = "0.2.38"
|
||||||
|
|
|
@ -0,0 +1,113 @@
|
||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Create_Ollama_Agent_With_Tool.cs
|
||||||
|
|
||||||
|
using AutoGen.Core;
|
||||||
|
using AutoGen.Ollama.Extension;
|
||||||
|
using FluentAssertions;
|
||||||
|
|
||||||
|
namespace AutoGen.Ollama.Sample;
|
||||||
|
|
||||||
|
#region WeatherFunction
|
||||||
|
public partial class WeatherFunction
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Gets the weather based on the location and the unit
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="location"></param>
|
||||||
|
/// <param name="unit"></param>
|
||||||
|
/// <returns></returns>
|
||||||
|
[Function]
|
||||||
|
public async Task<string> GetWeather(string location, string unit)
|
||||||
|
{
|
||||||
|
// dummy implementation
|
||||||
|
return $"The weather in {location} is currently sunny with a tempature of {unit} (s)";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endregion
|
||||||
|
|
||||||
|
public class Create_Ollama_Agent_With_Tool
|
||||||
|
{
|
||||||
|
public static async Task RunAsync()
|
||||||
|
{
|
||||||
|
#region define_tool
|
||||||
|
var tool = new Tool()
|
||||||
|
{
|
||||||
|
Function = new Function
|
||||||
|
{
|
||||||
|
Name = "get_current_weather",
|
||||||
|
Description = "Get the current weather for a location",
|
||||||
|
Parameters = new Parameters
|
||||||
|
{
|
||||||
|
Properties = new Dictionary<string, Properties>
|
||||||
|
{
|
||||||
|
{
|
||||||
|
"location",
|
||||||
|
new Properties
|
||||||
|
{
|
||||||
|
Type = "string", Description = "The location to get the weather for, e.g. San Francisco, CA"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"format", new Properties
|
||||||
|
{
|
||||||
|
Type = "string",
|
||||||
|
Description =
|
||||||
|
"The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
|
||||||
|
Enum = new List<string> {"celsius", "fahrenheit"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Required = new List<string> { "location", "format" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
var weatherFunction = new WeatherFunction();
|
||||||
|
var functionMiddleware = new FunctionCallMiddleware(
|
||||||
|
functions: [
|
||||||
|
weatherFunction.GetWeatherFunctionContract,
|
||||||
|
],
|
||||||
|
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||||
|
{
|
||||||
|
{ weatherFunction.GetWeatherFunctionContract.Name!, weatherFunction.GetWeatherWrapper },
|
||||||
|
});
|
||||||
|
|
||||||
|
#endregion
|
||||||
|
|
||||||
|
#region create_ollama_agent_llama3.1
|
||||||
|
|
||||||
|
var agent = new OllamaAgent(
|
||||||
|
new HttpClient { BaseAddress = new Uri("http://localhost:11434") },
|
||||||
|
"MyAgent",
|
||||||
|
"llama3.1",
|
||||||
|
tools: [tool]);
|
||||||
|
#endregion
|
||||||
|
|
||||||
|
// TODO cannot stream
|
||||||
|
#region register_middleware
|
||||||
|
var agentWithConnector = agent
|
||||||
|
.RegisterMessageConnector()
|
||||||
|
.RegisterPrintMessage()
|
||||||
|
.RegisterStreamingMiddleware(functionMiddleware);
|
||||||
|
#endregion register_middleware
|
||||||
|
|
||||||
|
#region single_turn
|
||||||
|
var question = new TextMessage(Role.Assistant,
|
||||||
|
"What is the weather like in San Francisco?",
|
||||||
|
from: "user");
|
||||||
|
var functionCallReply = await agentWithConnector.SendAsync(question);
|
||||||
|
#endregion
|
||||||
|
|
||||||
|
#region Single_turn_verify_reply
|
||||||
|
functionCallReply.Should().BeOfType<ToolCallAggregateMessage>();
|
||||||
|
#endregion Single_turn_verify_reply
|
||||||
|
|
||||||
|
#region Multi_turn
|
||||||
|
var finalReply = await agentWithConnector.SendAsync(chatHistory: [question, functionCallReply]);
|
||||||
|
#endregion Multi_turn
|
||||||
|
|
||||||
|
#region Multi_turn_verify_reply
|
||||||
|
finalReply.Should().BeOfType<TextMessage>();
|
||||||
|
#endregion Multi_turn_verify_reply
|
||||||
|
}
|
||||||
|
}
|
|
@ -3,4 +3,4 @@
|
||||||
|
|
||||||
using AutoGen.Ollama.Sample;
|
using AutoGen.Ollama.Sample;
|
||||||
|
|
||||||
await Chat_With_LLaVA.RunAsync();
|
await Create_Ollama_Agent_With_Tool.RunAsync();
|
||||||
|
|
|
@ -24,16 +24,23 @@ public class OllamaAgent : IStreamingAgent
|
||||||
private readonly string _modelName;
|
private readonly string _modelName;
|
||||||
private readonly string _systemMessage;
|
private readonly string _systemMessage;
|
||||||
private readonly OllamaReplyOptions? _replyOptions;
|
private readonly OllamaReplyOptions? _replyOptions;
|
||||||
|
private readonly Tool[]? _tools;
|
||||||
|
|
||||||
public OllamaAgent(HttpClient httpClient, string name, string modelName,
|
public OllamaAgent(HttpClient httpClient, string name, string modelName,
|
||||||
string systemMessage = "You are a helpful AI assistant",
|
string systemMessage = "You are a helpful AI assistant",
|
||||||
OllamaReplyOptions? replyOptions = null)
|
OllamaReplyOptions? replyOptions = null, Tool[]? tools = null)
|
||||||
{
|
{
|
||||||
Name = name;
|
Name = name;
|
||||||
_httpClient = httpClient;
|
_httpClient = httpClient;
|
||||||
_modelName = modelName;
|
_modelName = modelName;
|
||||||
_systemMessage = systemMessage;
|
_systemMessage = systemMessage;
|
||||||
_replyOptions = replyOptions;
|
_replyOptions = replyOptions;
|
||||||
|
_tools = tools;
|
||||||
|
|
||||||
|
if (_httpClient.BaseAddress == null)
|
||||||
|
{
|
||||||
|
throw new InvalidOperationException($"Please add the base address to httpClient");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<IMessage> GenerateReplyAsync(
|
public async Task<IMessage> GenerateReplyAsync(
|
||||||
|
@ -97,7 +104,8 @@ public class OllamaAgent : IStreamingAgent
|
||||||
var request = new ChatRequest
|
var request = new ChatRequest
|
||||||
{
|
{
|
||||||
Model = _modelName,
|
Model = _modelName,
|
||||||
Messages = await BuildChatHistory(messages)
|
Messages = await BuildChatHistory(messages),
|
||||||
|
Tools = _tools
|
||||||
};
|
};
|
||||||
|
|
||||||
if (options is OllamaReplyOptions replyOptions)
|
if (options is OllamaReplyOptions replyOptions)
|
||||||
|
|
|
@ -50,4 +50,11 @@ public class ChatRequest
|
||||||
[JsonPropertyName("keep_alive")]
|
[JsonPropertyName("keep_alive")]
|
||||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
|
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
|
||||||
public string? KeepAlive { get; set; }
|
public string? KeepAlive { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tools for the model to use. Not all models currently support tools.
|
||||||
|
/// Requires stream to be set to false
|
||||||
|
/// </summary>
|
||||||
|
[JsonPropertyName("tools")]
|
||||||
|
public IEnumerable<Tool>? Tools { get; set; }
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,7 @@ public class Message
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
public Message(string role, string value)
|
public Message(string role, string? value = null)
|
||||||
{
|
{
|
||||||
Role = role;
|
Role = role;
|
||||||
Value = value;
|
Value = value;
|
||||||
|
@ -27,11 +27,34 @@ public class Message
|
||||||
/// the content of the message
|
/// the content of the message
|
||||||
/// </summary>
|
/// </summary>
|
||||||
[JsonPropertyName("content")]
|
[JsonPropertyName("content")]
|
||||||
public string Value { get; set; } = string.Empty;
|
public string? Value { get; set; }
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// (optional): a list of images to include in the message (for multimodal models such as llava)
|
/// (optional): a list of images to include in the message (for multimodal models such as llava)
|
||||||
/// </summary>
|
/// </summary>
|
||||||
[JsonPropertyName("images")]
|
[JsonPropertyName("images")]
|
||||||
public IList<string>? Images { get; set; }
|
public IList<string>? Images { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// A list of tools the model wants to use. Not all models currently support tools.
|
||||||
|
/// Tool call is not supported while streaming.
|
||||||
|
/// </summary>
|
||||||
|
[JsonPropertyName("tool_calls")]
|
||||||
|
public IEnumerable<ToolCall>? ToolCalls { get; set; }
|
||||||
|
|
||||||
|
public class ToolCall
|
||||||
|
{
|
||||||
|
[JsonPropertyName("function")]
|
||||||
|
public Function? Function { get; set; }
|
||||||
|
}
|
||||||
|
|
||||||
|
public class Function
|
||||||
|
{
|
||||||
|
[JsonPropertyName("name")]
|
||||||
|
public string? Name { get; set; }
|
||||||
|
|
||||||
|
[JsonPropertyName("arguments")]
|
||||||
|
public Dictionary<string, string>? Arguments { get; set; }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Tools.cs
|
||||||
|
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using System.Text.Json.Serialization;
|
||||||
|
|
||||||
|
namespace AutoGen.Ollama;
|
||||||
|
|
||||||
|
public class Tool
|
||||||
|
{
|
||||||
|
[JsonPropertyName("type")]
|
||||||
|
public string? Type { get; set; } = "function";
|
||||||
|
|
||||||
|
[JsonPropertyName("function")]
|
||||||
|
public Function? Function { get; set; }
|
||||||
|
}
|
||||||
|
|
||||||
|
public class Function
|
||||||
|
{
|
||||||
|
[JsonPropertyName("name")]
|
||||||
|
public string? Name { get; set; }
|
||||||
|
|
||||||
|
[JsonPropertyName("description")]
|
||||||
|
public string? Description { get; set; }
|
||||||
|
|
||||||
|
[JsonPropertyName("parameters")]
|
||||||
|
public Parameters? Parameters { get; set; }
|
||||||
|
}
|
||||||
|
|
||||||
|
public class Parameters
|
||||||
|
{
|
||||||
|
[JsonPropertyName("type")]
|
||||||
|
public string? Type { get; set; } = "object";
|
||||||
|
|
||||||
|
[JsonPropertyName("properties")]
|
||||||
|
public Dictionary<string, Properties>? Properties { get; set; }
|
||||||
|
|
||||||
|
[JsonPropertyName("required")]
|
||||||
|
public IEnumerable<string>? Required { get; set; }
|
||||||
|
}
|
||||||
|
|
||||||
|
public class Properties
|
||||||
|
{
|
||||||
|
[JsonPropertyName("type")]
|
||||||
|
public string? Type { get; set; }
|
||||||
|
|
||||||
|
[JsonPropertyName("description")]
|
||||||
|
public string? Description { get; set; }
|
||||||
|
|
||||||
|
[JsonPropertyName("enum")]
|
||||||
|
public IEnumerable<string>? Enum { get; set; }
|
||||||
|
}
|
|
@ -6,6 +6,7 @@ using System.Collections.Generic;
|
||||||
using System.Linq;
|
using System.Linq;
|
||||||
using System.Net.Http;
|
using System.Net.Http;
|
||||||
using System.Runtime.CompilerServices;
|
using System.Runtime.CompilerServices;
|
||||||
|
using System.Text.Json;
|
||||||
using System.Threading;
|
using System.Threading;
|
||||||
using System.Threading.Tasks;
|
using System.Threading.Tasks;
|
||||||
using AutoGen.Core;
|
using AutoGen.Core;
|
||||||
|
@ -24,6 +25,7 @@ public class OllamaMessageConnector : IStreamingMiddleware
|
||||||
|
|
||||||
return reply switch
|
return reply switch
|
||||||
{
|
{
|
||||||
|
IMessage<ChatResponse> { Content.Message.ToolCalls: not null } messageEnvelope when messageEnvelope.Content.Message.ToolCalls.Any() => ProcessToolCalls(messageEnvelope, agent),
|
||||||
IMessage<ChatResponse> messageEnvelope when messageEnvelope.Content.Message?.Value is string content => new TextMessage(Role.Assistant, content, messageEnvelope.From),
|
IMessage<ChatResponse> messageEnvelope when messageEnvelope.Content.Message?.Value is string content => new TextMessage(Role.Assistant, content, messageEnvelope.From),
|
||||||
IMessage<ChatResponse> messageEnvelope when messageEnvelope.Content.Message?.Value is null => throw new InvalidOperationException("Message content is null"),
|
IMessage<ChatResponse> messageEnvelope when messageEnvelope.Content.Message?.Value is null => throw new InvalidOperationException("Message content is null"),
|
||||||
_ => reply
|
_ => reply
|
||||||
|
@ -73,20 +75,21 @@ public class OllamaMessageConnector : IStreamingMiddleware
|
||||||
{
|
{
|
||||||
return messages.SelectMany(m =>
|
return messages.SelectMany(m =>
|
||||||
{
|
{
|
||||||
if (m is IMessage<Message> messageEnvelope)
|
if (m is IMessage<Message>)
|
||||||
{
|
{
|
||||||
return [m];
|
return [m];
|
||||||
}
|
}
|
||||||
else
|
|
||||||
|
return m switch
|
||||||
{
|
{
|
||||||
return m switch
|
TextMessage textMessage => ProcessTextMessage(textMessage, agent),
|
||||||
{
|
ImageMessage imageMessage => ProcessImageMessage(imageMessage, agent),
|
||||||
TextMessage textMessage => ProcessTextMessage(textMessage, agent),
|
ToolCallMessage toolCallMessage => ProcessToolCallMessage(toolCallMessage, agent),
|
||||||
ImageMessage imageMessage => ProcessImageMessage(imageMessage, agent),
|
ToolCallResultMessage toolCallResultMessage => ProcessToolCallResultMessage(toolCallResultMessage),
|
||||||
MultiModalMessage multiModalMessage => ProcessMultiModalMessage(multiModalMessage, agent),
|
AggregateMessage<ToolCallMessage, ToolCallResultMessage> toolCallAggregateMessage => ProcessToolCallAggregateMessage(toolCallAggregateMessage, agent),
|
||||||
_ => [m],
|
MultiModalMessage multiModalMessage => ProcessMultiModalMessage(multiModalMessage, agent),
|
||||||
};
|
_ => [m],
|
||||||
}
|
};
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -183,4 +186,64 @@ public class OllamaMessageConnector : IStreamingMiddleware
|
||||||
return [MessageEnvelope.Create(message, agent.Name)];
|
return [MessageEnvelope.Create(message, agent.Name)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private IMessage ProcessToolCalls(IMessage<ChatResponse> messageEnvelope, IAgent agent)
|
||||||
|
{
|
||||||
|
var toolCalls = new List<ToolCall>();
|
||||||
|
foreach (var messageToolCall in messageEnvelope.Content.Message?.ToolCalls!)
|
||||||
|
{
|
||||||
|
toolCalls.Add(new ToolCall(
|
||||||
|
messageToolCall.Function?.Name ?? string.Empty,
|
||||||
|
JsonSerializer.Serialize(messageToolCall.Function?.Arguments)));
|
||||||
|
}
|
||||||
|
|
||||||
|
return new ToolCallMessage(toolCalls, agent.Name) { Content = messageEnvelope.Content.Message.Value };
|
||||||
|
}
|
||||||
|
|
||||||
|
private IEnumerable<IMessage> ProcessToolCallMessage(ToolCallMessage toolCallMessage, IAgent agent)
|
||||||
|
{
|
||||||
|
var chatMessage = new Message(toolCallMessage.From ?? string.Empty, toolCallMessage.GetContent())
|
||||||
|
{
|
||||||
|
ToolCalls = toolCallMessage.ToolCalls.Select(t => new Message.ToolCall
|
||||||
|
{
|
||||||
|
Function = new Message.Function
|
||||||
|
{
|
||||||
|
Name = t.FunctionName,
|
||||||
|
Arguments = JsonSerializer.Deserialize<Dictionary<string, string>>(t.FunctionArguments),
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
return [MessageEnvelope.Create(chatMessage, toolCallMessage.From)];
|
||||||
|
}
|
||||||
|
|
||||||
|
private IEnumerable<IMessage> ProcessToolCallResultMessage(ToolCallResultMessage toolCallResultMessage)
|
||||||
|
{
|
||||||
|
foreach (var toolCall in toolCallResultMessage.ToolCalls)
|
||||||
|
{
|
||||||
|
if (!string.IsNullOrEmpty(toolCall.Result))
|
||||||
|
{
|
||||||
|
return [MessageEnvelope.Create(new Message("tool", toolCall.Result), toolCallResultMessage.From)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new InvalidOperationException("Expected to have at least one tool call result");
|
||||||
|
}
|
||||||
|
|
||||||
|
private IEnumerable<IMessage> ProcessToolCallAggregateMessage(AggregateMessage<ToolCallMessage, ToolCallResultMessage> toolCallAggregateMessage, IAgent agent)
|
||||||
|
{
|
||||||
|
if (toolCallAggregateMessage.From is { } from && from != agent.Name)
|
||||||
|
{
|
||||||
|
var contents = toolCallAggregateMessage.Message2.ToolCalls.Select(t => t.Result);
|
||||||
|
var messages =
|
||||||
|
contents.Select(c => new Message("assistant", c ?? throw new ArgumentNullException(nameof(c))));
|
||||||
|
|
||||||
|
return messages.Select(m => new MessageEnvelope<Message>(m, from: from));
|
||||||
|
}
|
||||||
|
|
||||||
|
var toolCallMessage = ProcessToolCallMessage(toolCallAggregateMessage.Message1, agent);
|
||||||
|
var toolCallResult = ProcessToolCallResultMessage(toolCallAggregateMessage.Message2);
|
||||||
|
|
||||||
|
return toolCallMessage.Concat(toolCallResult);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -170,7 +170,7 @@ public class AnthropicClientAgentTest
|
||||||
)
|
)
|
||||||
.RegisterMessageConnector();
|
.RegisterMessageConnector();
|
||||||
|
|
||||||
var weatherFunctionArgumets = """
|
var weatherFunctionArguments = """
|
||||||
{
|
{
|
||||||
"city": "Philadelphia",
|
"city": "Philadelphia",
|
||||||
"date": "6/14/2024"
|
"date": "6/14/2024"
|
||||||
|
@ -178,8 +178,8 @@ public class AnthropicClientAgentTest
|
||||||
""";
|
""";
|
||||||
|
|
||||||
var function = new AnthropicTestFunctionCalls();
|
var function = new AnthropicTestFunctionCalls();
|
||||||
var functionCallResult = await function.GetWeatherReportWrapper(weatherFunctionArgumets);
|
var functionCallResult = await function.GetWeatherReportWrapper(weatherFunctionArguments);
|
||||||
var toolCall = new ToolCall(function.WeatherReportFunctionContract.Name!, weatherFunctionArgumets)
|
var toolCall = new ToolCall(function.WeatherReportFunctionContract.Name!, weatherFunctionArguments)
|
||||||
{
|
{
|
||||||
ToolCallId = "get_weather",
|
ToolCallId = "get_weather",
|
||||||
Result = functionCallResult,
|
Result = functionCallResult,
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<ProjectReference Include="..\..\src\AutoGen.Ollama\AutoGen.Ollama.csproj" />
|
<ProjectReference Include="..\..\src\AutoGen.Ollama\AutoGen.Ollama.csproj" />
|
||||||
<ProjectReference Include="..\AutoGen.Tests\AutoGen.Tests.csproj" />
|
<ProjectReference Include="..\AutoGen.Tests\AutoGen.Tests.csproj" />
|
||||||
|
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
|
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
|
|
|
@ -6,6 +6,7 @@ using AutoGen.Core;
|
||||||
using AutoGen.Ollama.Extension;
|
using AutoGen.Ollama.Extension;
|
||||||
using AutoGen.Tests;
|
using AutoGen.Tests;
|
||||||
using FluentAssertions;
|
using FluentAssertions;
|
||||||
|
using Xunit;
|
||||||
|
|
||||||
namespace AutoGen.Ollama.Tests;
|
namespace AutoGen.Ollama.Tests;
|
||||||
|
|
||||||
|
@ -49,7 +50,7 @@ public class OllamaAgentTests
|
||||||
result.Should().BeOfType<MessageEnvelope<ChatResponse>>();
|
result.Should().BeOfType<MessageEnvelope<ChatResponse>>();
|
||||||
result.From.Should().Be(ollamaAgent.Name);
|
result.From.Should().Be(ollamaAgent.Name);
|
||||||
|
|
||||||
string jsonContent = ((MessageEnvelope<ChatResponse>)result).Content.Message!.Value;
|
string jsonContent = ((MessageEnvelope<ChatResponse>)result).Content.Message!.Value ?? string.Empty;
|
||||||
bool isValidJson = IsValidJsonMessage(jsonContent);
|
bool isValidJson = IsValidJsonMessage(jsonContent);
|
||||||
isValidJson.Should().BeTrue();
|
isValidJson.Should().BeTrue();
|
||||||
}
|
}
|
||||||
|
@ -195,6 +196,66 @@ public class OllamaAgentTests
|
||||||
update.TotalDuration.Should().BeGreaterThan(0);
|
update.TotalDuration.Should().BeGreaterThan(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task GenerateReplyAsync_ReturnsValidToolMessage()
|
||||||
|
{
|
||||||
|
var host = @" http://localhost:11434";
|
||||||
|
var modelName = "llama3.1";
|
||||||
|
|
||||||
|
var ollamaAgent = BuildOllamaAgent(host, modelName, [OllamaTestUtils.WeatherTool]);
|
||||||
|
var message = new Message("user", "What is the weather today?");
|
||||||
|
var messages = new IMessage[] { MessageEnvelope.Create(message, from: modelName) };
|
||||||
|
|
||||||
|
var result = await ollamaAgent.GenerateReplyAsync(messages);
|
||||||
|
|
||||||
|
result.Should().BeOfType<MessageEnvelope<ChatResponse>>();
|
||||||
|
var chatResponse = ((MessageEnvelope<ChatResponse>)result).Content;
|
||||||
|
chatResponse.Message.Should().BeOfType<Message>();
|
||||||
|
chatResponse.Message.Should().NotBeNull();
|
||||||
|
var toolCall = chatResponse.Message!.ToolCalls!.First();
|
||||||
|
toolCall.Function.Should().NotBeNull();
|
||||||
|
toolCall.Function!.Name.Should().Be("get_current_weather");
|
||||||
|
toolCall.Function!.Arguments.Should().ContainKey("location");
|
||||||
|
toolCall.Function!.Arguments!["location"].Should().Be("San Francisco, CA");
|
||||||
|
toolCall.Function!.Arguments!.Should().ContainKey("format");
|
||||||
|
toolCall.Function!.Arguments!["format"].Should().BeOneOf("celsius", "fahrenheit");
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task OllamaAgentFunctionCallMessageTest()
|
||||||
|
{
|
||||||
|
var host = @" http://localhost:11434";
|
||||||
|
var modelName = "llama3.1";
|
||||||
|
|
||||||
|
var weatherFunctionArguments = """
|
||||||
|
{
|
||||||
|
"city": "Philadelphia",
|
||||||
|
"date": "6/14/2024"
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
|
||||||
|
var function = new OllamaTestFunctionCalls();
|
||||||
|
var functionCallResult = await function.GetWeatherReportWrapper(weatherFunctionArguments);
|
||||||
|
var toolCall = new ToolCall(function.WeatherReportFunctionContract.Name!, weatherFunctionArguments)
|
||||||
|
{
|
||||||
|
ToolCallId = "get_weather",
|
||||||
|
Result = functionCallResult,
|
||||||
|
};
|
||||||
|
|
||||||
|
var ollamaAgent = BuildOllamaAgent(host, modelName, [OllamaTestUtils.WeatherTool]).RegisterMessageConnector();
|
||||||
|
IMessage[] chatHistory = [
|
||||||
|
new TextMessage(Role.User, "what's the weather in Philadelphia?"),
|
||||||
|
new ToolCallMessage([toolCall], from: "assistant"),
|
||||||
|
new ToolCallResultMessage([toolCall], from: "user"),
|
||||||
|
];
|
||||||
|
|
||||||
|
var reply = await ollamaAgent.SendAsync(chatHistory: chatHistory);
|
||||||
|
|
||||||
|
reply.Should().BeOfType<TextMessage>();
|
||||||
|
reply.GetContent().Should().Contain("Philadelphia");
|
||||||
|
reply.GetContent().Should().Contain("sunny");
|
||||||
|
}
|
||||||
|
|
||||||
private static bool IsValidJsonMessage(string input)
|
private static bool IsValidJsonMessage(string input)
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
|
@ -213,12 +274,12 @@ public class OllamaAgentTests
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static OllamaAgent BuildOllamaAgent(string host, string modelName)
|
private static OllamaAgent BuildOllamaAgent(string host, string modelName, Tool[]? tools = null)
|
||||||
{
|
{
|
||||||
var httpClient = new HttpClient
|
var httpClient = new HttpClient
|
||||||
{
|
{
|
||||||
BaseAddress = new Uri(host)
|
BaseAddress = new Uri(host)
|
||||||
};
|
};
|
||||||
return new OllamaAgent(httpClient, "TestAgent", modelName);
|
return new OllamaAgent(httpClient, "TestAgent", modelName, tools: tools);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// OllamaTestFunctionCalls.cs
|
||||||
|
|
||||||
|
using System.Text.Json;
|
||||||
|
using System.Text.Json.Serialization;
|
||||||
|
using AutoGen.Core;
|
||||||
|
|
||||||
|
namespace AutoGen.Ollama.Tests;
|
||||||
|
|
||||||
|
public partial class OllamaTestFunctionCalls
|
||||||
|
{
|
||||||
|
private class GetWeatherSchema
|
||||||
|
{
|
||||||
|
[JsonPropertyName("city")]
|
||||||
|
public string? City { get; set; }
|
||||||
|
|
||||||
|
[JsonPropertyName("date")]
|
||||||
|
public string? Date { get; set; }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Get weather report
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="city">city</param>
|
||||||
|
/// <param name="date">date</param>
|
||||||
|
[Function]
|
||||||
|
public async Task<string> WeatherReport(string city, string date)
|
||||||
|
{
|
||||||
|
return $"Weather report for {city} on {date} is sunny";
|
||||||
|
}
|
||||||
|
|
||||||
|
public Task<string> GetWeatherReportWrapper(string arguments)
|
||||||
|
{
|
||||||
|
var schema = JsonSerializer.Deserialize<GetWeatherSchema>(
|
||||||
|
arguments,
|
||||||
|
new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase });
|
||||||
|
|
||||||
|
return WeatherReport(schema?.City ?? string.Empty, schema?.Date ?? string.Empty);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,39 @@
|
||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// OllamaTestUtils.cs
|
||||||
|
|
||||||
|
namespace AutoGen.Ollama.Tests;
|
||||||
|
|
||||||
|
public static class OllamaTestUtils
|
||||||
|
{
|
||||||
|
public static Tool WeatherTool => new()
|
||||||
|
{
|
||||||
|
Function = new Function
|
||||||
|
{
|
||||||
|
Name = "get_current_weather",
|
||||||
|
Description = "Get the current weather for a location",
|
||||||
|
Parameters = new Parameters
|
||||||
|
{
|
||||||
|
Properties = new Dictionary<string, Properties>
|
||||||
|
{
|
||||||
|
{
|
||||||
|
"location",
|
||||||
|
new Properties
|
||||||
|
{
|
||||||
|
Type = "string", Description = "The location to get the weather for, e.g. San Francisco, CA"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"format", new Properties
|
||||||
|
{
|
||||||
|
Type = "string",
|
||||||
|
Description =
|
||||||
|
"The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
|
||||||
|
Enum = new List<string> {"celsius", "fahrenheit"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Required = new List<string> { "location", "format" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
|
@ -45,7 +45,7 @@ pip install autogen-agentchat~=0.2
|
||||||
import os
|
import os
|
||||||
from autogen import AssistantAgent, UserProxyAgent
|
from autogen import AssistantAgent, UserProxyAgent
|
||||||
|
|
||||||
llm_config = {"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY"]}
|
llm_config = { "config_list": [{ "model": "gpt-4", "api_key": os.environ.get("OPENAI_API_KEY") }] }
|
||||||
assistant = AssistantAgent("assistant", llm_config=llm_config)
|
assistant = AssistantAgent("assistant", llm_config=llm_config)
|
||||||
user_proxy = UserProxyAgent("user_proxy", code_execution_config=False)
|
user_proxy = UserProxyAgent("user_proxy", code_execution_config=False)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue