diff --git a/dotnet/src/Microsoft.AutoGen/Agents/GrpcAgentWorkerRuntime.cs b/dotnet/src/Microsoft.AutoGen/Agents/GrpcAgentWorkerRuntime.cs index c52509876f..193f9dd2b6 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/GrpcAgentWorkerRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/GrpcAgentWorkerRuntime.cs @@ -19,7 +19,7 @@ public sealed class GrpcAgentWorkerRuntime : IHostedService, IDisposable, IAgent private readonly ConcurrentDictionary _agentTypes = new(); private readonly ConcurrentDictionary<(string Type, string Key), IAgentBase> _agents = new(); private readonly ConcurrentDictionary _pendingRequests = new(); - private readonly Channel _outboundMessagesChannel = Channel.CreateBounded(new BoundedChannelOptions(1024) + private readonly Channel<(Message Message, TaskCompletionSource WriteCompletionSource)> _outboundMessagesChannel = Channel.CreateBounded<(Message, TaskCompletionSource)>(new BoundedChannelOptions(1024) { AllowSynchronousContinuations = true, SingleReader = true, @@ -138,30 +138,34 @@ public sealed class GrpcAgentWorkerRuntime : IHostedService, IDisposable, IAgent var outboundMessages = _outboundMessagesChannel.Reader; while (!_shutdownCts.IsCancellationRequested) { + (Message Message, TaskCompletionSource WriteCompletionSource) item = default; try { await outboundMessages.WaitToReadAsync().ConfigureAwait(false); // Read the next message if we don't already have an unsent message // waiting to be sent. - if (!outboundMessages.TryRead(out var message)) + if (!outboundMessages.TryRead(out item)) { break; } while (!_shutdownCts.IsCancellationRequested) { - await channel.RequestStream.WriteAsync(message, _shutdownCts.Token).ConfigureAwait(false); + await channel.RequestStream.WriteAsync(item.Message, _shutdownCts.Token).ConfigureAwait(false); + item.WriteCompletionSource.TrySetResult(); break; } } catch (OperationCanceledException) { // Time to shut down. + item.WriteCompletionSource?.TrySetCanceled(); break; } catch (Exception ex) when (!_shutdownCts.IsCancellationRequested) { + item.WriteCompletionSource?.TrySetException(ex); _logger.LogError(ex, "Error writing to channel."); channel = RecreateChannel(channel); continue; @@ -169,9 +173,15 @@ public sealed class GrpcAgentWorkerRuntime : IHostedService, IDisposable, IAgent catch { // Shutdown requested. + item.WriteCompletionSource?.TrySetCanceled(); break; } } + + while (outboundMessages.TryRead(out var item)) + { + item.WriteCompletionSource.TrySetCanceled(); + } } private IAgentBase GetOrActivateAgent(AgentId agentId) @@ -213,7 +223,8 @@ public sealed class GrpcAgentWorkerRuntime : IHostedService, IDisposable, IAgent //StateType = state?.Name, //Events = { events } } - }).ConfigureAwait(false); + }, + _shutdownCts.Token).ConfigureAwait(false); } } @@ -229,17 +240,36 @@ public sealed class GrpcAgentWorkerRuntime : IHostedService, IDisposable, IAgent var requestId = Guid.NewGuid().ToString(); _pendingRequests[requestId] = (agent, request.RequestId); request.RequestId = requestId; - await WriteChannelAsync(new Message { Request = request }).ConfigureAwait(false); + try + { + await WriteChannelAsync(new Message { Request = request }).ConfigureAwait(false); + } + catch (Exception exception) + { + if (_pendingRequests.TryRemove(requestId, out _)) + { + agent.ReceiveMessage(new Message { Response = new RpcResponse { RequestId = request.RequestId, Error = exception.Message } }); + } + } } public async ValueTask PublishEvent(CloudEvent @event) { - await WriteChannelAsync(new Message { CloudEvent = @event }).ConfigureAwait(false); + try + { + await WriteChannelAsync(new Message { CloudEvent = @event }).ConfigureAwait(false); + } + catch (Exception exception) + { + _logger.LogWarning(exception, "Failed to publish event '{Event}'.", @event); + } } - private async Task WriteChannelAsync(Message message) + private async Task WriteChannelAsync(Message message, CancellationToken cancellationToken = default) { - await _outboundMessagesChannel.Writer.WriteAsync(message).ConfigureAwait(false); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + await _outboundMessagesChannel.Writer.WriteAsync((message, tcs), cancellationToken).ConfigureAwait(false); + await tcs.Task.WaitAsync(cancellationToken); } private AsyncDuplexStreamingCall GetChannel() @@ -269,7 +299,7 @@ public sealed class GrpcAgentWorkerRuntime : IHostedService, IDisposable, IAgent if (_channel is null || _channel == channel) { _channel?.Dispose(); - _channel = _client.OpenChannel(); + _channel = _client.OpenChannel(cancellationToken: _shutdownCts.Token); } } }