Wait for acknowledgment when sending message to gRPC channel (#4034)

This commit is contained in:
Reuben Bond 2024-11-01 12:59:50 -07:00 committed by GitHub
parent c3b2597e12
commit a4901f3ba8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 39 additions and 9 deletions

View File

@ -19,7 +19,7 @@ public sealed class GrpcAgentWorkerRuntime : IHostedService, IDisposable, IAgent
private readonly ConcurrentDictionary<string, Type> _agentTypes = new(); private readonly ConcurrentDictionary<string, Type> _agentTypes = new();
private readonly ConcurrentDictionary<(string Type, string Key), IAgentBase> _agents = new(); private readonly ConcurrentDictionary<(string Type, string Key), IAgentBase> _agents = new();
private readonly ConcurrentDictionary<string, (IAgentBase Agent, string OriginalRequestId)> _pendingRequests = new(); private readonly ConcurrentDictionary<string, (IAgentBase Agent, string OriginalRequestId)> _pendingRequests = new();
private readonly Channel<Message> _outboundMessagesChannel = Channel.CreateBounded<Message>(new BoundedChannelOptions(1024) private readonly Channel<(Message Message, TaskCompletionSource WriteCompletionSource)> _outboundMessagesChannel = Channel.CreateBounded<(Message, TaskCompletionSource)>(new BoundedChannelOptions(1024)
{ {
AllowSynchronousContinuations = true, AllowSynchronousContinuations = true,
SingleReader = true, SingleReader = true,
@ -138,30 +138,34 @@ public sealed class GrpcAgentWorkerRuntime : IHostedService, IDisposable, IAgent
var outboundMessages = _outboundMessagesChannel.Reader; var outboundMessages = _outboundMessagesChannel.Reader;
while (!_shutdownCts.IsCancellationRequested) while (!_shutdownCts.IsCancellationRequested)
{ {
(Message Message, TaskCompletionSource WriteCompletionSource) item = default;
try try
{ {
await outboundMessages.WaitToReadAsync().ConfigureAwait(false); await outboundMessages.WaitToReadAsync().ConfigureAwait(false);
// Read the next message if we don't already have an unsent message // Read the next message if we don't already have an unsent message
// waiting to be sent. // waiting to be sent.
if (!outboundMessages.TryRead(out var message)) if (!outboundMessages.TryRead(out item))
{ {
break; break;
} }
while (!_shutdownCts.IsCancellationRequested) while (!_shutdownCts.IsCancellationRequested)
{ {
await channel.RequestStream.WriteAsync(message, _shutdownCts.Token).ConfigureAwait(false); await channel.RequestStream.WriteAsync(item.Message, _shutdownCts.Token).ConfigureAwait(false);
item.WriteCompletionSource.TrySetResult();
break; break;
} }
} }
catch (OperationCanceledException) catch (OperationCanceledException)
{ {
// Time to shut down. // Time to shut down.
item.WriteCompletionSource?.TrySetCanceled();
break; break;
} }
catch (Exception ex) when (!_shutdownCts.IsCancellationRequested) catch (Exception ex) when (!_shutdownCts.IsCancellationRequested)
{ {
item.WriteCompletionSource?.TrySetException(ex);
_logger.LogError(ex, "Error writing to channel."); _logger.LogError(ex, "Error writing to channel.");
channel = RecreateChannel(channel); channel = RecreateChannel(channel);
continue; continue;
@ -169,9 +173,15 @@ public sealed class GrpcAgentWorkerRuntime : IHostedService, IDisposable, IAgent
catch catch
{ {
// Shutdown requested. // Shutdown requested.
item.WriteCompletionSource?.TrySetCanceled();
break; break;
} }
} }
while (outboundMessages.TryRead(out var item))
{
item.WriteCompletionSource.TrySetCanceled();
}
} }
private IAgentBase GetOrActivateAgent(AgentId agentId) private IAgentBase GetOrActivateAgent(AgentId agentId)
@ -213,7 +223,8 @@ public sealed class GrpcAgentWorkerRuntime : IHostedService, IDisposable, IAgent
//StateType = state?.Name, //StateType = state?.Name,
//Events = { events } //Events = { events }
} }
}).ConfigureAwait(false); },
_shutdownCts.Token).ConfigureAwait(false);
} }
} }
@ -229,17 +240,36 @@ public sealed class GrpcAgentWorkerRuntime : IHostedService, IDisposable, IAgent
var requestId = Guid.NewGuid().ToString(); var requestId = Guid.NewGuid().ToString();
_pendingRequests[requestId] = (agent, request.RequestId); _pendingRequests[requestId] = (agent, request.RequestId);
request.RequestId = requestId; request.RequestId = requestId;
await WriteChannelAsync(new Message { Request = request }).ConfigureAwait(false); try
{
await WriteChannelAsync(new Message { Request = request }).ConfigureAwait(false);
}
catch (Exception exception)
{
if (_pendingRequests.TryRemove(requestId, out _))
{
agent.ReceiveMessage(new Message { Response = new RpcResponse { RequestId = request.RequestId, Error = exception.Message } });
}
}
} }
public async ValueTask PublishEvent(CloudEvent @event) public async ValueTask PublishEvent(CloudEvent @event)
{ {
await WriteChannelAsync(new Message { CloudEvent = @event }).ConfigureAwait(false); try
{
await WriteChannelAsync(new Message { CloudEvent = @event }).ConfigureAwait(false);
}
catch (Exception exception)
{
_logger.LogWarning(exception, "Failed to publish event '{Event}'.", @event);
}
} }
private async Task WriteChannelAsync(Message message) private async Task WriteChannelAsync(Message message, CancellationToken cancellationToken = default)
{ {
await _outboundMessagesChannel.Writer.WriteAsync(message).ConfigureAwait(false); var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
await _outboundMessagesChannel.Writer.WriteAsync((message, tcs), cancellationToken).ConfigureAwait(false);
await tcs.Task.WaitAsync(cancellationToken);
} }
private AsyncDuplexStreamingCall<Message, Message> GetChannel() private AsyncDuplexStreamingCall<Message, Message> GetChannel()
@ -269,7 +299,7 @@ public sealed class GrpcAgentWorkerRuntime : IHostedService, IDisposable, IAgent
if (_channel is null || _channel == channel) if (_channel is null || _channel == channel)
{ {
_channel?.Dispose(); _channel?.Dispose();
_channel = _client.OpenChannel(); _channel = _client.OpenChannel(cancellationToken: _shutdownCts.Token);
} }
} }
} }