mirror of https://github.com/microsoft/autogen.git
Merge branch 'main' into add-chainlit-to-distributed-group-chat
This commit is contained in:
commit
d32a6b1102
|
@ -118,4 +118,4 @@ message ReadmeRequested {
|
|||
</ItemGroup>
|
||||
```
|
||||
|
||||
You can send messages using the [```Microsoft.AutoGen.Agents.AgentClient``` class](autogen/dotnet/src/Microsoft.AutoGen/Agents/AgentClient.cs). Messages are wrapped in [the CloudEvents specification](https://cloudevents.io) and sent to the event bus.
|
||||
You can send messages using the [```Microsoft.AutoGen.Agents.AgentWorker``` class](autogen/dotnet/src/Microsoft.AutoGen/Agents/AgentWorker.cs). Messages are wrapped in [the CloudEvents specification](https://cloudevents.io) and sent to the event bus.
|
||||
|
|
|
@ -117,7 +117,7 @@ message ReadmeRequested {
|
|||
</ItemGroup>
|
||||
```
|
||||
|
||||
You can send messages using the [```Microsoft.AutoGen.Agents``` class](autogen/dotnet/src/Microsoft.AutoGen/Agents/AgentClient.cs). Messages are wrapped in [the CloudEvents specification](https://cloudevents.io) and sent to the event bus.
|
||||
You can send messages using the [```Microsoft.AutoGen.Agents.AgentWorker``` class](autogen/dotnet/src/Microsoft.AutoGen/Agents/AgentWorker.cs). Messages are wrapped in [the CloudEvents specification](https://cloudevents.io) and sent to the event bus.
|
||||
|
||||
### Managing State
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ builder.AddAgentWorker(builder.Configuration["AGENT_HOST"]!)
|
|||
//.AddAgent<Sandbox>(nameof(Sandbox))
|
||||
.AddAgent<Hubber>(nameof(Hubber));
|
||||
|
||||
builder.Services.AddSingleton<AgentClient>();
|
||||
builder.Services.AddSingleton<AgentWorker>();
|
||||
builder.Services.AddSingleton<WebhookEventProcessor, GithubWebHookProcessor>();
|
||||
builder.Services.AddSingleton<GithubAuthService>();
|
||||
builder.Services.AddSingleton<IManageAzure, AzureService>();
|
||||
|
|
|
@ -10,10 +10,10 @@ using Octokit.Webhooks.Models;
|
|||
|
||||
namespace DevTeam.Backend;
|
||||
|
||||
public sealed class GithubWebHookProcessor(ILogger<GithubWebHookProcessor> logger, AgentClient client) : WebhookEventProcessor
|
||||
public sealed class GithubWebHookProcessor(ILogger<GithubWebHookProcessor> logger, AgentWorker client) : WebhookEventProcessor
|
||||
{
|
||||
private readonly ILogger<GithubWebHookProcessor> _logger = logger;
|
||||
private readonly AgentClient _client = client;
|
||||
private readonly AgentWorker _client = client;
|
||||
|
||||
protected override async Task ProcessIssuesWebhookAsync(WebhookHeaders headers, IssuesEvent issuesEvent, IssuesAction action)
|
||||
{
|
||||
|
|
|
@ -8,18 +8,17 @@ using Microsoft.Extensions.Logging;
|
|||
|
||||
namespace Microsoft.AutoGen.Agents;
|
||||
|
||||
public abstract class AgentBase
|
||||
public abstract class AgentBase : IAgentBase
|
||||
{
|
||||
public static readonly ActivitySource s_source = new("AutoGen.Agent");
|
||||
public AgentId AgentId => _context.AgentId;
|
||||
private readonly object _lock = new();
|
||||
private readonly Dictionary<string, TaskCompletionSource<RpcResponse>> _pendingRequests = [];
|
||||
|
||||
private readonly Channel<object> _mailbox = Channel.CreateUnbounded<object>();
|
||||
private readonly IAgentContext _context;
|
||||
|
||||
protected internal AgentId AgentId => _context.AgentId;
|
||||
protected internal ILogger Logger => _context.Logger;
|
||||
protected internal IAgentContext Context => _context;
|
||||
public IAgentContext Context => _context;
|
||||
protected readonly EventTypes EventTypes;
|
||||
|
||||
protected AgentBase(IAgentContext context, EventTypes eventTypes)
|
||||
|
@ -54,7 +53,7 @@ public abstract class AgentBase
|
|||
}
|
||||
}
|
||||
|
||||
internal void ReceiveMessage(Message message) => _mailbox.Writer.TryWrite(message);
|
||||
public void ReceiveMessage(Message message) => _mailbox.Writer.TryWrite(message);
|
||||
|
||||
private async Task RunMessagePump()
|
||||
{
|
||||
|
@ -79,7 +78,7 @@ public abstract class AgentBase
|
|||
}
|
||||
}
|
||||
|
||||
private async Task HandleRpcMessage(Message msg)
|
||||
protected internal async Task HandleRpcMessage(Message msg)
|
||||
{
|
||||
switch (msg.MessageCase)
|
||||
{
|
||||
|
@ -108,12 +107,12 @@ public abstract class AgentBase
|
|||
break;
|
||||
}
|
||||
}
|
||||
protected async Task Store(AgentState state)
|
||||
public async Task Store(AgentState state)
|
||||
{
|
||||
await _context.Store(state).ConfigureAwait(false);
|
||||
return;
|
||||
}
|
||||
protected async Task<T> Read<T>(AgentId agentId) where T : IMessage, new()
|
||||
public async Task<T> Read<T>(AgentId agentId) where T : IMessage, new()
|
||||
{
|
||||
var agentstate = await _context.Read(agentId).ConfigureAwait(false);
|
||||
return agentstate.FromAgentState<T>();
|
||||
|
@ -132,7 +131,6 @@ public abstract class AgentBase
|
|||
|
||||
completion.SetResult(response);
|
||||
}
|
||||
|
||||
private async Task OnRequestCore(RpcRequest request)
|
||||
{
|
||||
RpcResponse response;
|
||||
|
@ -193,7 +191,7 @@ public abstract class AgentBase
|
|||
return await completion.Task.ConfigureAwait(false);
|
||||
}
|
||||
|
||||
protected async ValueTask PublishEvent(CloudEvent item)
|
||||
public async ValueTask PublishEvent(CloudEvent item)
|
||||
{
|
||||
var activity = s_source.StartActivity($"PublishEvent '{item.Type}'", ActivityKind.Client, Activity.Current?.Context ?? default);
|
||||
activity?.SetTag("peer.service", $"{item.Type}/{item.Source}");
|
||||
|
@ -225,5 +223,5 @@ public abstract class AgentBase
|
|||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
protected virtual Task<RpcResponse> HandleRequest(RpcRequest request) => Task.FromResult(new RpcResponse { Error = "Not implemented" });
|
||||
public virtual Task<RpcResponse> HandleRequest(RpcRequest request) => Task.FromResult(new RpcResponse { Error = "Not implemented" });
|
||||
}
|
||||
|
|
|
@ -1,46 +0,0 @@
|
|||
using System.Diagnostics;
|
||||
using Google.Protobuf;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace Microsoft.AutoGen.Agents;
|
||||
public sealed class AgentClient(ILogger<AgentClient> logger, AgentWorkerRuntime runtime, DistributedContextPropagator distributedContextPropagator,
|
||||
[FromKeyedServices("EventTypes")] EventTypes eventTypes)
|
||||
: AgentBase(new ClientContext(logger, runtime, distributedContextPropagator), eventTypes)
|
||||
{
|
||||
public async ValueTask PublishEventAsync(CloudEvent evt) => await PublishEvent(evt);
|
||||
public async ValueTask<RpcResponse> SendRequestAsync(AgentId target, string method, Dictionary<string, string> parameters) => await RequestAsync(target, method, parameters);
|
||||
public async ValueTask PublishEventAsync(string topic, IMessage evt)
|
||||
{
|
||||
await PublishEventAsync(evt.ToCloudEvent(topic)).ConfigureAwait(false);
|
||||
}
|
||||
private sealed class ClientContext(ILogger<AgentClient> logger, AgentWorkerRuntime runtime, DistributedContextPropagator distributedContextPropagator) : IAgentContext
|
||||
{
|
||||
public AgentId AgentId { get; } = new AgentId("client", Guid.NewGuid().ToString());
|
||||
public AgentBase? AgentInstance { get; set; }
|
||||
public ILogger Logger { get; } = logger;
|
||||
public DistributedContextPropagator DistributedContextPropagator { get; } = distributedContextPropagator;
|
||||
public async ValueTask PublishEventAsync(CloudEvent @event)
|
||||
{
|
||||
await runtime.PublishEvent(@event).ConfigureAwait(false);
|
||||
}
|
||||
public async ValueTask SendRequestAsync(AgentBase agent, RpcRequest request)
|
||||
{
|
||||
await runtime.SendRequest(AgentInstance!, request).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
public async ValueTask SendResponseAsync(RpcRequest request, RpcResponse response)
|
||||
{
|
||||
await runtime.SendResponse(response).ConfigureAwait(false);
|
||||
}
|
||||
public ValueTask Store(AgentState value)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
public ValueTask<AgentState> Read(AgentId agentId)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -4,9 +4,9 @@ using Microsoft.Extensions.Logging;
|
|||
|
||||
namespace Microsoft.AutoGen.Agents;
|
||||
|
||||
internal sealed class AgentContext(AgentId agentId, AgentWorkerRuntime runtime, ILogger<AgentBase> logger, DistributedContextPropagator distributedContextPropagator) : IAgentContext
|
||||
internal sealed class AgentContext(AgentId agentId, IAgentWorkerRuntime runtime, ILogger<AgentBase> logger, DistributedContextPropagator distributedContextPropagator) : IAgentContext
|
||||
{
|
||||
private readonly AgentWorkerRuntime _runtime = runtime;
|
||||
private readonly IAgentWorkerRuntime _runtime = runtime;
|
||||
|
||||
public AgentId AgentId { get; } = agentId;
|
||||
public ILogger Logger { get; } = logger;
|
||||
|
@ -19,18 +19,18 @@ internal sealed class AgentContext(AgentId agentId, AgentWorkerRuntime runtime,
|
|||
}
|
||||
public async ValueTask SendRequestAsync(AgentBase agent, RpcRequest request)
|
||||
{
|
||||
await _runtime.SendRequest(agent, request);
|
||||
await _runtime.SendRequest(agent, request).ConfigureAwait(false);
|
||||
}
|
||||
public async ValueTask PublishEventAsync(CloudEvent @event)
|
||||
{
|
||||
await _runtime.PublishEvent(@event);
|
||||
await _runtime.PublishEvent(@event).ConfigureAwait(false);
|
||||
}
|
||||
public async ValueTask Store(AgentState value)
|
||||
{
|
||||
await _runtime.Store(value);
|
||||
await _runtime.Store(value).ConfigureAwait(false);
|
||||
}
|
||||
public async ValueTask<AgentState> Read(AgentId agentId)
|
||||
public ValueTask<AgentState> Read(AgentId agentId)
|
||||
{
|
||||
return await _runtime.Read(agentId);
|
||||
return _runtime.Read(agentId);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
using System.Diagnostics;
|
||||
using Google.Protobuf;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace Microsoft.AutoGen.Agents;
|
||||
public sealed class AgentWorker(IAgentWorkerRuntime runtime, DistributedContextPropagator distributedContextPropagator,
|
||||
[FromKeyedServices("EventTypes")] EventTypes eventTypes, ILogger<AgentBase> logger)
|
||||
: AgentBase(new AgentContext(new AgentId("client", Guid.NewGuid().ToString()), runtime, logger, distributedContextPropagator), eventTypes)
|
||||
{
|
||||
public async ValueTask PublishEventAsync(CloudEvent evt) => await PublishEvent(evt);
|
||||
|
||||
public async ValueTask PublishEventAsync(string topic, IMessage evt)
|
||||
{
|
||||
await PublishEventAsync(evt.ToCloudEvent(topic)).ConfigureAwait(false);
|
||||
}
|
||||
}
|
|
@ -44,7 +44,7 @@ public static class AgentsApp
|
|||
{
|
||||
await StartAsync(builder, agents, local);
|
||||
}
|
||||
var client = Host.Services.GetRequiredService<AgentClient>() ?? throw new InvalidOperationException("Host not started");
|
||||
var client = Host.Services.GetRequiredService<AgentWorker>() ?? throw new InvalidOperationException("Host not started");
|
||||
await client.PublishEventAsync(topic, message).ConfigureAwait(false);
|
||||
return Host;
|
||||
}
|
||||
|
|
|
@ -10,12 +10,12 @@ using Microsoft.Extensions.Logging;
|
|||
|
||||
namespace Microsoft.AutoGen.Agents;
|
||||
|
||||
public sealed class AgentWorkerRuntime : IHostedService, IDisposable, IAgentWorkerRuntime
|
||||
public sealed class GrpcAgentWorkerRuntime : IHostedService, IDisposable, IAgentWorkerRuntime
|
||||
{
|
||||
private readonly object _channelLock = new();
|
||||
private readonly ConcurrentDictionary<string, Type> _agentTypes = new();
|
||||
private readonly ConcurrentDictionary<(string Type, string Key), AgentBase> _agents = new();
|
||||
private readonly ConcurrentDictionary<string, (AgentBase Agent, string OriginalRequestId)> _pendingRequests = new();
|
||||
private readonly ConcurrentDictionary<(string Type, string Key), IAgentBase> _agents = new();
|
||||
private readonly ConcurrentDictionary<string, (IAgentBase Agent, string OriginalRequestId)> _pendingRequests = new();
|
||||
private readonly Channel<Message> _outboundMessagesChannel = Channel.CreateBounded<Message>(new BoundedChannelOptions(1024)
|
||||
{
|
||||
AllowSynchronousContinuations = true,
|
||||
|
@ -26,19 +26,19 @@ public sealed class AgentWorkerRuntime : IHostedService, IDisposable, IAgentWork
|
|||
private readonly AgentRpc.AgentRpcClient _client;
|
||||
private readonly IServiceProvider _serviceProvider;
|
||||
private readonly IEnumerable<Tuple<string, Type>> _configuredAgentTypes;
|
||||
private readonly ILogger<AgentWorkerRuntime> _logger;
|
||||
private readonly ILogger<GrpcAgentWorkerRuntime> _logger;
|
||||
private readonly DistributedContextPropagator _distributedContextPropagator;
|
||||
private readonly CancellationTokenSource _shutdownCts;
|
||||
private AsyncDuplexStreamingCall<Message, Message>? _channel;
|
||||
private Task? _readTask;
|
||||
private Task? _writeTask;
|
||||
|
||||
public AgentWorkerRuntime(
|
||||
public GrpcAgentWorkerRuntime(
|
||||
AgentRpc.AgentRpcClient client,
|
||||
IHostApplicationLifetime hostApplicationLifetime,
|
||||
IServiceProvider serviceProvider,
|
||||
[FromKeyedServices("AgentTypes")] IEnumerable<Tuple<string, Type>> configuredAgentTypes,
|
||||
ILogger<AgentWorkerRuntime> logger,
|
||||
ILogger<GrpcAgentWorkerRuntime> logger,
|
||||
DistributedContextPropagator distributedContextPropagator)
|
||||
{
|
||||
_client = client;
|
||||
|
@ -83,6 +83,13 @@ public sealed class AgentWorkerRuntime : IHostedService, IDisposable, IAgentWork
|
|||
message.Response.RequestId = request.OriginalRequestId;
|
||||
request.Agent.ReceiveMessage(message);
|
||||
break;
|
||||
case Message.MessageOneofCase.RegisterAgentTypeResponse:
|
||||
if (!message.RegisterAgentTypeResponse.Success)
|
||||
{
|
||||
throw new InvalidOperationException($"Failed to register agent: '{message.RegisterAgentTypeResponse.Error}'.");
|
||||
}
|
||||
break;
|
||||
|
||||
case Message.MessageOneofCase.CloudEvent:
|
||||
|
||||
// HACK: Send the message to an instance of each agent type
|
||||
|
@ -163,7 +170,7 @@ public sealed class AgentWorkerRuntime : IHostedService, IDisposable, IAgentWork
|
|||
}
|
||||
}
|
||||
|
||||
private AgentBase GetOrActivateAgent(AgentId agentId)
|
||||
private IAgentBase GetOrActivateAgent(AgentId agentId)
|
||||
{
|
||||
if (!_agents.TryGetValue((agentId.Type, agentId.Key), out var agent))
|
||||
{
|
||||
|
@ -197,6 +204,7 @@ public sealed class AgentWorkerRuntime : IHostedService, IDisposable, IAgentWork
|
|||
RegisterAgentTypeRequest = new RegisterAgentTypeRequest
|
||||
{
|
||||
Type = type,
|
||||
RequestId = Guid.NewGuid().ToString(),
|
||||
//TopicTypes = { topicTypes },
|
||||
//StateType = state?.Name,
|
||||
//Events = { events }
|
||||
|
@ -211,7 +219,7 @@ public sealed class AgentWorkerRuntime : IHostedService, IDisposable, IAgentWork
|
|||
await WriteChannelAsync(new Message { Response = response }).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
public async ValueTask SendRequest(AgentBase agent, RpcRequest request)
|
||||
public async ValueTask SendRequest(IAgentBase agent, RpcRequest request)
|
||||
{
|
||||
_logger.LogInformation("[{AgentId}] Sending request '{Request}'.", agent.AgentId, request);
|
||||
var requestId = Guid.NewGuid().ToString();
|
||||
|
@ -322,10 +330,6 @@ public sealed class AgentWorkerRuntime : IHostedService, IDisposable, IAgentWork
|
|||
_channel?.Dispose();
|
||||
}
|
||||
}
|
||||
public ValueTask SendRequest(RpcRequest request)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
public ValueTask Store(AgentState value)
|
||||
{
|
||||
var agentId = value.AgentId ?? throw new InvalidOperationException("AgentId is required when saving AgentState.");
|
|
@ -49,9 +49,9 @@ public static class HostBuilderExtensions
|
|||
});
|
||||
});
|
||||
builder.Services.TryAddSingleton(DistributedContextPropagator.Current);
|
||||
builder.Services.AddSingleton<AgentClient>();
|
||||
builder.Services.AddSingleton<AgentWorkerRuntime>();
|
||||
builder.Services.AddSingleton<IHostedService>(sp => sp.GetRequiredService<AgentWorkerRuntime>());
|
||||
builder.Services.AddSingleton<IAgentWorkerRuntime, GrpcAgentWorkerRuntime>();
|
||||
builder.Services.AddSingleton<IHostedService>(sp => (IHostedService)sp.GetRequiredService<IAgentWorkerRuntime>());
|
||||
builder.Services.AddSingleton<AgentWorker>();
|
||||
builder.Services.AddKeyedSingleton("EventTypes", (sp, key) =>
|
||||
{
|
||||
var interfaceType = typeof(IMessage);
|
||||
|
@ -111,7 +111,7 @@ public sealed class AgentTypes(Dictionary<string, Type> types)
|
|||
.SelectMany(assembly => assembly.GetTypes())
|
||||
.Where(type => ReflectionHelper.IsSubclassOfGeneric(type, typeof(AgentBase))
|
||||
&& !type.IsAbstract
|
||||
&& !type.Name.Equals("AgentClient"))
|
||||
&& !type.Name.Equals("AgentWorker"))
|
||||
.ToDictionary(type => type.Name, type => type);
|
||||
|
||||
return new AgentTypes(agents);
|
||||
|
|
|
@ -1,22 +1,20 @@
|
|||
using Google.Protobuf;
|
||||
using Microsoft.AutoGen.Abstractions;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace Microsoft.AutoGen.Agents
|
||||
{
|
||||
public interface IAgentBase
|
||||
{
|
||||
// Properties
|
||||
string AgentId { get; }
|
||||
ILogger Logger { get; }
|
||||
AgentId AgentId { get; }
|
||||
IAgentContext Context { get; }
|
||||
|
||||
// Methods
|
||||
Task CallHandler(CloudEvent item);
|
||||
Task<RpcResponse> HandleRequest(RpcRequest request);
|
||||
Task Start();
|
||||
Task ReceiveMessage(Message message);
|
||||
void ReceiveMessage(Message message);
|
||||
Task Store(AgentState state);
|
||||
Task<T> Read<T>(AgentId agentId);
|
||||
Task PublishEvent(CloudEvent item);
|
||||
Task<T> Read<T>(AgentId agentId) where T : IMessage, new();
|
||||
ValueTask PublishEvent(CloudEvent item);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,8 @@ namespace Microsoft.AutoGen.Agents;
|
|||
public interface IAgentWorkerRuntime
|
||||
{
|
||||
ValueTask PublishEvent(CloudEvent evt);
|
||||
ValueTask SendRequest(RpcRequest request);
|
||||
ValueTask SendRequest(IAgentBase agent, RpcRequest request);
|
||||
ValueTask SendResponse(RpcResponse response);
|
||||
ValueTask Store(AgentState value);
|
||||
ValueTask<AgentState> Read(AgentId agentId);
|
||||
}
|
||||
|
|
|
@ -141,8 +141,22 @@ internal sealed class WorkerGateway : BackgroundService, IWorkerGateway
|
|||
{
|
||||
connection.AddSupportedType(msg.Type);
|
||||
_supportedAgentTypes.GetOrAdd(msg.Type, _ => []).Add(connection);
|
||||
var success = false;
|
||||
var error = String.Empty;
|
||||
|
||||
await _gatewayRegistry.RegisterAgentType(msg.Type, _reference);
|
||||
try
|
||||
{
|
||||
await _gatewayRegistry.RegisterAgentType(msg.Type, _reference);
|
||||
success = true;
|
||||
}
|
||||
catch (InvalidOperationException exception)
|
||||
{
|
||||
error = $"Error registering agent type '{msg.Type}'.";
|
||||
_logger.LogWarning(exception, error);
|
||||
}
|
||||
var request_id = msg.RequestId;
|
||||
var response = new RegisterAgentTypeResponse { RequestId = request_id, Success = success, Error = error };
|
||||
await connection.SendMessage(new Message { RegisterAgentTypeResponse = response });
|
||||
}
|
||||
|
||||
private async ValueTask DispatchEventAsync(CloudEvent evt)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from ._assistant_agent import AssistantAgent
|
||||
from ._assistant_agent import AssistantAgent, Handoff
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
from ._code_executor_agent import CodeExecutorAgent
|
||||
from ._coding_assistant_agent import CodingAssistantAgent
|
||||
|
@ -7,6 +7,7 @@ from ._tool_use_assistant_agent import ToolUseAssistantAgent
|
|||
__all__ = [
|
||||
"BaseChatAgent",
|
||||
"AssistantAgent",
|
||||
"Handoff",
|
||||
"CodeExecutorAgent",
|
||||
"CodingAssistantAgent",
|
||||
"ToolUseAssistantAgent",
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, List, Sequence
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Sequence
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components import FunctionCall
|
||||
|
@ -15,11 +15,13 @@ from autogen_core.components.models import (
|
|||
UserMessage,
|
||||
)
|
||||
from autogen_core.components.tools import FunctionTool, Tool
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from .. import EVENT_LOGGER_NAME
|
||||
from ..messages import (
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
ResetMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
)
|
||||
|
@ -31,6 +33,9 @@ event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
|||
class ToolCallEvent(BaseModel):
|
||||
"""A tool call event."""
|
||||
|
||||
source: str
|
||||
"""The source of the event."""
|
||||
|
||||
tool_calls: List[FunctionCall]
|
||||
"""The tool call message."""
|
||||
|
||||
|
@ -40,12 +45,58 @@ class ToolCallEvent(BaseModel):
|
|||
class ToolCallResultEvent(BaseModel):
|
||||
"""A tool call result event."""
|
||||
|
||||
source: str
|
||||
"""The source of the event."""
|
||||
|
||||
tool_call_results: List[FunctionExecutionResult]
|
||||
"""The tool call result message."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class Handoff(BaseModel):
|
||||
"""Handoff configuration for :class:`AssistantAgent`."""
|
||||
|
||||
target: str
|
||||
"""The name of the target agent to handoff to."""
|
||||
|
||||
description: str = Field(default=None)
|
||||
"""The description of the handoff such as the condition under which it should happen and the target agent's ability.
|
||||
If not provided, it is generated from the target agent's name."""
|
||||
|
||||
name: str = Field(default=None)
|
||||
"""The name of this handoff configuration. If not provided, it is generated from the target agent's name."""
|
||||
|
||||
message: str = Field(default=None)
|
||||
"""The message to the target agent.
|
||||
If not provided, it is generated from the target agent's name."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_defaults(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if values.get("description") is None:
|
||||
values["description"] = f"Handoff to {values['target']}."
|
||||
if values.get("name") is None:
|
||||
values["name"] = f"transfer_to_{values['target']}".lower()
|
||||
else:
|
||||
name = values["name"]
|
||||
if not isinstance(name, str):
|
||||
raise ValueError(f"Handoff name must be a string: {values['name']}")
|
||||
# Check if name is a valid identifier.
|
||||
if not name.isidentifier():
|
||||
raise ValueError(f"Handoff name must be a valid identifier: {values['name']}")
|
||||
if values.get("message") is None:
|
||||
values["message"] = (
|
||||
f"Transferred to {values['target']}, adopting the role of {values['target']} immediately."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def handoff_tool(self) -> Tool:
|
||||
"""Create a handoff tool from this handoff configuration."""
|
||||
return FunctionTool(lambda: self.message, name=self.name, description=self.description)
|
||||
|
||||
|
||||
class AssistantAgent(BaseChatAgent):
|
||||
"""An agent that provides assistance with tool use.
|
||||
|
||||
|
@ -55,8 +106,52 @@ class AssistantAgent(BaseChatAgent):
|
|||
name (str): The name of the agent.
|
||||
model_client (ChatCompletionClient): The model client to use for inference.
|
||||
tools (List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional): The tools to register with the agent.
|
||||
handoffs (List[Handoff | str] | None, optional): The handoff configurations for the agent, allowing it to transfer to other agents by responding with a HandoffMessage.
|
||||
If a handoff is a string, it should represent the target agent's name.
|
||||
description (str, optional): The description of the agent.
|
||||
system_message (str, optional): The system message for the model.
|
||||
|
||||
Raises:
|
||||
ValueError: If tool names are not unique.
|
||||
ValueError: If handoff names are not unique.
|
||||
ValueError: If handoff names are not unique from tool names.
|
||||
|
||||
Examples:
|
||||
|
||||
The following example demonstrates how to create an assistant agent with
|
||||
a model client and generate a response to a simple task.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.task import MaxMessageTermination
|
||||
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
agent = AssistantAgent(name="assistant", model_client=model_client)
|
||||
|
||||
await agent.run("What is the capital of France?", termination_condition=MaxMessageTermination(2))
|
||||
|
||||
|
||||
The following example demonstrates how to create an assistant agent with
|
||||
a model client and a tool, and generate a response to a simple task using the tool.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.task import MaxMessageTermination
|
||||
|
||||
|
||||
async def get_current_time() -> str:
|
||||
return "The current time is 12:00 PM."
|
||||
|
||||
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time])
|
||||
|
||||
await agent.run("What is the current time?", termination_condition=MaxMessageTermination(3))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -65,6 +160,7 @@ class AssistantAgent(BaseChatAgent):
|
|||
model_client: ChatCompletionClient,
|
||||
*,
|
||||
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
|
||||
handoffs: List[Handoff | str] | None = None,
|
||||
description: str = "An agent that provides assistance with ability to use tools.",
|
||||
system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply with 'TERMINATE' when the task has been completed.",
|
||||
):
|
||||
|
@ -84,33 +180,74 @@ class AssistantAgent(BaseChatAgent):
|
|||
self._tools.append(FunctionTool(tool, description=description))
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool type: {type(tool)}")
|
||||
# Check if tool names are unique.
|
||||
tool_names = [tool.name for tool in self._tools]
|
||||
if len(tool_names) != len(set(tool_names)):
|
||||
raise ValueError(f"Tool names must be unique: {tool_names}")
|
||||
# Handoff tools.
|
||||
self._handoff_tools: List[Tool] = []
|
||||
self._handoffs: Dict[str, Handoff] = {}
|
||||
if handoffs is not None:
|
||||
for handoff in handoffs:
|
||||
if isinstance(handoff, str):
|
||||
handoff = Handoff(target=handoff)
|
||||
if isinstance(handoff, Handoff):
|
||||
self._handoff_tools.append(handoff.handoff_tool)
|
||||
self._handoffs[handoff.name] = handoff
|
||||
else:
|
||||
raise ValueError(f"Unsupported handoff type: {type(handoff)}")
|
||||
# Check if handoff tool names are unique.
|
||||
handoff_tool_names = [tool.name for tool in self._handoff_tools]
|
||||
if len(handoff_tool_names) != len(set(handoff_tool_names)):
|
||||
raise ValueError(f"Handoff names must be unique: {handoff_tool_names}")
|
||||
# Check if handoff tool names not in tool names.
|
||||
if any(name in tool_names for name in handoff_tool_names):
|
||||
raise ValueError(
|
||||
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
|
||||
)
|
||||
self._model_context: List[LLMMessage] = []
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
# Add messages to the model context.
|
||||
for msg in messages:
|
||||
# TODO: add special handling for handoff messages
|
||||
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
|
||||
if isinstance(msg, ResetMessage):
|
||||
self._model_context.clear()
|
||||
else:
|
||||
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
|
||||
|
||||
# Generate an inference result based on the current model context.
|
||||
llm_messages = self._system_messages + self._model_context
|
||||
result = await self._model_client.create(llm_messages, tools=self._tools, cancellation_token=cancellation_token)
|
||||
result = await self._model_client.create(
|
||||
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
|
||||
)
|
||||
|
||||
# Add the response to the model context.
|
||||
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
|
||||
|
||||
# Run tool calls until the model produces a string response.
|
||||
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
|
||||
event_logger.debug(ToolCallEvent(tool_calls=result.content))
|
||||
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
|
||||
# Execute the tool calls.
|
||||
results = await asyncio.gather(
|
||||
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
|
||||
)
|
||||
event_logger.debug(ToolCallResultEvent(tool_call_results=results))
|
||||
event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name))
|
||||
self._model_context.append(FunctionExecutionResultMessage(content=results))
|
||||
|
||||
# Detect handoff requests.
|
||||
handoffs: List[Handoff] = []
|
||||
for call in result.content:
|
||||
if call.name in self._handoffs:
|
||||
handoffs.append(self._handoffs[call.name])
|
||||
if len(handoffs) > 0:
|
||||
if len(handoffs) > 1:
|
||||
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
|
||||
# Respond with a handoff message.
|
||||
return HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name)
|
||||
|
||||
# Generate an inference result based on the current model context.
|
||||
result = await self._model_client.create(
|
||||
self._model_context, tools=self._tools, cancellation_token=cancellation_token
|
||||
self._model_context, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
|
||||
)
|
||||
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
|
||||
|
||||
|
@ -127,9 +264,9 @@ class AssistantAgent(BaseChatAgent):
|
|||
) -> FunctionExecutionResult:
|
||||
"""Execute a tool call and return the result."""
|
||||
try:
|
||||
if not self._tools:
|
||||
if not self._tools + self._handoff_tools:
|
||||
raise ValueError("No tools are available.")
|
||||
tool = next((t for t in self._tools if t.name == tool_call.name), None)
|
||||
tool = next((t for t in self._tools + self._handoff_tools if t.name == tool_call.name), None)
|
||||
if tool is None:
|
||||
raise ValueError(f"The tool '{tool_call.name}' is not available.")
|
||||
arguments = json.loads(tool_call.arguments)
|
||||
|
|
|
@ -35,11 +35,21 @@ class StopMessage(BaseMessage):
|
|||
class HandoffMessage(BaseMessage):
|
||||
"""A message requesting handoff of a conversation to another agent."""
|
||||
|
||||
target: str
|
||||
"""The name of the target agent to handoff to."""
|
||||
|
||||
content: str
|
||||
"""The agent name to handoff the conversation to."""
|
||||
"""The handoff message to the target agent."""
|
||||
|
||||
|
||||
ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage
|
||||
class ResetMessage(BaseMessage):
|
||||
"""A message requesting reset of the recipient's state in the current conversation."""
|
||||
|
||||
content: str
|
||||
"""The content for the reset message."""
|
||||
|
||||
|
||||
ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ResetMessage
|
||||
"""A message used by agents in a team."""
|
||||
|
||||
|
||||
|
@ -49,5 +59,6 @@ __all__ = [
|
|||
"MultiModalMessage",
|
||||
"StopMessage",
|
||||
"HandoffMessage",
|
||||
"ResetMessage",
|
||||
"ChatMessage",
|
||||
]
|
||||
|
|
|
@ -48,7 +48,7 @@ class MaxMessageTermination(TerminationCondition):
|
|||
self._message_count += len(messages)
|
||||
if self._message_count >= self._max_messages:
|
||||
return StopMessage(
|
||||
content=f"Maximal number of messages {self._max_messages} reached, current message count: {self._message_count}",
|
||||
content=f"Maximum number of messages {self._max_messages} reached, current message count: {self._message_count}",
|
||||
source="MaxMessageTermination",
|
||||
)
|
||||
return None
|
||||
|
|
|
@ -28,7 +28,12 @@ class BaseGroupChat(Team, ABC):
|
|||
create a subclass of :class:`BaseGroupChat` that uses the group chat manager.
|
||||
"""
|
||||
|
||||
def __init__(self, participants: List[ChatAgent], group_chat_manager_class: type[BaseGroupChatManager]):
|
||||
def __init__(
|
||||
self,
|
||||
participants: List[ChatAgent],
|
||||
group_chat_manager_class: type[BaseGroupChatManager],
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
):
|
||||
if len(participants) == 0:
|
||||
raise ValueError("At least one participant is required.")
|
||||
if len(participants) != len(set(participant.name for participant in participants)):
|
||||
|
@ -36,6 +41,7 @@ class BaseGroupChat(Team, ABC):
|
|||
self._participants = participants
|
||||
self._team_id = str(uuid.uuid4())
|
||||
self._base_group_chat_manager_class = group_chat_manager_class
|
||||
self._termination_condition = termination_condition
|
||||
|
||||
@abstractmethod
|
||||
def _create_group_chat_manager_factory(
|
||||
|
@ -109,7 +115,7 @@ class BaseGroupChat(Team, ABC):
|
|||
group_topic_type=group_topic_type,
|
||||
participant_topic_types=participant_topic_types,
|
||||
participant_descriptions=participant_descriptions,
|
||||
termination_condition=termination_condition,
|
||||
termination_condition=termination_condition or self._termination_condition,
|
||||
),
|
||||
)
|
||||
# Add subscriptions for the group chat manager.
|
||||
|
|
|
@ -82,8 +82,12 @@ class RoundRobinGroupChat(BaseGroupChat):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, participants: List[ChatAgent]):
|
||||
super().__init__(participants, group_chat_manager_class=RoundRobinGroupChatManager)
|
||||
def __init__(self, participants: List[ChatAgent], termination_condition: TerminationCondition | None = None):
|
||||
super().__init__(
|
||||
participants,
|
||||
termination_condition=termination_condition,
|
||||
group_chat_manager_class=RoundRobinGroupChatManager,
|
||||
)
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
|
|
|
@ -140,7 +140,8 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
|||
+ re.escape(name.replace("_", r"\_"))
|
||||
+ r")(?=\W)"
|
||||
)
|
||||
count = len(re.findall(regex, f" {message_content} ")) # Pad the message to help with matching
|
||||
# Pad the message to help with matching
|
||||
count = len(re.findall(regex, f" {message_content} "))
|
||||
if count > 0:
|
||||
mentions[name] = count
|
||||
return mentions
|
||||
|
@ -184,6 +185,7 @@ class SelectorGroupChat(BaseGroupChat):
|
|||
participants: List[ChatAgent],
|
||||
model_client: ChatCompletionClient,
|
||||
*,
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
selector_prompt: str = """You are in a role play game. The following roles are available:
|
||||
{roles}.
|
||||
Read the following conversation. Then select the next role from {participants} to play. Only return the role.
|
||||
|
@ -194,7 +196,9 @@ Read the above conversation. Then select the next role from {participants} to pl
|
|||
""",
|
||||
allow_repeated_speaker: bool = False,
|
||||
):
|
||||
super().__init__(participants, group_chat_manager_class=SelectorGroupChatManager)
|
||||
super().__init__(
|
||||
participants, termination_condition=termination_condition, group_chat_manager_class=SelectorGroupChatManager
|
||||
)
|
||||
# Validate the participants.
|
||||
if len(participants) < 2:
|
||||
raise ValueError("At least two participants are required for SelectorGroupChat.")
|
||||
|
|
|
@ -37,7 +37,7 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
|||
async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str:
|
||||
"""Select a speaker from the participants based on handoff message."""
|
||||
if len(thread) > 0 and isinstance(thread[-1].agent_message, HandoffMessage):
|
||||
self._current_speaker = thread[-1].agent_message.content
|
||||
self._current_speaker = thread[-1].agent_message.target
|
||||
if self._current_speaker not in self._participant_topic_types:
|
||||
raise ValueError("The selected speaker in the handoff message is not a participant.")
|
||||
event_logger.debug(GroupChatSelectSpeakerEvent(selected_speaker=self._current_speaker, source=self.id))
|
||||
|
@ -47,10 +47,45 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
|||
|
||||
|
||||
class Swarm(BaseGroupChat):
|
||||
"""(Experimental) A group chat that selects the next speaker based on handoff message only."""
|
||||
"""A group chat team that selects the next speaker based on handoff message only.
|
||||
|
||||
def __init__(self, participants: List[ChatAgent]):
|
||||
super().__init__(participants, group_chat_manager_class=SwarmGroupChatManager)
|
||||
The first participant in the list of participants is the initial speaker.
|
||||
The next speaker is selected based on the :class:`~autogen_agentchat.messages.HandoffMessage` message
|
||||
sent by the current speaker. If no handoff message is sent, the current speaker
|
||||
continues to be the speaker.
|
||||
|
||||
Args:
|
||||
participants (List[ChatAgent]): The agents participating in the group chat. The first agent in the list is the initial speaker.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.teams import Swarm
|
||||
from autogen_agentchat.task import MaxMessageTermination
|
||||
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent(
|
||||
"Alice",
|
||||
model_client=model_client,
|
||||
handoffs=["Bob"],
|
||||
system_message="You are Alice and you only answer questions about yourself.",
|
||||
)
|
||||
agent2 = AssistantAgent(
|
||||
"Bob", model_client=model_client, system_message="You are Bob and your birthday is on 1st January."
|
||||
)
|
||||
|
||||
team = Swarm([agent1, agent2])
|
||||
await team.run("What is bob's birthday?", termination_condition=MaxMessageTermination(3))
|
||||
"""
|
||||
|
||||
def __init__(self, participants: List[ChatAgent], termination_condition: TerminationCondition | None = None):
|
||||
super().__init__(
|
||||
participants, termination_condition=termination_condition, group_chat_manager_class=SwarmGroupChatManager
|
||||
)
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncGenerator, List
|
||||
|
||||
import pytest
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.messages import StopMessage, TextMessage
|
||||
from autogen_agentchat import EVENT_LOGGER_NAME
|
||||
from autogen_agentchat.agents import AssistantAgent, Handoff
|
||||
from autogen_agentchat.logging import FileLogHandler
|
||||
from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage
|
||||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components.tools import FunctionTool
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from openai.resources.chat.completions import AsyncCompletions
|
||||
|
@ -14,6 +18,10 @@ from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
|||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.addHandler(FileLogHandler("test_assistant_agent.log"))
|
||||
|
||||
|
||||
class _MockChatCompletion:
|
||||
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
|
||||
|
@ -107,3 +115,51 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result.messages[1], TextMessage)
|
||||
assert isinstance(result.messages[2], StopMessage)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
handoff = Handoff(target="agent2")
|
||||
model = "gpt-4o-2024-05-13"
|
||||
chat_completions = [
|
||||
ChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="1",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=handoff.name,
|
||||
arguments=json.dumps({}),
|
||||
),
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
||||
),
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
tool_use_agent = AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||
handoffs=[handoff],
|
||||
)
|
||||
response = await tool_use_agent.on_messages(
|
||||
[TextMessage(content="task", source="user")], cancellation_token=CancellationToken()
|
||||
)
|
||||
assert isinstance(response, HandoffMessage)
|
||||
assert response.target == "agent2"
|
||||
|
|
|
@ -10,6 +10,7 @@ from autogen_agentchat.agents import (
|
|||
AssistantAgent,
|
||||
BaseChatAgent,
|
||||
CodeExecutorAgent,
|
||||
Handoff,
|
||||
)
|
||||
from autogen_agentchat.logging import FileLogHandler
|
||||
from autogen_agentchat.messages import (
|
||||
|
@ -415,11 +416,11 @@ class _HandOffAgent(BaseChatAgent):
|
|||
self._next_agent = next_agent
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
return HandoffMessage(content=self._next_agent, source=self.name)
|
||||
return HandoffMessage(content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swarm() -> None:
|
||||
async def test_swarm_handoff() -> None:
|
||||
first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
|
||||
second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
||||
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
||||
|
@ -428,8 +429,81 @@ async def test_swarm() -> None:
|
|||
result = await team.run("task", termination_condition=MaxMessageTermination(6))
|
||||
assert len(result.messages) == 6
|
||||
assert result.messages[0].content == "task"
|
||||
assert result.messages[1].content == "third_agent"
|
||||
assert result.messages[2].content == "first_agent"
|
||||
assert result.messages[3].content == "second_agent"
|
||||
assert result.messages[4].content == "third_agent"
|
||||
assert result.messages[5].content == "first_agent"
|
||||
assert result.messages[1].content == "Transferred to third_agent."
|
||||
assert result.messages[2].content == "Transferred to first_agent."
|
||||
assert result.messages[3].content == "Transferred to second_agent."
|
||||
assert result.messages[4].content == "Transferred to third_agent."
|
||||
assert result.messages[5].content == "Transferred to first_agent."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
chat_completions = [
|
||||
ChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="1",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="handoff_to_agent2",
|
||||
arguments=json.dumps({}),
|
||||
),
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
||||
),
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
||||
),
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant")
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
||||
),
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
|
||||
agnet1 = AssistantAgent(
|
||||
"agent1",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
handoffs=[Handoff(target="agent2", name="handoff_to_agent2", message="handoff to agent2")],
|
||||
)
|
||||
agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
|
||||
team = Swarm([agnet1, agent2])
|
||||
result = await team.run("task", termination_condition=StopMessageTermination())
|
||||
assert len(result.messages) == 5
|
||||
assert result.messages[0].content == "task"
|
||||
assert result.messages[1].content == "handoff to agent2"
|
||||
assert result.messages[2].content == "Transferred to agent1."
|
||||
assert result.messages[3].content == "Hello"
|
||||
assert result.messages[4].content == "TERMINATE"
|
||||
|
|
|
@ -37,18 +37,18 @@
|
|||
"text": [
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------- \n",
|
||||
"\u001b[91m[2024-10-23T12:15:51.582079]:\u001b[0m\n",
|
||||
"\u001b[91m[2024-10-29T15:48:06.329810]:\u001b[0m\n",
|
||||
"\n",
|
||||
"What is the weather in New York?\n",
|
||||
"--------------------------------------------------------------------------- \n",
|
||||
"\u001b[91m[2024-10-23T12:15:52.745820], writing_agent:\u001b[0m\n",
|
||||
"\u001b[91m[2024-10-29T15:48:08.085839], weather_agent:\u001b[0m\n",
|
||||
"\n",
|
||||
"The weather in New York is currently 73 degrees and sunny. TERMINATE\n",
|
||||
"The weather in New York is 73 degrees and sunny.\n",
|
||||
"--------------------------------------------------------------------------- \n",
|
||||
"\u001b[91m[2024-10-23T12:15:52.746210], Termination:\u001b[0m\n",
|
||||
"\u001b[91m[2024-10-29T15:48:08.086180], Termination:\u001b[0m\n",
|
||||
"\n",
|
||||
"Maximal number of messages 1 reached, current message count: 1\n",
|
||||
" TaskResult(messages=[TextMessage(source='user', content='What is the weather in New York?'), StopMessage(source='writing_agent', content='The weather in New York is currently 73 degrees and sunny. TERMINATE')])\n"
|
||||
"Maximum number of messages 2 reached, current message count: 2\n",
|
||||
" TaskResult(messages=[TextMessage(source='user', content='What is the weather in New York?'), TextMessage(source='weather_agent', content='The weather in New York is 73 degrees and sunny.')])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -56,13 +56,13 @@
|
|||
"import logging\n",
|
||||
"\n",
|
||||
"from autogen_agentchat import EVENT_LOGGER_NAME\n",
|
||||
"from autogen_agentchat.agents import ToolUseAssistantAgent\n",
|
||||
"from autogen_agentchat.agents import AssistantAgent\n",
|
||||
"from autogen_agentchat.logging import ConsoleLogHandler\n",
|
||||
"from autogen_agentchat.task import MaxMessageTermination\n",
|
||||
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
|
||||
"from autogen_core.components.tools import FunctionTool\n",
|
||||
"from autogen_ext.models import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
"# set up logging. You can define your own logger\n",
|
||||
"logger = logging.getLogger(EVENT_LOGGER_NAME)\n",
|
||||
"logger.addHandler(ConsoleLogHandler())\n",
|
||||
"logger.setLevel(logging.INFO)\n",
|
||||
|
@ -73,22 +73,18 @@
|
|||
" return f\"The weather in {city} is 73 degrees and Sunny.\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# wrap the tool for use with the agent\n",
|
||||
"get_weather_tool = FunctionTool(get_weather, description=\"Get the weather for a city\")\n",
|
||||
"\n",
|
||||
"# define an agent\n",
|
||||
"weather_agent = ToolUseAssistantAgent(\n",
|
||||
" name=\"writing_agent\",\n",
|
||||
"weather_agent = AssistantAgent(\n",
|
||||
" name=\"weather_agent\",\n",
|
||||
" model_client=OpenAIChatCompletionClient(model=\"gpt-4o-2024-08-06\"),\n",
|
||||
" registered_tools=[get_weather_tool],\n",
|
||||
" tools=[get_weather],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# add the agent to a team\n",
|
||||
"agent_team = RoundRobinGroupChat([weather_agent])\n",
|
||||
"agent_team = RoundRobinGroupChat([weather_agent], termination_condition=MaxMessageTermination(max_messages=2))\n",
|
||||
"# Note: if running in a Python file directly you'll need to use asyncio.run(agent_team.run(...)) instead of await agent_team.run(...)\n",
|
||||
"result = await agent_team.run(\n",
|
||||
" task=\"What is the weather in New York?\",\n",
|
||||
" termination_condition=MaxMessageTermination(max_messages=1),\n",
|
||||
")\n",
|
||||
"print(\"\\n\", result)"
|
||||
]
|
||||
|
@ -97,7 +93,7 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The code snippet above introduces two high level concepts in AgentChat: `Agent` and `Team`. An Agent helps us define what actions are taken when a message is received. Specifically, we use the `ToolUseAssistantAgent` preset - an agent that can be given a function that it can then use to address tasks. A Team helps us define the rules for how agents interact with each other. In the `RoundRobinGroupChat` team, agents receive messages in a sequential round-robin fashion. "
|
||||
"The code snippet above introduces two high level concepts in AgentChat: `Agent` and `Team`. An Agent helps us define what actions are taken when a message is received. Specifically, we use the `AssistantAgent` preset - an agent that can be given tools (functions) that it can then use to address tasks. A Team helps us define the rules for how agents interact with each other. In the `RoundRobinGroupChat` team, agents receive messages in a sequential round-robin fashion. "
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -113,7 +109,7 @@
|
|||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"display_name": "agnext",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
@ -127,7 +123,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.6"
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -136,6 +136,76 @@
|
|||
" )\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Local within a Virtual Environment\n",
|
||||
"\n",
|
||||
"If you want the code to run within a virtual environment created as part of the application’s setup, you can specify a directory for the newly created environment and pass its context to {py:class}`~autogen_core.components.code_executor.LocalCommandLineCodeExecutor`. This setup allows the executor to use the specified virtual environment consistently throughout the application's lifetime, ensuring isolated dependencies and a controlled runtime environment."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"CommandLineCodeResult(exit_code=0, output='', code_file='/Users/gziz/Dev/autogen/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/coding/tmp_code_d2a7db48799db3cc785156a11a38822a45c19f3956f02ec69b92e4169ecbf2ca.bash')"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import venv\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"from autogen_core.base import CancellationToken\n",
|
||||
"from autogen_core.components.code_executor import CodeBlock, LocalCommandLineCodeExecutor\n",
|
||||
"\n",
|
||||
"work_dir = Path(\"coding\")\n",
|
||||
"work_dir.mkdir(exist_ok=True)\n",
|
||||
"\n",
|
||||
"venv_dir = work_dir / \".venv\"\n",
|
||||
"venv_builder = venv.EnvBuilder(with_pip=True)\n",
|
||||
"venv_builder.create(venv_dir)\n",
|
||||
"venv_context = venv_builder.ensure_directories(venv_dir)\n",
|
||||
"\n",
|
||||
"local_executor = LocalCommandLineCodeExecutor(work_dir=work_dir, virtual_env_context=venv_context)\n",
|
||||
"await local_executor.execute_code_blocks(\n",
|
||||
" code_blocks=[\n",
|
||||
" CodeBlock(language=\"bash\", code=\"pip install matplotlib\"),\n",
|
||||
" ],\n",
|
||||
" cancellation_token=CancellationToken(),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"As we can see, the code has executed successfully, and the installation has been isolated to the newly created virtual environment, without affecting our global environment."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -154,7 +224,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
"version": "3.12.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -32,7 +32,8 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from autogen_ext.models import OpenAIChatCompletionClient, UserMessage\n",
|
||||
"from autogen_core.components.models import UserMessage\n",
|
||||
"from autogen_ext.models import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
"# Create an OpenAI model client.\n",
|
||||
"model_client = OpenAIChatCompletionClient(\n",
|
||||
|
@ -73,6 +74,24 @@
|
|||
"print(response.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"RequestUsage(prompt_tokens=15, completion_tokens=7)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Print the response token usage\n",
|
||||
"print(response.usage)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
@ -85,7 +104,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -93,24 +112,26 @@
|
|||
"output_type": "stream",
|
||||
"text": [
|
||||
"Streamed responses:\n",
|
||||
"In a secluded valley where the sun painted the sky with hues of gold, a solitary dragon named Bremora stood guard. Her emerald scales shimmered with an ancient light as she watched over the village below. Unlike her fiery kin, Bremora had no desire for destruction; her soul was bound by a promise to protect.\n",
|
||||
"In the heart of an ancient forest, beneath the shadow of snow-capped peaks, a dragon named Elara lived secretly for centuries. Elara was unlike any dragon from the old tales; her scales shimmered with a deep emerald hue, each scale engraved with symbols of lost wisdom. The villagers in the nearby valley spoke of mysterious lights dancing across the night sky, but none dared venture close enough to solve the enigma.\n",
|
||||
"\n",
|
||||
"Generations ago, a wise elder had befriended Bremora, offering her companionship instead of fear. In gratitude, she vowed to shield the village from calamity. Years passed, and children grew up believing in the legends of a watchful dragon who brought them prosperity and peace.\n",
|
||||
"One cold winter's eve, a young girl named Lira, brimming with curiosity and armed with the innocence of youth, wandered into Elara’s domain. Instead of fire and fury, she found warmth and a gentle gaze. The dragon shared stories of a world long forgotten and in return, Lira gifted her simple stories of human life, rich in laughter and scent of earth.\n",
|
||||
"\n",
|
||||
"One summer, an ominous storm threatened the valley, with ravenous winds and torrents of rain. Bremora rose into the tempest, her mighty wings defying the chaos. She channeled her breath—not of fire, but of warmth and tranquility—calming the storm and saving her cherished valley.\n",
|
||||
"\n",
|
||||
"When dawn broke and the village emerged unscathed, the people looked to the sky. There, Bremora soared gracefully, a guardian spirit woven into their lives, silently promising her eternal vigilance.\n",
|
||||
"From that night on, the villagers noticed subtle changes—the crops grew taller, and the air seemed sweeter. Elara had infused the valley with ancient magic, a guardian of balance, watching quietly as her new friend thrived under the stars. And so, Lira and Elara’s bond marked the beginning of a timeless friendship that spun tales of hope whispered through the leaves of the ever-verdant forest.\n",
|
||||
"\n",
|
||||
"------------\n",
|
||||
"\n",
|
||||
"The complete response:\n",
|
||||
"In a secluded valley where the sun painted the sky with hues of gold, a solitary dragon named Bremora stood guard. Her emerald scales shimmered with an ancient light as she watched over the village below. Unlike her fiery kin, Bremora had no desire for destruction; her soul was bound by a promise to protect.\n",
|
||||
"In the heart of an ancient forest, beneath the shadow of snow-capped peaks, a dragon named Elara lived secretly for centuries. Elara was unlike any dragon from the old tales; her scales shimmered with a deep emerald hue, each scale engraved with symbols of lost wisdom. The villagers in the nearby valley spoke of mysterious lights dancing across the night sky, but none dared venture close enough to solve the enigma.\n",
|
||||
"\n",
|
||||
"Generations ago, a wise elder had befriended Bremora, offering her companionship instead of fear. In gratitude, she vowed to shield the village from calamity. Years passed, and children grew up believing in the legends of a watchful dragon who brought them prosperity and peace.\n",
|
||||
"One cold winter's eve, a young girl named Lira, brimming with curiosity and armed with the innocence of youth, wandered into Elara’s domain. Instead of fire and fury, she found warmth and a gentle gaze. The dragon shared stories of a world long forgotten and in return, Lira gifted her simple stories of human life, rich in laughter and scent of earth.\n",
|
||||
"\n",
|
||||
"One summer, an ominous storm threatened the valley, with ravenous winds and torrents of rain. Bremora rose into the tempest, her mighty wings defying the chaos. She channeled her breath—not of fire, but of warmth and tranquility—calming the storm and saving her cherished valley.\n",
|
||||
"From that night on, the villagers noticed subtle changes—the crops grew taller, and the air seemed sweeter. Elara had infused the valley with ancient magic, a guardian of balance, watching quietly as her new friend thrived under the stars. And so, Lira and Elara’s bond marked the beginning of a timeless friendship that spun tales of hope whispered through the leaves of the ever-verdant forest.\n",
|
||||
"\n",
|
||||
"When dawn broke and the village emerged unscathed, the people looked to the sky. There, Bremora soared gracefully, a guardian spirit woven into their lives, silently promising her eternal vigilance.\n"
|
||||
"\n",
|
||||
"------------\n",
|
||||
"\n",
|
||||
"The token usage was:\n",
|
||||
"RequestUsage(prompt_tokens=0, completion_tokens=0)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -132,7 +153,10 @@
|
|||
" # The last response is a CreateResult object with the complete message.\n",
|
||||
" print(\"\\n\\n------------\\n\")\n",
|
||||
" print(\"The complete response:\", flush=True)\n",
|
||||
" print(response.content, flush=True)"
|
||||
" print(response.content, flush=True)\n",
|
||||
" print(\"\\n\\n------------\\n\")\n",
|
||||
" print(\"The token usage was:\", flush=True)\n",
|
||||
" print(response.usage, flush=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -142,7 +166,86 @@
|
|||
"```{note}\n",
|
||||
"The last response in the streaming response is always the final response\n",
|
||||
"of the type {py:class}`~autogen_core.components.models.CreateResult`.\n",
|
||||
"```"
|
||||
"```\n",
|
||||
"\n",
|
||||
"**NB the default usage response is to return zero values**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### A Note on Token usage counts with streaming example\n",
|
||||
"Comparing usage returns in the above Non Streaming `model_client.create(messages=messages)` vs streaming `model_client.create_stream(messages=messages)` we see differences.\n",
|
||||
"The non streaming response by default returns valid prompt and completion token usage counts. \n",
|
||||
"The streamed response by default returns zero values.\n",
|
||||
"\n",
|
||||
"as documented in the OPENAI API Reference an additional parameter `stream_options` can be specified to return valid usage counts. see [stream_options](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options)\n",
|
||||
"\n",
|
||||
"Only set this when you using streaming ie , using `create_stream` \n",
|
||||
"\n",
|
||||
"to enable this in `create_stream` set `extra_create_args={\"stream_options\": {\"include_usage\": True}},`\n",
|
||||
"\n",
|
||||
"- **Note whilst other API's like LiteLLM also support this, it is not always guarenteed that it is fully supported or correct**\n",
|
||||
"\n",
|
||||
"#### Streaming example with token usage\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Streamed responses:\n",
|
||||
"In a lush, emerald valley hidden by towering peaks, there lived a dragon named Ember. Unlike others of her kind, Ember cherished solitude over treasure, and the songs of the stream over the roar of flames. One misty dawn, a young shepherd stumbled into her sanctuary, lost and frightened. \n",
|
||||
"\n",
|
||||
"Instead of fury, he was met with kindness as Ember extended a wing, guiding him back to safety. In gratitude, the shepherd visited yearly, bringing tales of his world beyond the mountains. Over time, a friendship blossomed, binding man and dragon in shared stories and laughter.\n",
|
||||
"\n",
|
||||
"As the years passed, the legend of Ember the gentle-hearted spread far and wide, forever changing the way dragons were seen in the hearts of many.\n",
|
||||
"\n",
|
||||
"------------\n",
|
||||
"\n",
|
||||
"The complete response:\n",
|
||||
"In a lush, emerald valley hidden by towering peaks, there lived a dragon named Ember. Unlike others of her kind, Ember cherished solitude over treasure, and the songs of the stream over the roar of flames. One misty dawn, a young shepherd stumbled into her sanctuary, lost and frightened. \n",
|
||||
"\n",
|
||||
"Instead of fury, he was met with kindness as Ember extended a wing, guiding him back to safety. In gratitude, the shepherd visited yearly, bringing tales of his world beyond the mountains. Over time, a friendship blossomed, binding man and dragon in shared stories and laughter.\n",
|
||||
"\n",
|
||||
"As the years passed, the legend of Ember the gentle-hearted spread far and wide, forever changing the way dragons were seen in the hearts of many.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"------------\n",
|
||||
"\n",
|
||||
"The token usage was:\n",
|
||||
"RequestUsage(prompt_tokens=17, completion_tokens=146)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" UserMessage(content=\"Write a very short story about a dragon.\", source=\"user\"),\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Create a stream.\n",
|
||||
"stream = model_client.create_stream(messages=messages, extra_create_args={\"stream_options\": {\"include_usage\": True}})\n",
|
||||
"\n",
|
||||
"# Iterate over the stream and print the responses.\n",
|
||||
"print(\"Streamed responses:\")\n",
|
||||
"async for response in stream: # type: ignore\n",
|
||||
" if isinstance(response, str):\n",
|
||||
" # A partial response is a string.\n",
|
||||
" print(response, flush=True, end=\"\")\n",
|
||||
" else:\n",
|
||||
" # The last response is a CreateResult object with the complete message.\n",
|
||||
" print(\"\\n\\n------------\\n\")\n",
|
||||
" print(\"The complete response:\", flush=True)\n",
|
||||
" print(response.content, flush=True)\n",
|
||||
" print(\"\\n\\n------------\\n\")\n",
|
||||
" print(\"The token usage was:\", flush=True)\n",
|
||||
" print(response.usage, flush=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -233,7 +336,8 @@
|
|||
"from autogen_core.application import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_core.base import MessageContext\n",
|
||||
"from autogen_core.components import RoutedAgent, message_handler\n",
|
||||
"from autogen_core.components.models import ChatCompletionClient, OpenAIChatCompletionClient, SystemMessage, UserMessage\n",
|
||||
"from autogen_core.components.models import ChatCompletionClient, SystemMessage, UserMessage\n",
|
||||
"from autogen_ext.models import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@dataclass\n",
|
||||
|
@ -500,7 +604,7 @@
|
|||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "autogen_core",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
@ -514,7 +618,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -3,12 +3,14 @@
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from string import Template
|
||||
from typing import Any, Callable, ClassVar, List, Sequence, Union
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Callable, ClassVar, List, Optional, Sequence, Union
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
|
@ -54,6 +56,36 @@ class LocalCommandLineCodeExecutor(CodeExecutor):
|
|||
directory is the current directory ".".
|
||||
functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): A list of functions that are available to the code executor. Default is an empty list.
|
||||
functions_module (str, optional): The name of the module that will be created to store the functions. Defaults to "functions".
|
||||
virtual_env_context (Optional[SimpleNamespace], optional): The virtual environment context. Defaults to None.
|
||||
|
||||
Example:
|
||||
|
||||
How to use `LocalCommandLineCodeExecutor` with a virtual environment different from the one used to run the autogen application:
|
||||
Set up a virtual environment using the `venv` module, and pass its context to the initializer of `LocalCommandLineCodeExecutor`. This way, the executor will run code within the new environment.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import venv
|
||||
from pathlib import Path
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components.code_executor import CodeBlock, LocalCommandLineCodeExecutor
|
||||
|
||||
work_dir = Path("coding")
|
||||
work_dir.mkdir(exist_ok=True)
|
||||
|
||||
venv_dir = work_dir / ".venv"
|
||||
venv_builder = venv.EnvBuilder(with_pip=True)
|
||||
venv_builder.create(venv_dir)
|
||||
venv_context = venv_builder.ensure_directories(venv_dir)
|
||||
|
||||
local_executor = LocalCommandLineCodeExecutor(work_dir=work_dir, virtual_env_context=venv_context)
|
||||
await local_executor.execute_code_blocks(
|
||||
code_blocks=[
|
||||
CodeBlock(language="bash", code="pip install matplotlib"),
|
||||
],
|
||||
cancellation_token=CancellationToken(),
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
|
@ -86,6 +118,7 @@ $functions"""
|
|||
]
|
||||
] = [],
|
||||
functions_module: str = "functions",
|
||||
virtual_env_context: Optional[SimpleNamespace] = None,
|
||||
):
|
||||
if timeout < 1:
|
||||
raise ValueError("Timeout must be greater than or equal to 1.")
|
||||
|
@ -110,6 +143,8 @@ $functions"""
|
|||
else:
|
||||
self._setup_functions_complete = True
|
||||
|
||||
self._virtual_env_context: Optional[SimpleNamespace] = virtual_env_context
|
||||
|
||||
def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEMPLATE) -> str:
|
||||
"""(Experimental) Format the functions for a prompt.
|
||||
|
||||
|
@ -164,9 +199,14 @@ $functions"""
|
|||
cmd_args = ["-m", "pip", "install"]
|
||||
cmd_args.extend(required_packages)
|
||||
|
||||
if self._virtual_env_context:
|
||||
py_executable = self._virtual_env_context.env_exe
|
||||
else:
|
||||
py_executable = sys.executable
|
||||
|
||||
task = asyncio.create_task(
|
||||
asyncio.create_subprocess_exec(
|
||||
sys.executable,
|
||||
py_executable,
|
||||
*cmd_args,
|
||||
cwd=self._work_dir,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
|
@ -253,7 +293,17 @@ $functions"""
|
|||
f.write(code)
|
||||
file_names.append(written_file)
|
||||
|
||||
program = sys.executable if lang.startswith("python") else lang_to_cmd(lang)
|
||||
env = os.environ.copy()
|
||||
|
||||
if self._virtual_env_context:
|
||||
virtual_env_exe_abs_path = os.path.abspath(self._virtual_env_context.env_exe)
|
||||
virtual_env_bin_abs_path = os.path.abspath(self._virtual_env_context.bin_path)
|
||||
env["PATH"] = f"{virtual_env_bin_abs_path}{os.pathsep}{env['PATH']}"
|
||||
|
||||
program = virtual_env_exe_abs_path if lang.startswith("python") else lang_to_cmd(lang)
|
||||
else:
|
||||
program = sys.executable if lang.startswith("python") else lang_to_cmd(lang)
|
||||
|
||||
# Wrap in a task to make it cancellable
|
||||
task = asyncio.create_task(
|
||||
asyncio.create_subprocess_exec(
|
||||
|
@ -262,6 +312,7 @@ $functions"""
|
|||
cwd=self._work_dir,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=env,
|
||||
)
|
||||
)
|
||||
cancellation_token.link_future(task)
|
||||
|
|
|
@ -39,6 +39,7 @@ from openai.types.chat import (
|
|||
completion_create_params,
|
||||
)
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
|
||||
from openai.types.shared_params import FunctionDefinition, FunctionParameters
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Unpack
|
||||
|
@ -555,6 +556,31 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
"""
|
||||
Creates an AsyncGenerator that will yield a stream of chat completions based on the provided messages and tools.
|
||||
|
||||
Args:
|
||||
messages (Sequence[LLMMessage]): A sequence of messages to be processed.
|
||||
tools (Sequence[Tool | ToolSchema], optional): A sequence of tools to be used in the completion. Defaults to `[]`.
|
||||
json_output (Optional[bool], optional): If True, the output will be in JSON format. Defaults to None.
|
||||
extra_create_args (Mapping[str, Any], optional): Additional arguments for the creation process. Default to `{}`.
|
||||
cancellation_token (Optional[CancellationToken], optional): A token to cancel the operation. Defaults to None.
|
||||
|
||||
Yields:
|
||||
AsyncGenerator[Union[str, CreateResult], None]: A generator yielding the completion results as they are produced.
|
||||
|
||||
In streaming, the default behaviour is not return token usage counts. See: [OpenAI API reference for possible args](https://platform.openai.com/docs/api-reference/chat/create).
|
||||
However `extra_create_args={"stream_options": {"include_usage": True}}` will (if supported by the accessed API)
|
||||
return a final chunk with usage set to a RequestUsage object having prompt and completion token counts,
|
||||
all preceding chunks will have usage as None. See: [stream_options](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options).
|
||||
|
||||
Other examples of OPENAI supported arguments that can be included in `extra_create_args`:
|
||||
- `temperature` (float): Controls the randomness of the output. Higher values (e.g., 0.8) make the output more random, while lower values (e.g., 0.2) make it more focused and deterministic.
|
||||
- `max_tokens` (int): The maximum number of tokens to generate in the completion.
|
||||
- `top_p` (float): An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass.
|
||||
- `frequency_penalty` (float): A value between -2.0 and 2.0 that penalizes new tokens based on their existing frequency in the text so far, decreasing the likelihood of repeated phrases.
|
||||
- `presence_penalty` (float): A value between -2.0 and 2.0 that penalizes new tokens based on whether they appear in the text so far, encouraging the model to talk about new topics.
|
||||
"""
|
||||
# Make sure all extra_create_args are valid
|
||||
extra_create_args_keys = set(extra_create_args.keys())
|
||||
if not create_kwargs.issuperset(extra_create_args_keys):
|
||||
|
@ -601,7 +627,8 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||
if cancellation_token is not None:
|
||||
cancellation_token.link_future(stream_future)
|
||||
stream = await stream_future
|
||||
|
||||
choice: Union[ParsedChoice[Any], ParsedChoice[BaseModel], ChunkChoice] = cast(ChunkChoice, None)
|
||||
chunk = None
|
||||
stop_reason = None
|
||||
maybe_model = None
|
||||
content_deltas: List[str] = []
|
||||
|
@ -614,8 +641,23 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||
if cancellation_token is not None:
|
||||
cancellation_token.link_future(chunk_future)
|
||||
chunk = await chunk_future
|
||||
choice = chunk.choices[0]
|
||||
stop_reason = choice.finish_reason
|
||||
|
||||
# to process usage chunk in streaming situations
|
||||
# add stream_options={"include_usage": True} in the initialization of OpenAIChatCompletionClient(...)
|
||||
# However the different api's
|
||||
# OPENAI api usage chunk produces no choices so need to check if there is a choice
|
||||
# liteLLM api usage chunk does produce choices
|
||||
choice = (
|
||||
chunk.choices[0]
|
||||
if len(chunk.choices) > 0
|
||||
else choice
|
||||
if chunk.usage is not None and stop_reason is not None
|
||||
else cast(ChunkChoice, None)
|
||||
)
|
||||
|
||||
# for liteLLM chunk usage, do the following hack keeping the pervious chunk.stop_reason (if set).
|
||||
# set the stop_reason for the usage chunk to the prior stop_reason
|
||||
stop_reason = choice.finish_reason if chunk.usage is None and stop_reason is None else stop_reason
|
||||
maybe_model = chunk.model
|
||||
# First try get content
|
||||
if choice.delta.content is not None:
|
||||
|
@ -657,17 +699,21 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||
model = maybe_model or create_args["model"]
|
||||
model = model.replace("gpt-35", "gpt-3.5") # hack for Azure API
|
||||
|
||||
# TODO fix count token
|
||||
prompt_tokens = 0
|
||||
# prompt_tokens = count_token(messages, model=model)
|
||||
if chunk and chunk.usage:
|
||||
prompt_tokens = chunk.usage.prompt_tokens
|
||||
else:
|
||||
prompt_tokens = 0
|
||||
|
||||
if stop_reason is None:
|
||||
raise ValueError("No stop reason found")
|
||||
|
||||
content: Union[str, List[FunctionCall]]
|
||||
if len(content_deltas) > 1:
|
||||
content = "".join(content_deltas)
|
||||
completion_tokens = 0
|
||||
# completion_tokens = count_token(content, model=model)
|
||||
if chunk and chunk.usage:
|
||||
completion_tokens = chunk.usage.completion_tokens
|
||||
else:
|
||||
completion_tokens = 0
|
||||
else:
|
||||
completion_tokens = 0
|
||||
# TODO: fix assumption that dict values were added in order and actually order by int index
|
||||
|
|
|
@ -2,8 +2,11 @@
|
|||
# Credit to original authors
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import venv
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator, TypeAlias
|
||||
|
||||
|
@ -143,3 +146,51 @@ print("hello world")
|
|||
assert "test.py" in result.code_file
|
||||
assert (temp_dir / Path("test.py")).resolve() == Path(result.code_file).resolve()
|
||||
assert (temp_dir / Path("test.py")).exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_executor_with_custom_venv() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
env_builder = venv.EnvBuilder(with_pip=True)
|
||||
env_builder.create(temp_dir)
|
||||
env_builder_context = env_builder.ensure_directories(temp_dir)
|
||||
|
||||
executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, virtual_env_context=env_builder_context)
|
||||
code_blocks = [
|
||||
# https://stackoverflow.com/questions/1871549/how-to-determine-if-python-is-running-inside-a-virtualenv
|
||||
CodeBlock(code="import sys; print(sys.prefix != sys.base_prefix)", language="python"),
|
||||
]
|
||||
cancellation_token = CancellationToken()
|
||||
result = await executor.execute_code_blocks(code_blocks, cancellation_token=cancellation_token)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert result.output.strip() == "True"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_executor_with_custom_venv_in_local_relative_path() -> None:
|
||||
relative_folder_path = "tmp_dir"
|
||||
try:
|
||||
if not os.path.isdir(relative_folder_path):
|
||||
os.mkdir(relative_folder_path)
|
||||
|
||||
env_path = os.path.join(relative_folder_path, ".venv")
|
||||
env_builder = venv.EnvBuilder(with_pip=True)
|
||||
env_builder.create(env_path)
|
||||
env_builder_context = env_builder.ensure_directories(env_path)
|
||||
|
||||
executor = LocalCommandLineCodeExecutor(work_dir=relative_folder_path, virtual_env_context=env_builder_context)
|
||||
code_blocks = [
|
||||
CodeBlock(code="import sys; print(sys.executable)", language="python"),
|
||||
]
|
||||
cancellation_token = CancellationToken()
|
||||
result = await executor.execute_code_blocks(code_blocks, cancellation_token=cancellation_token)
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
# Check if the expected venv has been used
|
||||
bin_path = os.path.abspath(env_builder_context.bin_path)
|
||||
assert Path(result.output.strip()).parent.samefile(bin_path)
|
||||
finally:
|
||||
if os.path.isdir(relative_folder_path):
|
||||
shutil.rmtree(relative_folder_path)
|
||||
|
|
|
@ -60,6 +60,7 @@ from openai.types.chat import (
|
|||
completion_create_params,
|
||||
)
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
|
||||
from openai.types.shared_params import FunctionDefinition, FunctionParameters
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Unpack
|
||||
|
@ -556,6 +557,31 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
"""
|
||||
Creates an AsyncGenerator that will yield a stream of chat completions based on the provided messages and tools.
|
||||
|
||||
Args:
|
||||
messages (Sequence[LLMMessage]): A sequence of messages to be processed.
|
||||
tools (Sequence[Tool | ToolSchema], optional): A sequence of tools to be used in the completion. Defaults to `[]`.
|
||||
json_output (Optional[bool], optional): If True, the output will be in JSON format. Defaults to None.
|
||||
extra_create_args (Mapping[str, Any], optional): Additional arguments for the creation process. Default to `{}`.
|
||||
cancellation_token (Optional[CancellationToken], optional): A token to cancel the operation. Defaults to None.
|
||||
|
||||
Yields:
|
||||
AsyncGenerator[Union[str, CreateResult], None]: A generator yielding the completion results as they are produced.
|
||||
|
||||
In streaming, the default behaviour is not return token usage counts. See: [OpenAI API reference for possible args](https://platform.openai.com/docs/api-reference/chat/create).
|
||||
However `extra_create_args={"stream_options": {"include_usage": True}}` will (if supported by the accessed API)
|
||||
return a final chunk with usage set to a RequestUsage object having prompt and completion token counts,
|
||||
all preceding chunks will have usage as None. See: [stream_options](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options).
|
||||
|
||||
Other examples of OPENAI supported arguments that can be included in `extra_create_args`:
|
||||
- `temperature` (float): Controls the randomness of the output. Higher values (e.g., 0.8) make the output more random, while lower values (e.g., 0.2) make it more focused and deterministic.
|
||||
- `max_tokens` (int): The maximum number of tokens to generate in the completion.
|
||||
- `top_p` (float): An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass.
|
||||
- `frequency_penalty` (float): A value between -2.0 and 2.0 that penalizes new tokens based on their existing frequency in the text so far, decreasing the likelihood of repeated phrases.
|
||||
- `presence_penalty` (float): A value between -2.0 and 2.0 that penalizes new tokens based on whether they appear in the text so far, encouraging the model to talk about new topics.
|
||||
"""
|
||||
# Make sure all extra_create_args are valid
|
||||
extra_create_args_keys = set(extra_create_args.keys())
|
||||
if not create_kwargs.issuperset(extra_create_args_keys):
|
||||
|
@ -602,7 +628,8 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||
if cancellation_token is not None:
|
||||
cancellation_token.link_future(stream_future)
|
||||
stream = await stream_future
|
||||
|
||||
choice: Union[ParsedChoice[Any], ParsedChoice[BaseModel], ChunkChoice] = cast(ChunkChoice, None)
|
||||
chunk = None
|
||||
stop_reason = None
|
||||
maybe_model = None
|
||||
content_deltas: List[str] = []
|
||||
|
@ -615,8 +642,23 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||
if cancellation_token is not None:
|
||||
cancellation_token.link_future(chunk_future)
|
||||
chunk = await chunk_future
|
||||
choice = chunk.choices[0]
|
||||
stop_reason = choice.finish_reason
|
||||
|
||||
# to process usage chunk in streaming situations
|
||||
# add stream_options={"include_usage": True} in the initialization of OpenAIChatCompletionClient(...)
|
||||
# However the different api's
|
||||
# OPENAI api usage chunk produces no choices so need to check if there is a choice
|
||||
# liteLLM api usage chunk does produce choices
|
||||
choice = (
|
||||
chunk.choices[0]
|
||||
if len(chunk.choices) > 0
|
||||
else choice
|
||||
if chunk.usage is not None and stop_reason is not None
|
||||
else cast(ChunkChoice, None)
|
||||
)
|
||||
|
||||
# for liteLLM chunk usage, do the following hack keeping the pervious chunk.stop_reason (if set).
|
||||
# set the stop_reason for the usage chunk to the prior stop_reason
|
||||
stop_reason = choice.finish_reason if chunk.usage is None and stop_reason is None else stop_reason
|
||||
maybe_model = chunk.model
|
||||
# First try get content
|
||||
if choice.delta.content is not None:
|
||||
|
@ -658,17 +700,21 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||
model = maybe_model or create_args["model"]
|
||||
model = model.replace("gpt-35", "gpt-3.5") # hack for Azure API
|
||||
|
||||
# TODO fix count token
|
||||
prompt_tokens = 0
|
||||
# prompt_tokens = count_token(messages, model=model)
|
||||
if chunk and chunk.usage:
|
||||
prompt_tokens = chunk.usage.prompt_tokens
|
||||
else:
|
||||
prompt_tokens = 0
|
||||
|
||||
if stop_reason is None:
|
||||
raise ValueError("No stop reason found")
|
||||
|
||||
content: Union[str, List[FunctionCall]]
|
||||
if len(content_deltas) > 1:
|
||||
content = "".join(content_deltas)
|
||||
completion_tokens = 0
|
||||
# completion_tokens = count_token(content, model=model)
|
||||
if chunk and chunk.usage:
|
||||
completion_tokens = chunk.usage.completion_tokens
|
||||
else:
|
||||
completion_tokens = 0
|
||||
else:
|
||||
completion_tokens = 0
|
||||
# TODO: fix assumption that dict values were added in order and actually order by int index
|
||||
|
|
|
@ -11,6 +11,7 @@ from autogen_core.components.models import (
|
|||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
RequestUsage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
@ -24,28 +25,83 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceD
|
|||
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MockChunkDefinition(BaseModel):
|
||||
# defining elements for diffentiating mocking chunks
|
||||
chunk_choice: ChunkChoice
|
||||
usage: CompletionUsage | None
|
||||
|
||||
|
||||
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
|
||||
model = resolve_model(kwargs.get("model", "gpt-4o"))
|
||||
chunks = ["Hello", " Another Hello", " Yet Another Hello"]
|
||||
for chunk in chunks:
|
||||
mock_chunks_content = ["Hello", " Another Hello", " Yet Another Hello"]
|
||||
|
||||
# The openai api implementations (OpenAI and Litellm) stream chunks of tokens
|
||||
# with content as string, and then at the end a token with stop set and finally if
|
||||
# usage requested with `"stream_options": {"include_usage": True}` a chunk with the usage data
|
||||
mock_chunks = [
|
||||
# generate the list of mock chunk content
|
||||
MockChunkDefinition(
|
||||
chunk_choice=ChunkChoice(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
content=mock_chunk_content,
|
||||
role="assistant",
|
||||
),
|
||||
),
|
||||
usage=None,
|
||||
)
|
||||
for mock_chunk_content in mock_chunks_content
|
||||
] + [
|
||||
# generate the stop chunk
|
||||
MockChunkDefinition(
|
||||
chunk_choice=ChunkChoice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
content=None,
|
||||
role="assistant",
|
||||
),
|
||||
),
|
||||
usage=None,
|
||||
)
|
||||
]
|
||||
# generate the usage chunk if configured
|
||||
if kwargs.get("stream_options", {}).get("include_usage") is True:
|
||||
mock_chunks = mock_chunks + [
|
||||
# ---- API differences
|
||||
# OPENAI API does NOT create a choice
|
||||
# LITELLM (proxy) DOES create a choice
|
||||
# Not simulating all the API options, just implementing the LITELLM variant
|
||||
MockChunkDefinition(
|
||||
chunk_choice=ChunkChoice(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
content=None,
|
||||
role="assistant",
|
||||
),
|
||||
),
|
||||
usage=CompletionUsage(prompt_tokens=3, completion_tokens=3, total_tokens=6),
|
||||
)
|
||||
]
|
||||
elif kwargs.get("stream_options", {}).get("include_usage") is False:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
for mock_chunk in mock_chunks:
|
||||
await asyncio.sleep(0.1)
|
||||
yield ChatCompletionChunk(
|
||||
id="id",
|
||||
choices=[
|
||||
ChunkChoice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
content=chunk,
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
choices=[mock_chunk.chunk_choice],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
usage=mock_chunk.usage,
|
||||
)
|
||||
|
||||
|
||||
|
@ -95,17 +151,64 @@ async def test_openai_chat_completion_client_create(monkeypatch: pytest.MonkeyPa
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_chat_completion_client_create_stream(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
async def test_openai_chat_completion_client_create_stream_with_usage(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
|
||||
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key")
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in client.create_stream(messages=[UserMessage(content="Hello", source="user")]):
|
||||
async for chunk in client.create_stream(
|
||||
messages=[UserMessage(content="Hello", source="user")],
|
||||
# include_usage not the default of the OPENAI API and must be explicitly set
|
||||
extra_create_args={"stream_options": {"include_usage": True}},
|
||||
):
|
||||
chunks.append(chunk)
|
||||
assert chunks[0] == "Hello"
|
||||
assert chunks[1] == " Another Hello"
|
||||
assert chunks[2] == " Yet Another Hello"
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
assert chunks[-1].content == "Hello Another Hello Yet Another Hello"
|
||||
assert chunks[-1].usage == RequestUsage(prompt_tokens=3, completion_tokens=3)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_chat_completion_client_create_stream_no_usage_default(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
|
||||
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key")
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in client.create_stream(
|
||||
messages=[UserMessage(content="Hello", source="user")],
|
||||
# include_usage not the default of the OPENAI APIis ,
|
||||
# it can be explicitly set
|
||||
# or just not declared which is the default
|
||||
# extra_create_args={"stream_options": {"include_usage": False}},
|
||||
):
|
||||
chunks.append(chunk)
|
||||
assert chunks[0] == "Hello"
|
||||
assert chunks[1] == " Another Hello"
|
||||
assert chunks[2] == " Yet Another Hello"
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
assert chunks[-1].content == "Hello Another Hello Yet Another Hello"
|
||||
assert chunks[-1].usage == RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_chat_completion_client_create_stream_no_usage_explicit(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
|
||||
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="api_key")
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in client.create_stream(
|
||||
messages=[UserMessage(content="Hello", source="user")],
|
||||
# include_usage is not the default of the OPENAI API ,
|
||||
# it can be explicitly set
|
||||
# or just not declared which is the default
|
||||
extra_create_args={"stream_options": {"include_usage": False}},
|
||||
):
|
||||
chunks.append(chunk)
|
||||
assert chunks[0] == "Hello"
|
||||
assert chunks[1] == " Another Hello"
|
||||
assert chunks[2] == " Yet Another Hello"
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
assert chunks[-1].content == "Hello Another Hello Yet Another Hello"
|
||||
assert chunks[-1].usage == RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
Loading…
Reference in New Issue