[.Net] feature: Ollama integration (#2693)

* [.Net] feature: Ollama integration with

* [.Net] ollama agent improvements and reorganization

* added ollama fact logic

* [.Net] added ollama embeddings service

* [.Net] Ollama embeddings integration

* cleaned the agent and connector code

* [.Net] cleaned ollama agent tests

* [.Net] standardize api key fact ollama host variable

* [.Net] fixed solution issue

---------

Co-authored-by: Xiaoyun Zhang <bigmiao.zhang@gmail.com>
This commit is contained in:
Israel de la Cruz 2024-05-15 18:54:08 +02:00 committed by GitHub
parent 84577570ad
commit 1c3ae92d39
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 953 additions and 0 deletions

View File

@ -37,6 +37,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.SemanticKernel.Test
EndProject EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.DotnetInteractive.Tests", "test\AutoGen.DotnetInteractive.Tests\AutoGen.DotnetInteractive.Tests.csproj", "{B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}" Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.DotnetInteractive.Tests", "test\AutoGen.DotnetInteractive.Tests\AutoGen.DotnetInteractive.Tests.csproj", "{B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}"
EndProject EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Autogen.Ollama", "src\Autogen.Ollama\Autogen.Ollama.csproj", "{A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Autogen.Ollama.Tests", "test\Autogen.Ollama.Tests\Autogen.Ollama.Tests.csproj", "{C24FDE63-952D-4F8E-A807-AF31D43AD675}"
EndProject
Global Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU Debug|Any CPU = Debug|Any CPU
@ -91,6 +95,14 @@ Global
{15441693-3659-4868-B6C1-B106F52FF3BA}.Debug|Any CPU.Build.0 = Debug|Any CPU {15441693-3659-4868-B6C1-B106F52FF3BA}.Debug|Any CPU.Build.0 = Debug|Any CPU
{15441693-3659-4868-B6C1-B106F52FF3BA}.Release|Any CPU.ActiveCfg = Release|Any CPU {15441693-3659-4868-B6C1-B106F52FF3BA}.Release|Any CPU.ActiveCfg = Release|Any CPU
{15441693-3659-4868-B6C1-B106F52FF3BA}.Release|Any CPU.Build.0 = Release|Any CPU {15441693-3659-4868-B6C1-B106F52FF3BA}.Release|Any CPU.Build.0 = Release|Any CPU
{A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}.Debug|Any CPU.Build.0 = Debug|Any CPU
{A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}.Release|Any CPU.Build.0 = Release|Any CPU
{C24FDE63-952D-4F8E-A807-AF31D43AD675}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{C24FDE63-952D-4F8E-A807-AF31D43AD675}.Debug|Any CPU.Build.0 = Debug|Any CPU
{C24FDE63-952D-4F8E-A807-AF31D43AD675}.Release|Any CPU.ActiveCfg = Release|Any CPU
{C24FDE63-952D-4F8E-A807-AF31D43AD675}.Release|Any CPU.Build.0 = Release|Any CPU
{1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Debug|Any CPU.Build.0 = Debug|Any CPU {1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Debug|Any CPU.Build.0 = Debug|Any CPU
{1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Release|Any CPU.ActiveCfg = Release|Any CPU {1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Release|Any CPU.ActiveCfg = Release|Any CPU
@ -116,6 +128,8 @@ Global
{63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} {63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
{6585D1A4-3D97-4D76-A688-1933B61AEB19} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} {6585D1A4-3D97-4D76-A688-1933B61AEB19} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
{15441693-3659-4868-B6C1-B106F52FF3BA} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} {15441693-3659-4868-B6C1-B106F52FF3BA} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
{C24FDE63-952D-4F8E-A807-AF31D43AD675} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{1DFABC4A-8458-4875-8DCB-59F3802DAC65} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} {1DFABC4A-8458-4875-8DCB-59F3802DAC65} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{B61388CA-DC73-4B7F-A7B2-7B9A86C9229E} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
EndGlobalSection EndGlobalSection

View File

@ -0,0 +1,216 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// OllamaAgent.cs
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.Core;
namespace Autogen.Ollama;
/// <summary>
/// An agent that can interact with ollama models.
/// </summary>
public class OllamaAgent : IStreamingAgent
{
private readonly HttpClient _httpClient;
public string Name { get; }
private readonly string _modelName;
private readonly string _systemMessage;
private readonly OllamaReplyOptions? _replyOptions;
public OllamaAgent(HttpClient httpClient, string name, string modelName,
string systemMessage = "You are a helpful AI assistant",
OllamaReplyOptions? replyOptions = null)
{
Name = name;
_httpClient = httpClient;
_modelName = modelName;
_systemMessage = systemMessage;
_replyOptions = replyOptions;
}
public async Task<IMessage> GenerateReplyAsync(
IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellation = default)
{
ChatRequest request = await BuildChatRequest(messages, options);
request.Stream = false;
using (HttpResponseMessage? response = await _httpClient
.SendAsync(BuildRequestMessage(request), HttpCompletionOption.ResponseContentRead, cancellation))
{
response.EnsureSuccessStatusCode();
Stream? streamResponse = await response.Content.ReadAsStreamAsync();
ChatResponse chatResponse = await JsonSerializer.DeserializeAsync<ChatResponse>(streamResponse, cancellationToken: cancellation)
?? throw new Exception("Failed to deserialize response");
var output = new MessageEnvelope<ChatResponse>(chatResponse, from: Name);
return output;
}
}
public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
ChatRequest request = await BuildChatRequest(messages, options);
request.Stream = true;
HttpRequestMessage message = BuildRequestMessage(request);
using (HttpResponseMessage? response = await _httpClient.SendAsync(message, HttpCompletionOption.ResponseHeadersRead, cancellationToken))
{
response.EnsureSuccessStatusCode();
using Stream? stream = await response.Content.ReadAsStreamAsync().ConfigureAwait(false);
using var reader = new StreamReader(stream);
while (!reader.EndOfStream && !cancellationToken.IsCancellationRequested)
{
string? line = await reader.ReadLineAsync();
if (string.IsNullOrWhiteSpace(line)) continue;
ChatResponseUpdate? update = JsonSerializer.Deserialize<ChatResponseUpdate>(line);
if (update != null)
{
yield return new MessageEnvelope<ChatResponseUpdate>(update, from: Name);
}
if (update is { Done: false }) continue;
ChatResponse? chatMessage = JsonSerializer.Deserialize<ChatResponse>(line);
if (chatMessage == null) continue;
yield return new MessageEnvelope<ChatResponse>(chatMessage, from: Name);
}
}
}
private async Task<ChatRequest> BuildChatRequest(IEnumerable<IMessage> messages, GenerateReplyOptions? options)
{
var request = new ChatRequest
{
Model = _modelName,
Messages = await BuildChatHistory(messages)
};
if (options is OllamaReplyOptions replyOptions)
{
BuildChatRequestOptions(replyOptions, request);
return request;
}
if (_replyOptions != null)
{
BuildChatRequestOptions(_replyOptions, request);
return request;
}
return request;
}
private void BuildChatRequestOptions(OllamaReplyOptions replyOptions, ChatRequest request)
{
request.Format = replyOptions.Format == FormatType.Json ? OllamaConsts.JsonFormatType : null;
request.Template = replyOptions.Template;
request.KeepAlive = replyOptions.KeepAlive;
if (replyOptions.Temperature != null
|| replyOptions.MaxToken != null
|| replyOptions.StopSequence != null
|| replyOptions.Seed != null
|| replyOptions.MiroStat != null
|| replyOptions.MiroStatEta != null
|| replyOptions.MiroStatTau != null
|| replyOptions.NumCtx != null
|| replyOptions.NumGqa != null
|| replyOptions.NumGpu != null
|| replyOptions.NumThread != null
|| replyOptions.RepeatLastN != null
|| replyOptions.RepeatPenalty != null
|| replyOptions.TopK != null
|| replyOptions.TopP != null
|| replyOptions.TfsZ != null)
{
request.Options = new ModelReplyOptions
{
Temperature = replyOptions.Temperature,
NumPredict = replyOptions.MaxToken,
Stop = replyOptions.StopSequence?[0],
Seed = replyOptions.Seed,
MiroStat = replyOptions.MiroStat,
MiroStatEta = replyOptions.MiroStatEta,
MiroStatTau = replyOptions.MiroStatTau,
NumCtx = replyOptions.NumCtx,
NumGqa = replyOptions.NumGqa,
NumGpu = replyOptions.NumGpu,
NumThread = replyOptions.NumThread,
RepeatLastN = replyOptions.RepeatLastN,
RepeatPenalty = replyOptions.RepeatPenalty,
TopK = replyOptions.TopK,
TopP = replyOptions.TopP,
TfsZ = replyOptions.TfsZ
};
}
}
private async Task<List<Message>> BuildChatHistory(IEnumerable<IMessage> messages)
{
if (!messages.Any(m => m.IsSystemMessage()))
{
var systemMessage = new TextMessage(Role.System, _systemMessage, from: Name);
messages = new[] { systemMessage }.Concat(messages);
}
var collection = new List<Message>();
foreach (IMessage? message in messages)
{
Message item;
switch (message)
{
case TextMessage tm:
item = new Message { Role = tm.Role.ToString(), Value = tm.Content };
break;
case ImageMessage im:
string base64Image = await ImageUrlToBase64(im.Url!);
item = new Message { Role = im.Role.ToString(), Images = [base64Image] };
break;
case MultiModalMessage mm:
var textsGroupedByRole = mm.Content.OfType<TextMessage>().GroupBy(tm => tm.Role)
.ToDictionary(g => g.Key, g => string.Join(Environment.NewLine, g.Select(tm => tm.Content)));
string content = string.Join($"{Environment.NewLine}", textsGroupedByRole
.Select(g => $"{g.Key}{Environment.NewLine}:{g.Value}"));
IEnumerable<Task<string>> imagesConversionTasks = mm.Content
.OfType<ImageMessage>()
.Select(async im => await ImageUrlToBase64(im.Url!));
string[]? imagesBase64 = await Task.WhenAll(imagesConversionTasks);
item = new Message { Role = mm.Role.ToString(), Value = content, Images = imagesBase64 };
break;
default:
throw new NotSupportedException();
}
collection.Add(item);
}
return collection;
}
private static HttpRequestMessage BuildRequestMessage(ChatRequest request)
{
string serialized = JsonSerializer.Serialize(request);
return new HttpRequestMessage(HttpMethod.Post, OllamaConsts.ChatCompletionEndpoint)
{
Content = new StringContent(serialized, Encoding.UTF8, OllamaConsts.JsonMediaType)
};
}
private async Task<string> ImageUrlToBase64(string imageUrl)
{
if (string.IsNullOrWhiteSpace(imageUrl))
{
throw new ArgumentException("required parameter", nameof(imageUrl));
}
byte[] imageBytes = await _httpClient.GetByteArrayAsync(imageUrl);
return imageBytes != null
? Convert.ToBase64String(imageBytes)
: throw new InvalidOperationException("no image byte array");
}
}

View File

@ -0,0 +1,12 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<GenerateDocumentationFile>True</GenerateDocumentationFile>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\AutoGen.Core\AutoGen.Core.csproj" />
</ItemGroup>
</Project>

View File

@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ChatRequest.cs
using System;
using System.Collections.Generic;
using System.Text.Json.Serialization;
namespace Autogen.Ollama;
public class ChatRequest
{
/// <summary>
/// (required) the model name
/// </summary>
[JsonPropertyName("model")]
public string Model { get; set; } = string.Empty;
/// <summary>
/// the messages of the chat, this can be used to keep a chat memory
/// </summary>
[JsonPropertyName("messages")]
public IList<Message> Messages { get; set; } = Array.Empty<Message>();
/// <summary>
/// the format to return a response in. Currently, the only accepted value is json
/// </summary>
[JsonPropertyName("format")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Format { get; set; }
/// <summary>
/// additional model parameters listed in the documentation for the Modelfile such as temperature
/// </summary>
[JsonPropertyName("options")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public ModelReplyOptions? Options { get; set; }
/// <summary>
/// the prompt template to use (overrides what is defined in the Modelfile)
/// </summary>
[JsonPropertyName("template")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Template { get; set; }
/// <summary>
/// if false the response will be returned as a single response object, rather than a stream of objects
/// </summary>
[JsonPropertyName("stream")]
public bool Stream { get; set; }
/// <summary>
/// controls how long the model will stay loaded into memory following the request (default: 5m)
/// </summary>
[JsonPropertyName("keep_alive")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? KeepAlive { get; set; }
}

View File

@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ChatResponse.cs
using System.Text.Json.Serialization;
namespace Autogen.Ollama;
public class ChatResponse : ChatResponseUpdate
{
/// <summary>
/// time spent generating the response
/// </summary>
[JsonPropertyName("total_duration")]
public long TotalDuration { get; set; }
/// <summary>
/// time spent in nanoseconds loading the model
/// </summary>
[JsonPropertyName("load_duration")]
public long LoadDuration { get; set; }
/// <summary>
/// number of tokens in the prompt
/// </summary>
[JsonPropertyName("prompt_eval_count")]
public int PromptEvalCount { get; set; }
/// <summary>
/// time spent in nanoseconds evaluating the prompt
/// </summary>
[JsonPropertyName("prompt_eval_duration")]
public long PromptEvalDuration { get; set; }
/// <summary>
/// number of tokens the response
/// </summary>
[JsonPropertyName("eval_count")]
public int EvalCount { get; set; }
/// <summary>
/// time in nanoseconds spent generating the response
/// </summary>
[JsonPropertyName("eval_duration")]
public long EvalDuration { get; set; }
}

View File

@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ChatResponseUpdate.cs
using System.Collections.Generic;
using System.Text.Json.Serialization;
namespace Autogen.Ollama;
public class ChatResponseUpdate
{
[JsonPropertyName("model")]
public string Model { get; set; } = string.Empty;
[JsonPropertyName("created_at")]
public string CreatedAt { get; set; } = string.Empty;
[JsonPropertyName("message")]
public Message? Message { get; set; }
[JsonPropertyName("done")]
public bool Done { get; set; }
}
public class Message
{
/// <summary>
/// the role of the message, either system, user or assistant
/// </summary>
[JsonPropertyName("role")]
public string Role { get; set; } = string.Empty;
/// <summary>
/// the content of the message
/// </summary>
[JsonPropertyName("content")]
public string Value { get; set; } = string.Empty;
/// <summary>
/// (optional): a list of images to include in the message (for multimodal models such as llava)
/// </summary>
[JsonPropertyName("images")]
public IList<string>? Images { get; set; }
}

View File

@ -0,0 +1,129 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ModelReplyOptions.cs
using System.Text.Json.Serialization;
namespace Autogen.Ollama;
//https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
public class ModelReplyOptions
{
/// <summary>
/// Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)
/// </summary>
[JsonPropertyName("mirostat")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? MiroStat { get; set; }
/// <summary>
/// Influences how quickly the algorithm responds to feedback from the generated text.
/// A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1)
/// </summary>
[JsonPropertyName("mirostat_eta")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public float? MiroStatEta { get; set; }
/// <summary>
/// Controls the balance between coherence and diversity of the output.
/// A lower value will result in more focused and coherent text. (Default: 5.0)
/// </summary>
[JsonPropertyName("mirostat_tau")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public float? MiroStatTau { get; set; }
/// <summary>
/// Sets the size of the context window used to generate the next token. (Default: 2048)
/// </summary>
[JsonPropertyName("num_ctx")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? NumCtx { get; set; }
/// <summary>
/// The number of GQA groups in the transformer layer. Required for some models, for example it is 8 for llama2:70b
/// </summary>
[JsonPropertyName("num_gqa")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? NumGqa { get; set; }
/// <summary>
/// The number of layers to send to the GPU(s). On macOS it defaults to 1 to enable metal support, 0 to disable.
/// </summary>
[JsonPropertyName("num_gpu")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? NumGpu { get; set; }
/// <summary>
/// Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance.
/// It is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores).
/// </summary>
[JsonPropertyName("num_thread")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? NumThread { get; set; }
/// <summary>
/// Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)
/// </summary>
[JsonPropertyName("repeat_last_n")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? RepeatLastN { get; set; }
/// <summary>
/// Sets how strongly to penalize repetitions.
/// A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)
/// </summary>
[JsonPropertyName("repeat_penalty")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public float? RepeatPenalty { get; set; }
/// <summary>
/// The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8)
/// </summary>
[JsonPropertyName("temperature")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public float? Temperature { get; set; }
/// <summary>
/// Sets the random number seed to use for generation.
/// Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0)
/// </summary>
[JsonPropertyName("seed")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? Seed { get; set; }
/// <summary>
/// Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return.
/// Multiple stop patterns may be set by specifying multiple separate stop parameters in a modelfile.
/// </summary>
[JsonPropertyName("stop")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Stop { get; set; }
/// <summary>
/// Tail free sampling is used to reduce the impact of less probable tokens from the output.
/// A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1)
/// </summary>
[JsonPropertyName("tfs_z")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public float? TfsZ { get; set; }
/// <summary>
/// Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context)
/// </summary>
[JsonPropertyName("num_predict")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? NumPredict { get; set; }
/// <summary>
/// Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)
/// </summary>
[JsonPropertyName("top_k")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? TopK { get; set; }
/// <summary>
/// Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)
/// </summary>
[JsonPropertyName("top_p")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? TopP { get; set; }
}

View File

@ -0,0 +1,111 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// OllamaReplyOptions.cs
using AutoGen.Core;
namespace Autogen.Ollama;
public enum FormatType
{
None,
Json
}
public class OllamaReplyOptions : GenerateReplyOptions
{
/// <summary>
/// the format to return a response in. Currently, the only accepted value is json
/// </summary>
public FormatType Format { get; set; } = FormatType.None;
/// <summary>
/// the prompt template to use (overrides what is defined in the Modelfile)
/// </summary>
public string? Template { get; set; }
/// <summary>
/// The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8)
/// </summary>
public new float? Temperature { get; set; }
/// <summary>
/// controls how long the model will stay loaded into memory following the request (default: 5m)
/// </summary>
public string? KeepAlive { get; set; }
/// <summary>
/// Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)
/// </summary>
public int? MiroStat { get; set; }
/// <summary>
/// Influences how quickly the algorithm responds to feedback from the generated text.
/// A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1)
/// </summary>
public float? MiroStatEta { get; set; }
/// <summary>
/// Controls the balance between coherence and diversity of the output.
/// A lower value will result in more focused and coherent text. (Default: 5.0)
/// </summary>
public float? MiroStatTau { get; set; }
/// <summary>
/// Sets the size of the context window used to generate the next token. (Default: 2048)
/// </summary>
public int? NumCtx { get; set; }
/// <summary>
/// The number of GQA groups in the transformer layer. Required for some models, for example it is 8 for llama2:70b
/// </summary>
public int? NumGqa { get; set; }
/// <summary>
/// The number of layers to send to the GPU(s). On macOS it defaults to 1 to enable metal support, 0 to disable.
/// </summary>
public int? NumGpu { get; set; }
/// <summary>
/// Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance.
/// It is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores).
/// </summary>
public int? NumThread { get; set; }
/// <summary>
/// Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)
/// </summary>
public int? RepeatLastN { get; set; }
/// <summary>
/// Sets how strongly to penalize repetitions.
/// A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)
/// </summary>
public float? RepeatPenalty { get; set; }
/// <summary>
/// Sets the random number seed to use for generation.
/// Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0)
/// </summary>
public int? Seed { get; set; }
/// <summary>
/// Tail free sampling is used to reduce the impact of less probable tokens from the output.
/// A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1)
/// </summary>
public float? TfsZ { get; set; }
/// <summary>
/// Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context)
/// </summary>
public new int? MaxToken { get; set; }
/// <summary>
/// Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)
/// </summary>
public int? TopK { get; set; }
/// <summary>
/// Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)
/// </summary>
public int? TopP { get; set; }
}

View File

@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ITextEmbeddingService.cs
using System.Threading;
using System.Threading.Tasks;
namespace Autogen.Ollama;
public interface ITextEmbeddingService
{
public Task<TextEmbeddingsResponse> GenerateAsync(TextEmbeddingsRequest request, CancellationToken cancellationToken);
}

View File

@ -0,0 +1,44 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// OllamaTextEmbeddingService.cs
using System;
using System.IO;
using System.Net.Http;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
namespace Autogen.Ollama;
public class OllamaTextEmbeddingService : ITextEmbeddingService
{
private readonly HttpClient _client;
public OllamaTextEmbeddingService(HttpClient client)
{
_client = client;
}
public async Task<TextEmbeddingsResponse> GenerateAsync(TextEmbeddingsRequest request, CancellationToken cancellationToken = default)
{
using (HttpResponseMessage? response = await _client
.SendAsync(BuildPostRequest(request), HttpCompletionOption.ResponseContentRead, cancellationToken))
{
response.EnsureSuccessStatusCode();
Stream? streamResponse = await response.Content.ReadAsStreamAsync();
TextEmbeddingsResponse output = await JsonSerializer
.DeserializeAsync<TextEmbeddingsResponse>(streamResponse, cancellationToken: cancellationToken)
?? throw new Exception("Failed to deserialize response");
return output;
}
}
private static HttpRequestMessage BuildPostRequest(TextEmbeddingsRequest request)
{
string serialized = JsonSerializer.Serialize(request);
return new HttpRequestMessage(HttpMethod.Post, OllamaConsts.EmbeddingsEndpoint)
{
Content = new StringContent(serialized, Encoding.UTF8, OllamaConsts.JsonMediaType)
};
}
}

View File

@ -0,0 +1,32 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// TextEmbeddingsRequest.cs
using System.Text.Json.Serialization;
namespace Autogen.Ollama;
public class TextEmbeddingsRequest
{
/// <summary>
/// name of model to generate embeddings from
/// </summary>
[JsonPropertyName("model")]
public string Model { get; set; } = string.Empty;
/// <summary>
/// text to generate embeddings for
/// </summary>
[JsonPropertyName("prompt")]
public string Prompt { get; set; } = string.Empty;
/// <summary>
/// additional model parameters listed in the documentation for the Modelfile such as temperature
/// </summary>
[JsonPropertyName("options")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public ModelReplyOptions? Options { get; set; }
/// <summary>
/// controls how long the model will stay loaded into memory following the request (default: 5m)
/// </summary>
[JsonPropertyName("keep_alive")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? KeepAlive { get; set; }
}

View File

@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// TextEmbeddingsResponse.cs
using System.Text.Json.Serialization;
namespace Autogen.Ollama;
public class TextEmbeddingsResponse
{
[JsonPropertyName("embedding")]
public double[]? Embedding { get; set; }
}

View File

@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// OllamaMessageConnector.cs
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.Core;
namespace Autogen.Ollama;
public class OllamaMessageConnector : IMiddleware, IStreamingMiddleware
{
public string Name => nameof(OllamaMessageConnector);
public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent,
CancellationToken cancellationToken = default)
{
IEnumerable<IMessage> messages = context.Messages;
IMessage reply = await agent.GenerateReplyAsync(messages, context.Options, cancellationToken);
switch (reply)
{
case IMessage<ChatResponse> messageEnvelope:
Message? message = messageEnvelope.Content.Message;
return new TextMessage(Role.Assistant, message != null ? message.Value : "EMPTY_CONTENT", messageEnvelope.From);
default:
throw new NotSupportedException();
}
}
public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
await foreach (IStreamingMessage? update in agent.GenerateStreamingReplyAsync(context.Messages, context.Options, cancellationToken))
{
switch (update)
{
case IMessage<ChatResponse> complete:
{
string? textContent = complete.Content.Message?.Value;
yield return new TextMessage(Role.Assistant, textContent!, complete.From);
break;
}
case IMessage<ChatResponseUpdate> updatedMessage:
{
string? textContent = updatedMessage.Content.Message?.Value;
yield return new TextMessageUpdate(Role.Assistant, textContent, updatedMessage.From);
break;
}
default:
throw new InvalidOperationException("Message type not supported.");
}
}
}
}

View File

@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// OllamaConsts.cs
namespace Autogen.Ollama;
public class OllamaConsts
{
public const string JsonFormatType = "json";
public const string JsonMediaType = "application/json";
public const string ChatCompletionEndpoint = "/api/chat";
public const string EmbeddingsEndpoint = "/api/embeddings";
}

View File

@ -0,0 +1,33 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<IsPackable>false</IsPackable>
<IsTestProject>true</IsTestProject>
<GenerateDocumentationFile>True</GenerateDocumentationFile>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="FluentAssertions" Version="6.8.0" />
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" Version="8.0.0" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.7.0"/>
<PackageReference Include="xunit" Version="2.4.2"/>
<PackageReference Include="xunit.runner.visualstudio" Version="2.4.5">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="coverlet.collector" Version="6.0.0">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\Autogen.Ollama\Autogen.Ollama.csproj" />
<ProjectReference Include="..\AutoGen.Tests\AutoGen.Tests.csproj" />
</ItemGroup>
</Project>

View File

@ -0,0 +1,102 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// OllamaAgentTests.cs
using System.Text.Json;
using AutoGen.Core;
using AutoGen.Tests;
using FluentAssertions;
namespace Autogen.Ollama.Tests;
public class OllamaAgentTests
{
[ApiKeyFact("OLLAMA_HOST", "OLLAMA_MODEL_NAME")]
public async Task GenerateReplyAsync_ReturnsValidMessage_WhenCalled()
{
string host = Environment.GetEnvironmentVariable("OLLAMA_HOST")
?? throw new InvalidOperationException("OLLAMA_HOST is not set.");
string modelName = Environment.GetEnvironmentVariable("OLLAMA_MODEL_NAME")
?? throw new InvalidOperationException("OLLAMA_MODEL_NAME is not set.");
OllamaAgent ollamaAgent = BuildOllamaAgent(host, modelName);
var messages = new IMessage[] { new TextMessage(Role.User, "Hello, how are you") };
IMessage result = await ollamaAgent.GenerateReplyAsync(messages);
result.Should().NotBeNull();
result.Should().BeOfType<MessageEnvelope<ChatResponse>>();
result.From.Should().Be(ollamaAgent.Name);
}
[ApiKeyFact("OLLAMA_HOST", "OLLAMA_MODEL_NAME")]
public async Task GenerateReplyAsync_ReturnsValidJsonMessageContent_WhenCalled()
{
string host = Environment.GetEnvironmentVariable("OLLAMA_HOST")
?? throw new InvalidOperationException("OLLAMA_HOST is not set.");
string modelName = Environment.GetEnvironmentVariable("OLLAMA_MODEL_NAME")
?? throw new InvalidOperationException("OLLAMA_MODEL_NAME is not set.");
OllamaAgent ollamaAgent = BuildOllamaAgent(host, modelName);
var messages = new IMessage[] { new TextMessage(Role.User, "Hello, how are you") };
IMessage result = await ollamaAgent.GenerateReplyAsync(messages, new OllamaReplyOptions
{
Format = FormatType.Json
});
result.Should().NotBeNull();
result.Should().BeOfType<MessageEnvelope<ChatResponse>>();
result.From.Should().Be(ollamaAgent.Name);
string jsonContent = ((MessageEnvelope<ChatResponse>)result).Content.Message!.Value;
bool isValidJson = IsValidJsonMessage(jsonContent);
isValidJson.Should().BeTrue();
}
[ApiKeyFact("OLLAMA_HOST", "OLLAMA_MODEL_NAME")]
public async Task GenerateStreamingReplyAsync_ReturnsValidMessages_WhenCalled()
{
string host = Environment.GetEnvironmentVariable("OLLAMA_HOST")
?? throw new InvalidOperationException("OLLAMA_HOST is not set.");
string modelName = Environment.GetEnvironmentVariable("OLLAMA_MODEL_NAME")
?? throw new InvalidOperationException("OLLAMA_MODEL_NAME is not set.");
OllamaAgent ollamaAgent = BuildOllamaAgent(host, modelName);
var messages = new IMessage[] { new TextMessage(Role.User, "Hello how are you") };
IStreamingMessage? finalReply = default;
await foreach (IStreamingMessage message in ollamaAgent.GenerateStreamingReplyAsync(messages))
{
message.Should().NotBeNull();
message.From.Should().Be(ollamaAgent.Name);
finalReply = message;
}
finalReply.Should().BeOfType<MessageEnvelope<ChatResponse>>();
}
private static bool IsValidJsonMessage(string input)
{
try
{
JsonDocument.Parse(input);
return true;
}
catch (JsonException)
{
return false;
}
catch (Exception ex)
{
Console.WriteLine("An unexpected exception occurred: " + ex.Message);
return false;
}
}
private static OllamaAgent BuildOllamaAgent(string host, string modelName)
{
var httpClient = new HttpClient
{
BaseAddress = new Uri(host)
};
return new OllamaAgent(httpClient, "TestAgent", modelName);
}
}

View File

@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// OllamaTextEmbeddingServiceTests.cs
using AutoGen.Tests;
using FluentAssertions;
namespace Autogen.Ollama.Tests;
public class OllamaTextEmbeddingServiceTests
{
[ApiKeyFact("OLLAMA_HOST", "OLLAMA_EMBEDDING_MODEL_NAME")]
public async Task GenerateAsync_ReturnsEmbeddings_WhenApiResponseIsSuccessful()
{
string host = Environment.GetEnvironmentVariable("OLLAMA_HOST")
?? throw new InvalidOperationException("OLLAMA_HOST is not set.");
string embeddingModelName = Environment.GetEnvironmentVariable("OLLAMA_EMBEDDING_MODEL_NAME")
?? throw new InvalidOperationException("OLLAMA_EMBEDDING_MODEL_NAME is not set.");
var httpClient = new HttpClient
{
BaseAddress = new Uri(host)
};
var request = new TextEmbeddingsRequest { Model = embeddingModelName, Prompt = "Llamas are members of the camelid family", };
var service = new OllamaTextEmbeddingService(httpClient);
TextEmbeddingsResponse response = await service.GenerateAsync(request);
response.Should().NotBeNull();
}
}