Add connection poisoning to aws-smithy-client (#2445)

* Add Connection Poisoning to aws-smithy-client

* Fix doc links

* Remove required tokio dependency from aws-smithy-client

* Remove external type exposed

* Rename, re-add tokio dependency

* Change IP to 127.0.0.1 to attempt to fix windows

* Add dns::Name to external types

* Remove non_exhaustive not needed

* Add client target to changelog
This commit is contained in:
Russell Cohen 2023-03-14 16:08:35 -04:00 committed by GitHub
parent b2c5eaa328
commit 61934da044
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 1289 additions and 103 deletions

View File

@ -287,3 +287,28 @@ message = "The modules in generated client crates have been reorganized. See the
references = ["smithy-rs#2448"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "jdisanti"
[[aws-sdk-rust]]
message = """Reconnect on transient errors.
If a transient error (timeout, 500, 503, 503) is encountered, the connection will be evicted from the pool and will not
be reused. This is enabled by default for all AWS services. It can be disabled by setting `RetryConfig::with_reconnect_mode`
Although there is no API breakage from this change, it alters the client behavior in a way that may cause breakage for customers.
"""
references = ["aws-sdk-rust#160", "smithy-rs#2445"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "rcoh"
[[smithy-rs]]
message = """Reconnect on transient errors.
Note: **this behavior is disabled by default for generic clients**. It can be enabled with
`aws_smithy_client::Builder::reconnect_on_transient_errors`
If a transient error (timeout, 500, 503, 503) is encountered, the connection will be evicted from the pool and will not
be reused.
"""
references = ["aws-sdk-rust#160", "smithy-rs#2445"]
meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client" }
author = "rcoh"

View File

@ -30,7 +30,7 @@ aws-smithy-types = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-types" }
aws-types = { path = "../../sdk/build/aws-sdk/sdk/aws-types" }
hyper = { version = "0.14.12", default-features = false }
time = { version = "0.3.4", features = ["parsing"] }
tokio = { version = "1.8.4", features = ["sync"] }
tokio = { version = "1.13.1", features = ["sync"] }
tracing = { version = "0.1" }
# implementation detail of SSO credential caching

View File

@ -208,6 +208,7 @@ private class AwsFluentClientExtensions(types: Types) {
};
let mut builder = builder
.middleware(#{DynMiddleware}::new(#{Middleware}::new()))
.reconnect_mode(retry_config.reconnect_mode())
.retry_config(retry_config.into())
.operation_timeout_config(timeout_config.into());
builder.set_sleep_impl(sleep_impl);
@ -257,6 +258,7 @@ private fun renderCustomizableOperationSendMethod(
"combined_generics_decl" to combinedGenerics.declaration(),
"handle_generics_bounds" to handleGenerics.bounds(),
"SdkSuccess" to RuntimeType.sdkSuccess(runtimeConfig),
"SdkError" to RuntimeType.sdkError(runtimeConfig),
"ClassifyRetry" to RuntimeType.classifyRetry(runtimeConfig),
"ParseHttpResponse" to RuntimeType.parseHttpResponse(runtimeConfig),
)
@ -272,7 +274,7 @@ private fun renderCustomizableOperationSendMethod(
where
E: std::error::Error + Send + Sync + 'static,
O: #{ParseHttpResponse}<Output = Result<T, E>> + Send + Sync + Clone + 'static,
Retry: #{ClassifyRetry}<#{SdkSuccess}<T>, SdkError<E>> + Send + Sync + Clone,
Retry: #{ClassifyRetry}<#{SdkSuccess}<T>, #{SdkError}<E>> + Send + Sync + Clone,
{
self.handle.client.call(self.operation).await
}

View File

@ -0,0 +1,99 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
use aws_credential_types::provider::SharedCredentialsProvider;
use aws_credential_types::Credentials;
use aws_smithy_async::rt::sleep::TokioSleep;
use aws_smithy_client::test_connection::wire_mock::{
check_matches, ReplayedEvent, WireLevelTestConnection,
};
use aws_smithy_client::{ev, match_events};
use aws_smithy_types::retry::{ReconnectMode, RetryConfig};
use aws_types::region::Region;
use aws_types::SdkConfig;
use std::sync::Arc;
#[tokio::test]
/// test that disabling reconnects on retry config disables them for the client
async fn disable_reconnects() {
let mock = WireLevelTestConnection::spinup(vec![
ReplayedEvent::status(503),
ReplayedEvent::status(503),
ReplayedEvent::with_body("here-is-your-object"),
])
.await;
let sdk_config = SdkConfig::builder()
.region(Region::from_static("us-east-2"))
.credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests()))
.sleep_impl(Arc::new(TokioSleep::new()))
.endpoint_url(mock.endpoint_url())
.http_connector(mock.http_connector())
.retry_config(
RetryConfig::standard().with_reconnect_mode(ReconnectMode::ReuseAllConnections),
)
.build();
let client = aws_sdk_s3::Client::new(&sdk_config);
let resp = client
.get_object()
.bucket("bucket")
.key("key")
.send()
.await
.expect("succeeds after retries");
assert_eq!(
resp.body.collect().await.unwrap().to_vec(),
b"here-is-your-object"
);
match_events!(
ev!(dns),
ev!(connect),
ev!(http(503)),
ev!(http(503)),
ev!(http(200))
)(&mock.events());
}
#[tokio::test]
async fn reconnect_on_503() {
let mock = WireLevelTestConnection::spinup(vec![
ReplayedEvent::status(503),
ReplayedEvent::status(503),
ReplayedEvent::with_body("here-is-your-object"),
])
.await;
let sdk_config = SdkConfig::builder()
.region(Region::from_static("us-east-2"))
.credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests()))
.sleep_impl(Arc::new(TokioSleep::new()))
.endpoint_url(mock.endpoint_url())
.http_connector(mock.http_connector())
.retry_config(RetryConfig::standard())
.build();
let client = aws_sdk_s3::Client::new(&sdk_config);
let resp = client
.get_object()
.bucket("bucket")
.key("key")
.send()
.await
.expect("succeeds after retries");
assert_eq!(
resp.body.collect().await.unwrap().to_vec(),
b"here-is-your-object"
);
match_events!(
ev!(dns),
ev!(connect),
ev!(http(503)),
ev!(dns),
ev!(connect),
ev!(http(503)),
ev!(dns),
ev!(connect),
ev!(http(200))
)(&mock.events());
}

View File

@ -41,7 +41,7 @@ class CustomizableOperationGenerator(
"Operation" to smithyHttp.resolve("operation::Operation"),
"Request" to smithyHttp.resolve("operation::Request"),
"Response" to smithyHttp.resolve("operation::Response"),
"ClassifyRetry" to smithyHttp.resolve("retry::ClassifyRetry"),
"ClassifyRetry" to RuntimeType.classifyRetry(runtimeConfig),
"RetryKind" to smithyTypes.resolve("retry::RetryKind"),
)
renderCustomizableOperationModule(this)
@ -150,6 +150,9 @@ class CustomizableOperationGenerator(
"ParseHttpResponse" to smithyHttp.resolve("response::ParseHttpResponse"),
"NewRequestPolicy" to smithyClient.resolve("retry::NewRequestPolicy"),
"SmithyRetryPolicy" to smithyClient.resolve("bounds::SmithyRetryPolicy"),
"ClassifyRetry" to RuntimeType.classifyRetry(runtimeConfig),
"SdkSuccess" to RuntimeType.sdkSuccess(runtimeConfig),
"SdkError" to RuntimeType.sdkError(runtimeConfig),
)
writer.rustTemplate(
@ -164,6 +167,7 @@ class CustomizableOperationGenerator(
E: std::error::Error + Send + Sync + 'static,
O: #{ParseHttpResponse}<Output = Result<T, E>> + Send + Sync + Clone + 'static,
Retry: Send + Sync + Clone,
Retry: #{ClassifyRetry}<#{SdkSuccess}<T>, #{SdkError}<E>> + Send + Sync + Clone,
<R as #{NewRequestPolicy}>::Policy: #{SmithyRetryPolicy}<O, T, E, Retry> + Clone,
{
self.handle.client.call(self.operation).await

View File

@ -1,5 +1,6 @@
[workspace]
members = [
"inlineable",
"aws-smithy-async",

View File

@ -9,12 +9,13 @@ repository = "https://github.com/awslabs/smithy-rs"
[features]
rt-tokio = ["aws-smithy-async/rt-tokio"]
test-util = ["aws-smithy-protocol-test", "serde/derive", "rustls"]
test-util = ["aws-smithy-protocol-test", "serde/derive", "rustls", "hyper/server", "hyper/h2", "tokio/full"]
native-tls = ["client-hyper", "hyper-tls", "rt-tokio"]
rustls = ["client-hyper", "hyper-rustls", "rt-tokio", "lazy_static"]
client-hyper = ["hyper"]
hyper-webpki-doctest-only = ["hyper-rustls/webpki-roots"]
[dependencies]
aws-smithy-async = { path = "../aws-smithy-async" }
aws-smithy-http = { path = "../aws-smithy-http" }
@ -25,7 +26,7 @@ bytes = "1"
fastrand = "1.4.0"
http = "0.2.3"
http-body = "0.4.4"
hyper = { version = "0.14.12", features = ["client", "http2", "http1", "tcp"], optional = true }
hyper = { version = "0.14.25", features = ["client", "http2", "http1", "tcp"], optional = true }
# cargo does not support optional test dependencies, so to completely disable rustls when
# the native-tls feature is enabled, we need to add the webpki-roots feature here.
# https://github.com/rust-lang/cargo/issues/1596
@ -34,7 +35,7 @@ hyper-tls = { version = "0.5.0", optional = true }
lazy_static = { version = "1", optional = true }
pin-project-lite = "0.2.7"
serde = { version = "1", features = ["derive"], optional = true }
tokio = { version = "1.8.4" }
tokio = { version = "1.13.1" }
tower = { version = "0.4.6", features = ["util", "retry"] }
tracing = "0.1"
@ -44,6 +45,9 @@ serde = { version = "1", features = ["derive"] }
serde_json = "1"
tokio = { version = "1.8.4", features = ["full", "test-util"] }
tower-test = "0.4.0"
tracing-subscriber = "0.3.16"
tracing-test = "0.2.4"
[package.metadata.docs.rs]
all-features = true

View File

@ -21,10 +21,12 @@ allowed_external_types = [
"tokio::io::async_read::AsyncRead",
"tokio::io::async_write::AsyncWrite",
# TODO(https://github.com/awslabs/smithy-rs/issues/1193): Once tooling permits it, only allow the following types in the `test-utils` feature
"bytes::bytes::Bytes",
"serde::ser::Serialize",
"serde::de::Deserialize",
"hyper::client::connect::dns::Name",
# TODO(https://github.com/awslabs/smithy-rs/issues/1193): Decide if we want to continue exposing tower_layer
"tower_layer::Layer",

View File

@ -7,6 +7,7 @@ use crate::{bounds, erase, retry, Client};
use aws_smithy_async::rt::sleep::{default_async_sleep, AsyncSleep};
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::result::ConnectorError;
use aws_smithy_types::retry::ReconnectMode;
use aws_smithy_types::timeout::{OperationTimeoutConfig, TimeoutConfig};
use std::sync::Arc;
@ -37,6 +38,12 @@ pub struct Builder<C = (), M = (), R = retry::Standard> {
retry_policy: MaybeRequiresSleep<R>,
operation_timeout_config: Option<OperationTimeoutConfig>,
sleep_impl: Option<Arc<dyn AsyncSleep>>,
reconnect_mode: Option<ReconnectMode>,
}
/// transitional default: disable this behavior by default
const fn default_reconnect_mode() -> ReconnectMode {
ReconnectMode::ReuseAllConnections
}
impl<C, M> Default for Builder<C, M>
@ -55,6 +62,7 @@ where
),
operation_timeout_config: None,
sleep_impl: default_async_sleep(),
reconnect_mode: Some(default_reconnect_mode()),
}
}
}
@ -173,6 +181,7 @@ impl<M, R> Builder<(), M, R> {
retry_policy: self.retry_policy,
operation_timeout_config: self.operation_timeout_config,
sleep_impl: self.sleep_impl,
reconnect_mode: self.reconnect_mode,
}
}
@ -229,6 +238,7 @@ impl<C, R> Builder<C, (), R> {
operation_timeout_config: self.operation_timeout_config,
middleware,
sleep_impl: self.sleep_impl,
reconnect_mode: self.reconnect_mode,
}
}
@ -280,6 +290,7 @@ impl<C, M> Builder<C, M, retry::Standard> {
operation_timeout_config: self.operation_timeout_config,
middleware: self.middleware,
sleep_impl: self.sleep_impl,
reconnect_mode: self.reconnect_mode,
}
}
}
@ -347,6 +358,7 @@ impl<C, M, R> Builder<C, M, R> {
retry_policy: self.retry_policy,
operation_timeout_config: self.operation_timeout_config,
sleep_impl: self.sleep_impl,
reconnect_mode: self.reconnect_mode,
}
}
@ -361,9 +373,41 @@ impl<C, M, R> Builder<C, M, R> {
retry_policy: self.retry_policy,
operation_timeout_config: self.operation_timeout_config,
sleep_impl: self.sleep_impl,
reconnect_mode: self.reconnect_mode,
}
}
/// Set the [`ReconnectMode`] for the retry strategy
///
/// By default, no reconnection occurs.
///
/// When enabled and a transient error is encountered, the connection in use will be poisoned.
/// This prevents reusing a connection to a potentially bad host.
pub fn reconnect_mode(mut self, reconnect_mode: ReconnectMode) -> Self {
self.set_reconnect_mode(Some(reconnect_mode));
self
}
/// Set the [`ReconnectMode`] for the retry strategy
///
/// By default, no reconnection occurs.
///
/// When enabled and a transient error is encountered, the connection in use will be poisoned.
/// This prevents reusing a connection to a potentially bad host.
pub fn set_reconnect_mode(&mut self, reconnect_mode: Option<ReconnectMode>) -> &mut Self {
self.reconnect_mode = reconnect_mode;
self
}
/// Enable reconnection on transient errors
///
/// By default, when a transient error is encountered, the connection in use will be poisoned.
/// This prevents reusing a connection to a potentially bad host but may increase the load on
/// the server.
pub fn reconnect_on_transient_errors(self) -> Self {
self.reconnect_mode(ReconnectMode::ReconnectOnTransientError)
}
/// Build a Smithy service [`Client`].
pub fn build(self) -> Client<C, M, R> {
let operation_timeout_config = self
@ -392,6 +436,7 @@ impl<C, M, R> Builder<C, M, R> {
middleware: self.middleware,
operation_timeout_config,
sleep_impl: self.sleep_impl,
reconnect_mode: self.reconnect_mode.unwrap_or(default_reconnect_mode()),
}
}
}

View File

@ -61,6 +61,7 @@ where
retry_policy: self.retry_policy,
operation_timeout_config: self.operation_timeout_config,
sleep_impl: self.sleep_impl,
reconnect_mode: self.reconnect_mode,
}
}
}
@ -101,6 +102,7 @@ where
retry_policy: self.retry_policy,
operation_timeout_config: self.operation_timeout_config,
sleep_impl: self.sleep_impl,
reconnect_mode: self.reconnect_mode,
}
}

View File

@ -92,13 +92,22 @@ use crate::never::stream::EmptyStream;
use aws_smithy_async::future::timeout::TimedOutError;
use aws_smithy_async::rt::sleep::{default_async_sleep, AsyncSleep};
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::result::ConnectorError;
use aws_smithy_types::error::display::DisplayErrorContext;
use aws_smithy_types::retry::ErrorKind;
use http::Uri;
use hyper::client::connect::{Connected, Connection};
use http::{Extensions, Uri};
use hyper::client::connect::{
capture_connection, CaptureConnection, Connected, Connection, HttpInfo,
};
use std::error::Error;
use std::fmt::Debug;
use std::sync::Arc;
use crate::erase::boxclone::BoxFuture;
use aws_smithy_http::connection::{CaptureSmithyConnection, ConnectionMetadata};
use tokio::io::{AsyncRead, AsyncWrite};
use tower::{BoxError, Service};
@ -107,8 +116,30 @@ use tower::{BoxError, Service};
/// This adapter also enables TCP `CONNECT` and HTTP `READ` timeouts via [`Adapter::builder`]. For examples
/// see [the module documentation](crate::hyper_ext).
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct Adapter<C>(HttpReadTimeout<hyper::Client<ConnectTimeout<C>, SdkBody>>);
pub struct Adapter<C> {
client: HttpReadTimeout<hyper::Client<ConnectTimeout<C>, SdkBody>>,
}
/// Extract a smithy connection from a hyper CaptureConnection
fn extract_smithy_connection(capture_conn: &CaptureConnection) -> Option<ConnectionMetadata> {
let capture_conn = capture_conn.clone();
if let Some(conn) = capture_conn.clone().connection_metadata().as_ref() {
let mut extensions = Extensions::new();
conn.get_extras(&mut extensions);
let http_info = extensions.get::<HttpInfo>();
let smithy_connection = ConnectionMetadata::new(
conn.is_proxied(),
http_info.map(|info| info.remote_addr()),
move || match capture_conn.connection_metadata().as_ref() {
Some(conn) => conn.poison(),
None => tracing::trace!("no connection existed to poison"),
},
);
Some(smithy_connection)
} else {
None
}
}
impl<C> Service<http::Request<SdkBody>> for Adapter<C>
where
@ -121,20 +152,22 @@ where
type Response = http::Response<SdkBody>;
type Error = ConnectorError;
#[allow(clippy::type_complexity)]
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>,
>;
type Future = BoxFuture<Self::Response, Self::Error>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx).map_err(downcast_error)
self.client.poll_ready(cx).map_err(downcast_error)
}
fn call(&mut self, req: http::Request<SdkBody>) -> Self::Future {
let fut = self.0.call(req);
fn call(&mut self, mut req: http::Request<SdkBody>) -> Self::Future {
let capture_connection = capture_connection(&mut req);
if let Some(capture_smithy_connection) = req.extensions().get::<CaptureSmithyConnection>() {
capture_smithy_connection
.set_connection_retriever(move || extract_smithy_connection(&capture_connection));
}
let fut = self.client.call(req);
Box::pin(async move { Ok(fut.await.map_err(downcast_error)?.map(SdkBody::from)) })
}
}
@ -271,7 +304,9 @@ impl Builder {
),
None => HttpReadTimeout::no_timeout(base),
};
Adapter(read_timeout)
Adapter {
client: read_timeout,
}
}
/// Set the async sleep implementation used for timeouts
@ -343,7 +378,6 @@ mod timeout_middleware {
use pin_project_lite::pin_project;
use tower::BoxError;
use aws_smithy_async::future;
use aws_smithy_async::future::timeout::{TimedOutError, Timeout};
use aws_smithy_async::rt::sleep::AsyncSleep;
use aws_smithy_async::rt::sleep::Sleep;
@ -493,7 +527,7 @@ mod timeout_middleware {
Some((sleep, duration)) => {
let sleep = sleep.sleep(*duration);
MaybeTimeoutFuture::Timeout {
timeout: future::timeout::Timeout::new(self.inner.call(req), sleep),
timeout: Timeout::new(self.inner.call(req), sleep),
error_type: "HTTP connect",
duration: *duration,
}
@ -522,7 +556,7 @@ mod timeout_middleware {
Some((sleep, duration)) => {
let sleep = sleep.sleep(*duration);
MaybeTimeoutFuture::Timeout {
timeout: future::timeout::Timeout::new(self.inner.call(req), sleep),
timeout: Timeout::new(self.inner.call(req), sleep),
error_type: "HTTP read",
duration: *duration,
}

View File

@ -26,6 +26,7 @@ pub mod bounds;
pub mod erase;
pub mod http_connector;
pub mod never;
mod poison;
pub mod retry;
pub mod timeout;
@ -50,14 +51,17 @@ pub mod hyper_ext;
#[doc(hidden)]
pub mod static_tests;
use crate::poison::PoisonLayer;
use aws_smithy_async::rt::sleep::AsyncSleep;
use aws_smithy_http::operation::Operation;
use aws_smithy_http::response::ParseHttpResponse;
pub use aws_smithy_http::result::{SdkError, SdkSuccess};
use aws_smithy_http::retry::ClassifyRetry;
use aws_smithy_http_tower::dispatch::DispatchLayer;
use aws_smithy_http_tower::parse_response::ParseResponseLayer;
use aws_smithy_types::error::display::DisplayErrorContext;
use aws_smithy_types::retry::ProvideErrorKind;
use aws_smithy_types::retry::{ProvideErrorKind, ReconnectMode};
use aws_smithy_types::timeout::OperationTimeoutConfig;
use std::sync::Arc;
use timeout::ClientTimeoutParams;
@ -93,6 +97,7 @@ pub struct Client<
connector: Connector,
middleware: Middleware,
retry_policy: RetryPolicy,
reconnect_mode: ReconnectMode,
operation_timeout_config: OperationTimeoutConfig,
sleep_impl: Option<Arc<dyn AsyncSleep>>,
}
@ -140,6 +145,7 @@ where
E: std::error::Error + Send + Sync + 'static,
Retry: Send + Sync,
R::Policy: bounds::SmithyRetryPolicy<O, T, E, Retry>,
Retry: ClassifyRetry<SdkSuccess<T>, SdkError<E>>,
bounds::Parsed<<M as bounds::SmithyMiddleware<C>>::Service, O, Retry>:
Service<Operation<O, Retry>, Response = SdkSuccess<T>, Error = SdkError<E>> + Clone,
{
@ -159,6 +165,7 @@ where
E: std::error::Error + Send + Sync + 'static,
Retry: Send + Sync,
R::Policy: bounds::SmithyRetryPolicy<O, T, E, Retry>,
Retry: ClassifyRetry<SdkSuccess<T>, SdkError<E>>,
// This bound is not _technically_ inferred by all the previous bounds, but in practice it
// is because _we_ know that there is only implementation of Service for Parsed
// (ParsedResponseService), and it will apply as long as the bounds on C, M, and R hold,
@ -179,6 +186,7 @@ where
self.retry_policy
.new_request_policy(self.sleep_impl.clone()),
)
.layer(PoisonLayer::new(self.reconnect_mode))
.layer(TimeoutLayer::new(timeout_params.operation_attempt_timeout))
.layer(ParseResponseLayer::<O, Retry>::new())
// These layers can be considered as occurring in order. That is, first invoke the

View File

@ -0,0 +1,143 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
//! Connection Poisoning
//!
//! The client supports behavior where on transient errors (e.g. timeouts, 503, etc.) it will ensure
//! that the offending connection is not reused. This happens to ensure that in the case where the
//! connection itself is broken (e.g. connected to a bad host) we don't reuse it for other requests.
//!
//! This relies on a series of mechanisms:
//! 1. [`CaptureSmithyConnection`] is a container which exists in the operation property bag. It is
//! inserted by this layer before the request is sent.
//! 2. The [`DispatchLayer`](aws_smithy_http_tower::dispatch::DispatchLayer) copies the field from operation extensions HTTP request extensions.
//! 3. The HTTP layer (e.g. Hyper) sets [`ConnectionMetadata`](aws_smithy_http::connection::ConnectionMetadata)
//! when it is available.
//! 4. When the response comes back, if indicated, this layer invokes
//! [`ConnectionMetadata::poison`](aws_smithy_http::connection::ConnectionMetadata::poison).
//!
//! ### Why isn't this integrated into `retry.rs`?
//! If the request has a streaming body, we won't attempt to retry because [`Operation::try_clone()`] will
//! return `None`. Therefore, we need to handle this inside of the retry loop.
use std::future::Future;
use aws_smithy_http::operation::Operation;
use aws_smithy_http::result::{SdkError, SdkSuccess};
use aws_smithy_http::retry::ClassifyRetry;
use aws_smithy_http::connection::CaptureSmithyConnection;
use aws_smithy_types::retry::{ErrorKind, ReconnectMode, RetryKind};
use pin_project_lite::pin_project;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
/// PoisonLayer that poisons connections depending on the error kind
pub(crate) struct PoisonLayer<S> {
inner: PhantomData<S>,
mode: ReconnectMode,
}
impl<S> PoisonLayer<S> {
pub(crate) fn new(mode: ReconnectMode) -> Self {
Self {
inner: Default::default(),
mode,
}
}
}
impl<S> Clone for PoisonLayer<S> {
fn clone(&self) -> Self {
Self {
inner: Default::default(),
mode: self.mode,
}
}
}
impl<S> tower::Layer<S> for PoisonLayer<S> {
type Service = PoisonService<S>;
fn layer(&self, inner: S) -> Self::Service {
PoisonService {
inner,
mode: self.mode,
}
}
}
#[derive(Clone)]
pub(crate) struct PoisonService<S> {
inner: S,
mode: ReconnectMode,
}
impl<H, R, S, O, E> tower::Service<Operation<H, R>> for PoisonService<S>
where
R: ClassifyRetry<SdkSuccess<O>, SdkError<E>>,
S: tower::Service<Operation<H, R>, Response = SdkSuccess<O>, Error = SdkError<E>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = PoisonServiceFuture<S::Future, R>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Operation<H, R>) -> Self::Future {
let classifier = req.retry_classifier().clone();
let capture_smithy_connection = CaptureSmithyConnection::new();
req.properties_mut()
.insert(capture_smithy_connection.clone());
PoisonServiceFuture {
inner: self.inner.call(req),
conn: capture_smithy_connection,
mode: self.mode,
classifier,
}
}
}
pin_project! {
pub struct PoisonServiceFuture<F, R> {
#[pin]
inner: F,
classifier: R,
conn: CaptureSmithyConnection,
mode: ReconnectMode
}
}
impl<F, R, T, E> Future for PoisonServiceFuture<F, R>
where
F: Future<Output = Result<SdkSuccess<T>, SdkError<E>>>,
R: ClassifyRetry<SdkSuccess<T>, SdkError<E>>,
{
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.inner.poll(cx) {
Poll::Ready(resp) => {
let retry_kind = this.classifier.classify_retry(resp.as_ref());
if this.mode == &ReconnectMode::ReconnectOnTransientError
&& retry_kind == RetryKind::Error(ErrorKind::TransientError)
{
if let Some(smithy_conn) = this.conn.get() {
tracing::info!("poisoning connection: {:?}", smithy_conn);
smithy_conn.poison();
} else {
tracing::trace!("No smithy connection found! The underlying HTTP connection never set a connection.");
}
}
Poll::Ready(resp)
}
Poll::Pending => Poll::Pending,
}
}
}

View File

@ -17,14 +17,15 @@ use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use crate::{SdkError, SdkSuccess};
use tracing::Instrument;
use aws_smithy_async::rt::sleep::AsyncSleep;
use aws_smithy_http::operation::Operation;
use aws_smithy_http::retry::ClassifyRetry;
use aws_smithy_types::retry::{ErrorKind, RetryKind};
use tracing::Instrument;
use crate::{SdkError, SdkSuccess};
/// A policy instantiator.
///
@ -292,9 +293,20 @@ impl RetryHandler {
fn should_retry_error(&self, error_kind: &ErrorKind) -> Option<(Self, Duration)> {
let quota_used = {
if self.local.attempts == self.config.max_attempts {
tracing::trace!(
attempts = self.local.attempts,
max_attempts = self.config.max_attempts,
"not retrying becuase we are out of attempts"
);
return None;
}
self.shared.quota_acquire(error_kind, &self.config)?
match self.shared.quota_acquire(error_kind, &self.config) {
Some(quota) => quota,
None => {
tracing::trace!(state = ?self.shared, "not retrying because no quota is available");
return None;
}
}
};
let backoff = calculate_exponential_backoff(
// Generate a random base multiplier to create jitter
@ -334,7 +346,9 @@ impl RetryHandler {
}
fn retry_for(&self, retry_kind: RetryKind) -> Option<BoxFuture<Self>> {
let (next, dur) = self.should_retry(&retry_kind)?;
let retry = self.should_retry(&retry_kind);
tracing::trace!(retry=?retry, retry_kind = ?retry_kind, "retry action");
let (next, dur) = retry?;
let sleep = match &self.sleep_impl {
Some(sleep) => sleep,
@ -377,6 +391,7 @@ where
) -> Option<Self::Future> {
let classifier = req.retry_classifier();
let retry_kind = classifier.classify_retry(result);
tracing::trace!(retry_kind = ?retry_kind, "retry classification");
self.retry_for(retry_kind)
}

View File

@ -90,7 +90,7 @@ impl tower::Service<http::Request<SdkBody>> for CaptureRequestHandler {
/// If response is `None`, it will reply with a 200 response with an empty body
///
/// Example:
/// ```rust,compile_fail
/// ```compile_fail
/// let (server, request) = capture_request(None);
/// let conf = aws_sdk_sts::Config::builder()
/// .http_connector(server)
@ -271,6 +271,347 @@ where
}
}
/// [`wire_mock`] contains utilities for mocking at the socket level
///
/// Other tools in this module actually operate at the `http::Request` / `http::Response` level. This
/// is useful, but it shortcuts the HTTP implementation (e.g. Hyper). [`wire_mock::WireLevelTestConnection`] binds
/// to an actual socket on the host
///
/// # Examples
/// ```
/// use tower::layer::util::Identity;
/// use aws_smithy_client::http_connector::ConnectorSettings;
/// use aws_smithy_client::{match_events, ev};
/// use aws_smithy_client::test_connection::wire_mock::check_matches;
/// # async fn example() {
/// use aws_smithy_client::test_connection::wire_mock::{ReplayedEvent, WireLevelTestConnection};
/// // This connection binds to a local address
/// let mock = WireLevelTestConnection::spinup(vec![
/// ReplayedEvent::status(503),
/// ReplayedEvent::status(200)
/// ]).await;
/// let client = aws_smithy_client::Client::builder()
/// .connector(mock.http_connector().connector(&ConnectorSettings::default(), None).unwrap())
/// .middleware(Identity::new())
/// .build();
/// /* do something with <client> */
/// // assert that you got the events you expected
/// match_events!(ev!(dns), ev!(connect), ev!(http(200)))(&mock.events());
/// # }
/// ```
pub mod wire_mock {
use bytes::Bytes;
use http::{Request, Response};
use hyper::client::connect::dns::Name;
use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Server};
use std::collections::HashSet;
use std::convert::Infallible;
use std::error::Error;
use hyper::client::HttpConnector as HyperHttpConnector;
use std::iter;
use std::iter::Once;
use std::net::{SocketAddr, TcpListener};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use tokio::spawn;
use tower::Service;
/// An event recorded by [`WireLevelTestConnection`]
#[derive(Debug, Clone)]
pub enum RecordedEvent {
DnsLookup(String),
NewConnection,
Response(ReplayedEvent),
}
type Matcher = (
Box<dyn Fn(&RecordedEvent) -> Result<(), Box<dyn Error>>>,
&'static str,
);
/// This method should only be used by the macro
#[doc(hidden)]
pub fn check_matches(events: &[RecordedEvent], matchers: &[Matcher]) {
let mut events_iter = events.iter();
let mut matcher_iter = matchers.iter();
let mut idx = -1;
loop {
idx += 1;
let bail = |err: Box<dyn Error>| panic!("failed on event {}:\n {}", idx, err);
match (events_iter.next(), matcher_iter.next()) {
(Some(event), Some((matcher, _msg))) => matcher(event).unwrap_or_else(bail),
(None, None) => return,
(Some(event), None) => {
bail(format!("got {:?} but no more events were expected", event).into())
}
(None, Some((_expect, msg))) => {
bail(format!("expected {:?} but no more events were expected", msg).into())
}
}
}
}
#[macro_export]
macro_rules! matcher {
($expect:tt) => {
(
Box::new(
|event: &::aws_smithy_client::test_connection::wire_mock::RecordedEvent| {
if !matches!(event, $expect) {
return Err(format!(
"expected `{}` but got {:?}",
stringify!($expect),
event
)
.into());
}
Ok(())
},
),
stringify!($expect),
)
};
}
/// Helper macro to generate a series of test expectations
#[macro_export]
macro_rules! match_events {
($( $expect:pat),*) => {
|events| {
check_matches(events, &[$( ::aws_smithy_client::matcher!($expect) ),*]);
}
};
}
/// Helper to generate match expressions for events
#[macro_export]
macro_rules! ev {
(http($status:expr)) => {
::aws_smithy_client::test_connection::wire_mock::RecordedEvent::Response(
ReplayedEvent::HttpResponse {
status: $status,
..
},
)
};
(dns) => {
::aws_smithy_client::test_connection::wire_mock::RecordedEvent::DnsLookup(_)
};
(connect) => {
::aws_smithy_client::test_connection::wire_mock::RecordedEvent::NewConnection
};
(timeout) => {
::aws_smithy_client::test_connection::wire_mock::RecordedEvent::Response(
ReplayedEvent::Timeout,
)
};
}
pub use {ev, match_events, matcher};
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ReplayedEvent {
Timeout,
HttpResponse { status: u16, body: Bytes },
}
impl ReplayedEvent {
pub fn ok() -> Self {
Self::HttpResponse {
status: 200,
body: Bytes::new(),
}
}
pub fn with_body(body: &str) -> Self {
Self::HttpResponse {
status: 200,
body: Bytes::copy_from_slice(body.as_ref()),
}
}
pub fn status(status: u16) -> Self {
Self::HttpResponse {
status,
body: Bytes::new(),
}
}
}
use crate::erase::boxclone::BoxFuture;
use crate::http_connector::HttpConnector;
use crate::hyper_ext;
use aws_smithy_async::future::never::Never;
use tokio::sync::oneshot;
/// Test connection that starts a server bound to 0.0.0.0
///
/// See the [module docs](crate::test_connection::wire_mock) for a usage example.
///
/// Usage:
/// - Call [`WireLevelTestConnection::spinup`] to start the server
/// - Use [`WireLevelTestConnection::http_connector`] or [`dns_resolver`](WireLevelTestConnection::dns_resolver) to configure your client.
/// - Make requests to [`endpoint_url`](WireLevelTestConnection::endpoint_url).
/// - Once the test is complete, retrieve a list of events from [`WireLevelTestConnection::events`]
#[derive(Debug)]
pub struct WireLevelTestConnection {
event_log: Arc<Mutex<Vec<RecordedEvent>>>,
bind_addr: SocketAddr,
// when the sender is dropped, that stops the server
shutdown_hook: oneshot::Sender<()>,
}
impl WireLevelTestConnection {
pub async fn spinup(mut response_events: Vec<ReplayedEvent>) -> Self {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let (tx, rx) = oneshot::channel();
let listener_addr = listener.local_addr().unwrap();
response_events.reverse();
let response_events = Arc::new(Mutex::new(response_events));
let handler_events = response_events;
let wire_events = Arc::new(Mutex::new(vec![]));
let wire_log_for_service = wire_events.clone();
let poisoned_conns: Arc<Mutex<HashSet<SocketAddr>>> = Default::default();
let make_service = make_service_fn(move |connection: &AddrStream| {
let poisoned_conns = poisoned_conns.clone();
let events = handler_events.clone();
let wire_log = wire_log_for_service.clone();
let remote_addr = connection.remote_addr();
tracing::info!("established connection: {:?}", connection);
wire_log.lock().unwrap().push(RecordedEvent::NewConnection);
async move {
Ok::<_, Infallible>(service_fn(move |_: Request<hyper::Body>| {
if poisoned_conns.lock().unwrap().contains(&remote_addr) {
tracing::error!("poisoned connection {:?} was reused!", &remote_addr);
panic!("poisoned connection was reused!");
}
let next_event = events.clone().lock().unwrap().pop();
let wire_log = wire_log.clone();
let poisoned_conns = poisoned_conns.clone();
async move {
let next_event = next_event
.unwrap_or_else(|| panic!("no more events! Log: {:?}", wire_log));
wire_log
.lock()
.unwrap()
.push(RecordedEvent::Response(next_event.clone()));
if next_event == ReplayedEvent::Timeout {
tracing::info!("{} is poisoned", remote_addr);
poisoned_conns.lock().unwrap().insert(remote_addr);
}
tracing::debug!("replying with {:?}", next_event);
let event = generate_response_event(next_event).await;
dbg!(event)
}
}))
}
});
let server = Server::from_tcp(listener)
.unwrap()
.serve(make_service)
.with_graceful_shutdown(async {
rx.await.ok();
tracing::info!("server shutdown!");
});
spawn(async move { server.await });
Self {
event_log: wire_events,
bind_addr: listener_addr,
shutdown_hook: tx,
}
}
/// Retrieve the events recorded by this connection
pub fn events(&self) -> Vec<RecordedEvent> {
self.event_log.lock().unwrap().clone()
}
fn bind_addr(&self) -> SocketAddr {
self.bind_addr
}
pub fn dns_resolver(&self) -> LoggingDnsResolver {
let event_log = self.event_log.clone();
let bind_addr = self.bind_addr;
LoggingDnsResolver {
log: event_log,
socket_addr: bind_addr,
}
}
/// Prebuilt HTTP connector with correctly wired DNS resolver
///
/// **Note**: This must be used in tandem with [`Self::dns_resolver`]
pub fn http_connector(&self) -> HttpConnector {
let http_connector = HyperHttpConnector::new_with_resolver(self.dns_resolver());
hyper_ext::Adapter::builder().build(http_connector).into()
}
/// Endpoint to use when connecting
///
/// This works in tandem with the [`Self::dns_resolver`] to bind to the correct local IP Address
pub fn endpoint_url(&self) -> String {
format!(
"http://this-url-is-converted-to-localhost.com:{}",
self.bind_addr().port()
)
}
pub fn shutdown(self) {
let _ = self.shutdown_hook.send(());
}
}
async fn generate_response_event(event: ReplayedEvent) -> Result<Response<Body>, Infallible> {
let resp = match event {
ReplayedEvent::HttpResponse { status, body } => http::Response::builder()
.status(status)
.body(hyper::Body::from(body))
.unwrap(),
ReplayedEvent::Timeout => {
Never::new().await;
unreachable!()
}
};
Ok::<_, Infallible>(resp)
}
/// DNS resolver that keeps a log of all lookups
///
/// Regardless of what hostname is requested, it will always return the same socket address.
#[derive(Clone, Debug)]
pub struct LoggingDnsResolver {
log: Arc<Mutex<Vec<RecordedEvent>>>,
socket_addr: SocketAddr,
}
impl Service<Name> for LoggingDnsResolver {
type Response = Once<SocketAddr>;
type Error = Infallible;
type Future = BoxFuture<Self::Response, Self::Error>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Name) -> Self::Future {
let sock_addr = self.socket_addr;
let log = self.log.clone();
Box::pin(async move {
println!("looking up {:?}, replying with {:?}", req, sock_addr);
log.lock()
.unwrap()
.push(RecordedEvent::DnsLookup(req.to_string()));
Ok(iter::once(sock_addr))
})
}
}
}
#[cfg(test)]
mod tests {
use hyper::service::Service;

View File

@ -208,7 +208,7 @@ where
InnerService: tower::Service<Operation<H, R>, Error = SdkError<E>>,
{
type Response = InnerService::Response;
type Error = aws_smithy_http::result::SdkError<E>;
type Error = SdkError<E>;
type Future = TimeoutServiceFuture<InnerService::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {

View File

@ -3,6 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
mod test_operation;
use crate::test_operation::{TestOperationParser, TestRetryClassifier};
use aws_smithy_async::rt::sleep::TokioSleep;
use aws_smithy_client::test_connection::TestConnection;
@ -15,78 +16,6 @@ use std::sync::Arc;
use std::time::Duration;
use tower::layer::util::Identity;
mod test_operation {
use aws_smithy_http::operation;
use aws_smithy_http::response::ParseHttpResponse;
use aws_smithy_http::result::SdkError;
use aws_smithy_http::retry::ClassifyRetry;
use aws_smithy_types::retry::{ErrorKind, ProvideErrorKind, RetryKind};
use bytes::Bytes;
use std::error::Error;
use std::fmt::{self, Debug, Display, Formatter};
#[derive(Clone)]
pub(super) struct TestOperationParser;
#[derive(Debug)]
pub(super) struct OperationError;
impl Display for OperationError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl Error for OperationError {}
impl ProvideErrorKind for OperationError {
fn retryable_error_kind(&self) -> Option<ErrorKind> {
Some(ErrorKind::ThrottlingError)
}
fn code(&self) -> Option<&str> {
None
}
}
impl ParseHttpResponse for TestOperationParser {
type Output = Result<String, OperationError>;
fn parse_unloaded(&self, response: &mut operation::Response) -> Option<Self::Output> {
if response.http().status().is_success() {
Some(Ok("Hello!".to_string()))
} else {
Some(Err(OperationError))
}
}
fn parse_loaded(&self, _response: &http::Response<Bytes>) -> Self::Output {
Ok("Hello!".to_string())
}
}
#[derive(Clone)]
pub(super) struct TestRetryClassifier;
impl<T, E> ClassifyRetry<T, SdkError<E>> for TestRetryClassifier
where
E: ProvideErrorKind + Debug,
T: Debug,
{
fn classify_retry(&self, err: Result<&T, &SdkError<E>>) -> RetryKind {
let kind = match err {
Err(SdkError::ServiceError(context)) => context.err().retryable_error_kind(),
Ok(_) => return RetryKind::Unnecessary,
_ => panic!("test handler only handles modeled errors got: {:?}", err),
};
match kind {
Some(kind) => RetryKind::Error(kind),
None => RetryKind::UnretryableFailure,
}
}
}
}
fn test_operation() -> Operation<TestOperationParser, TestRetryClassifier> {
let req = operation::Request::new(
http::Request::builder()
@ -108,14 +37,14 @@ async fn end_to_end_retry_test() {
fn ok() -> http::Response<&'static str> {
http::Response::builder()
.status(200)
.body("response body")
.body("Hello!")
.unwrap()
}
fn err() -> http::Response<&'static str> {
http::Response::builder()
.status(500)
.body("response body")
.body("This was an error")
.unwrap()
}
// 1 failing response followed by 1 successful response

View File

@ -0,0 +1,230 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
#![cfg(feature = "test-util")]
mod test_operation;
use aws_smithy_async::rt::sleep::TokioSleep;
use aws_smithy_client::test_connection::wire_mock;
use aws_smithy_client::test_connection::wire_mock::{check_matches, RecordedEvent, ReplayedEvent};
use aws_smithy_client::{hyper_ext, Builder};
use aws_smithy_client::{match_events, Client};
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::operation;
use aws_smithy_http::operation::Operation;
use aws_smithy_types::retry::ReconnectMode;
use aws_smithy_types::timeout::{OperationTimeoutConfig, TimeoutConfig};
use http::Uri;
use http_body::combinators::BoxBody;
use hyper::client::{Builder as HyperBuilder, HttpConnector};
use std::convert::Infallible;
use std::sync::Arc;
use std::time::Duration;
use test_operation::{TestOperationParser, TestRetryClassifier};
use tower::layer::util::Identity;
use wire_mock::ev;
fn end_of_test() -> &'static str {
"end_of_test"
}
fn test_operation(
uri: Uri,
retryable: bool,
) -> Operation<TestOperationParser, TestRetryClassifier> {
let mut req = operation::Request::new(
http::Request::builder()
.uri(uri)
.body(SdkBody::from("request body"))
.unwrap(),
);
if !retryable {
req = req
.augment(|req, _conf| {
Ok::<_, Infallible>(
req.map(|_| SdkBody::from_dyn(BoxBody::new(SdkBody::from("body")))),
)
})
.unwrap();
}
Operation::new(req, TestOperationParser).with_retry_classifier(TestRetryClassifier)
}
async fn h1_and_h2(events: Vec<ReplayedEvent>, match_clause: impl Fn(&[RecordedEvent])) {
wire_level_test(events.clone(), |_b| {}, |b| b, &match_clause).await;
wire_level_test(
events,
|b| {
b.http2_only(true);
},
|b| b,
match_clause,
)
.await;
println!("h2 ok!");
}
/// Repeatedly send test operation until `end_of_test` is received
///
/// When the test is over, match_clause is evaluated
async fn wire_level_test(
events: Vec<ReplayedEvent>,
hyper_builder_settings: impl Fn(&mut HyperBuilder),
client_builder_settings: impl Fn(Builder) -> Builder,
match_clause: impl Fn(&[RecordedEvent]),
) {
let connection = wire_mock::WireLevelTestConnection::spinup(events).await;
let http_connector = HttpConnector::new_with_resolver(connection.dns_resolver());
let mut hyper_builder = hyper::Client::builder();
hyper_builder_settings(&mut hyper_builder);
let hyper_adapter = hyper_ext::Adapter::builder()
.hyper_builder(hyper_builder)
.build(http_connector);
let client = client_builder_settings(
Client::builder().reconnect_mode(ReconnectMode::ReconnectOnTransientError),
)
.connector(hyper_adapter)
.middleware(Identity::new())
.operation_timeout_config(OperationTimeoutConfig::from(
&TimeoutConfig::builder()
.operation_attempt_timeout(Duration::from_millis(100))
.build(),
))
.sleep_impl(Arc::new(TokioSleep::new()))
.build();
loop {
match client
.call(test_operation(
connection.endpoint_url().parse().unwrap(),
false,
))
.await
{
Ok(resp) => {
tracing::info!("response: {:?}", resp);
if resp == end_of_test() {
break;
}
}
Err(e) => tracing::info!("error: {:?}", e),
}
}
let events = connection.events();
match_clause(&events);
}
#[tokio::test]
async fn non_transient_errors_no_reconect() {
h1_and_h2(
vec![
ReplayedEvent::status(400),
ReplayedEvent::with_body(end_of_test()),
],
match_events!(ev!(dns), ev!(connect), ev!(http(400)), ev!(http(200))),
)
.await
}
#[tokio::test]
async fn reestablish_dns_on_503() {
h1_and_h2(
vec![
ReplayedEvent::status(503),
ReplayedEvent::status(503),
ReplayedEvent::status(503),
ReplayedEvent::with_body(end_of_test()),
],
match_events!(
// first request
ev!(dns),
ev!(connect),
ev!(http(503)),
// second request
ev!(dns),
ev!(connect),
ev!(http(503)),
// third request
ev!(dns),
ev!(connect),
ev!(http(503)),
// all good
ev!(dns),
ev!(connect),
ev!(http(200))
),
)
.await;
}
#[tokio::test]
async fn connection_shared_on_success() {
h1_and_h2(
vec![
ReplayedEvent::ok(),
ReplayedEvent::ok(),
ReplayedEvent::status(503),
ReplayedEvent::with_body(end_of_test()),
],
match_events!(
ev!(dns),
ev!(connect),
ev!(http(200)),
ev!(http(200)),
ev!(http(503)),
ev!(dns),
ev!(connect),
ev!(http(200))
),
)
.await;
}
#[tokio::test]
async fn no_reconnect_when_disabled() {
use wire_mock::ev;
wire_level_test(
vec![
ReplayedEvent::status(503),
ReplayedEvent::with_body(end_of_test()),
],
|_b| {},
|b| b.reconnect_mode(ReconnectMode::ReuseAllConnections),
match_events!(ev!(dns), ev!(connect), ev!(http(503)), ev!(http(200))),
)
.await;
}
#[tokio::test]
async fn connection_reestablished_after_timeout() {
use wire_mock::ev;
h1_and_h2(
vec![
ReplayedEvent::ok(),
ReplayedEvent::Timeout,
ReplayedEvent::ok(),
ReplayedEvent::Timeout,
ReplayedEvent::with_body(end_of_test()),
],
match_events!(
// first connection
ev!(dns),
ev!(connect),
ev!(http(200)),
// reuse but got a timeout
ev!(timeout),
// so we reconnect
ev!(dns),
ev!(connect),
ev!(http(200)),
ev!(timeout),
ev!(dns),
ev!(connect),
ev!(http(200))
),
)
.await;
}

View File

@ -0,0 +1,84 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
use aws_smithy_http::operation;
use aws_smithy_http::response::ParseHttpResponse;
use aws_smithy_http::result::SdkError;
use aws_smithy_http::retry::ClassifyRetry;
use aws_smithy_types::retry::{ErrorKind, ProvideErrorKind, RetryKind};
use bytes::Bytes;
use std::error::Error;
use std::fmt::{self, Debug, Display, Formatter};
use std::str;
#[derive(Clone)]
pub(super) struct TestOperationParser;
#[derive(Debug)]
pub(super) struct OperationError(ErrorKind);
impl Display for OperationError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl Error for OperationError {}
impl ProvideErrorKind for OperationError {
fn retryable_error_kind(&self) -> Option<ErrorKind> {
Some(self.0)
}
fn code(&self) -> Option<&str> {
None
}
}
impl ParseHttpResponse for TestOperationParser {
type Output = Result<String, OperationError>;
fn parse_unloaded(&self, response: &mut operation::Response) -> Option<Self::Output> {
tracing::debug!("got response: {:?}", response);
match response.http().status() {
s if s.is_success() => None,
s if s.is_client_error() => Some(Err(OperationError(ErrorKind::ServerError))),
s if s.is_server_error() => Some(Err(OperationError(ErrorKind::TransientError))),
_ => panic!("unexpected status: {}", response.http().status()),
}
}
fn parse_loaded(&self, response: &http::Response<Bytes>) -> Self::Output {
Ok(str::from_utf8(response.body().as_ref())
.unwrap()
.to_string())
}
}
#[derive(Clone)]
pub(super) struct TestRetryClassifier;
impl<T, E> ClassifyRetry<T, SdkError<E>> for TestRetryClassifier
where
E: ProvideErrorKind + Debug,
T: Debug,
{
fn classify_retry(&self, err: Result<&T, &SdkError<E>>) -> RetryKind {
tracing::info!("got response: {:?}", err);
let kind = match err {
Err(SdkError::ServiceError(context)) => context.err().retryable_error_kind(),
Err(SdkError::DispatchFailure(err)) if err.is_timeout() => {
Some(ErrorKind::TransientError)
}
Err(SdkError::TimeoutError(_)) => Some(ErrorKind::TransientError),
Ok(_) => return RetryKind::Unnecessary,
_ => panic!("test handler only handles modeled errors got: {:?}", err),
};
match kind {
Some(kind) => RetryKind::Error(kind),
None => RetryKind::UnretryableFailure,
}
}
}

View File

@ -5,6 +5,7 @@
use crate::SendOperationError;
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::connection::CaptureSmithyConnection;
use aws_smithy_http::operation;
use aws_smithy_http::result::ConnectorError;
use std::future::Future;
@ -41,7 +42,13 @@ where
}
fn call(&mut self, req: operation::Request) -> Self::Future {
let (req, property_bag) = req.into_parts();
let (mut req, property_bag) = req.into_parts();
// copy the smithy connection
if let Some(smithy_conn) = property_bag.acquire().get::<CaptureSmithyConnection>() {
req.extensions_mut().insert(smithy_conn.clone());
} else {
println!("nothing to copy!");
}
let mut inner = self.inner.clone();
let future = async move {
trace!(request = ?req, "dispatching request");

View File

@ -0,0 +1,96 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
use std::fmt::{Debug, Formatter};
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
#[derive(Clone)]
pub struct ConnectionMetadata {
is_proxied: bool,
remote_addr: Option<SocketAddr>,
poison_fn: Arc<dyn Fn() + Send + Sync>,
}
impl Debug for ConnectionMetadata {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SmithyConnection")
.field("is_proxied", &self.is_proxied)
.field("remote_addr", &self.remote_addr)
.finish()
}
}
type LoaderFn = dyn Fn() -> Option<ConnectionMetadata> + Send + Sync;
#[derive(Clone, Default)]
pub struct CaptureSmithyConnection {
loader: Arc<Mutex<Option<Box<LoaderFn>>>>,
}
impl CaptureSmithyConnection {
pub fn new() -> Self {
Self {
loader: Default::default(),
}
}
pub fn set_connection_retriever<F>(&self, f: F)
where
F: Fn() -> Option<ConnectionMetadata> + Send + Sync + 'static,
{
*self.loader.lock().unwrap() = Some(Box::new(f));
}
pub fn get(&self) -> Option<ConnectionMetadata> {
match self.loader.lock().unwrap().as_ref() {
Some(loader) => loader(),
None => {
println!("no loader was set :-/");
None
}
}
}
}
impl ConnectionMetadata {
pub fn poison(&self) {
tracing::info!("smithy connection was poisoned");
(self.poison_fn)()
}
}
impl ConnectionMetadata {
pub fn new(
is_proxied: bool,
remote_addr: Option<SocketAddr>,
poison: impl Fn() + Send + Sync + 'static,
) -> Self {
Self {
is_proxied,
remote_addr,
poison_fn: Arc::new(poison),
}
}
pub fn remote_addr(&self) -> Option<SocketAddr> {
self.remote_addr
}
}
#[cfg(test)]
mod test {
use crate::connection::{CaptureSmithyConnection, ConnectionMetadata};
#[test]
fn retrieve_connection_metadata() {
let retriever = CaptureSmithyConnection::new();
let retriever_clone = retriever.clone();
assert!(retriever.get().is_none());
retriever.set_connection_retriever(|| Some(ConnectionMetadata::new(true, None, || {})));
assert!(retriever.get().is_some());
assert!(retriever_clone.get().is_some());
}
}

View File

@ -39,4 +39,5 @@ pub mod event_stream;
pub mod byte_stream;
pub mod connection;
mod urlencode;

View File

@ -12,6 +12,7 @@
//! `Result` wrapper types for [success](SdkSuccess) and [failure](SdkError) responses.
use crate::connection::ConnectionMetadata;
use crate::operation;
use aws_smithy_types::error::metadata::{ProvideErrorMetadata, EMPTY_ERROR_METADATA};
use aws_smithy_types::error::ErrorMetadata;
@ -240,6 +241,11 @@ impl DispatchFailure {
pub fn is_other(&self) -> Option<ErrorKind> {
self.source.is_other()
}
/// Returns the inner error if it is a connector error
pub fn as_connector_error(&self) -> Option<&ConnectorError> {
Some(&self.source)
}
}
/// Error context for [`SdkError::ResponseError`]
@ -505,6 +511,22 @@ enum ConnectorErrorKind {
pub struct ConnectorError {
kind: ConnectorErrorKind,
source: BoxError,
connection: ConnectionStatus,
}
#[non_exhaustive]
#[derive(Debug)]
pub(crate) enum ConnectionStatus {
/// This request was never connected to the remote
///
/// This indicates the failure was during connection establishment
NeverConnected,
/// It is unknown whether a connection was established
Unknown,
/// The request connected to the remote prior to failure
Connected(ConnectionMetadata),
}
impl Display for ConnectorError {
@ -532,14 +554,28 @@ impl ConnectorError {
Self {
kind: ConnectorErrorKind::Timeout,
source,
connection: ConnectionStatus::Unknown,
}
}
/// Include connection information along with this error
pub fn with_connection(mut self, info: ConnectionMetadata) -> Self {
self.connection = ConnectionStatus::Connected(info);
self
}
/// Set the connection status on this error to report that a connection was never established
pub fn never_connected(mut self) -> Self {
self.connection = ConnectionStatus::NeverConnected;
self
}
/// Construct a [`ConnectorError`] from an error caused by the user (e.g. invalid HTTP request)
pub fn user(source: BoxError) -> Self {
Self {
kind: ConnectorErrorKind::User,
source,
connection: ConnectionStatus::Unknown,
}
}
@ -548,6 +584,7 @@ impl ConnectorError {
Self {
kind: ConnectorErrorKind::Io,
source,
connection: ConnectionStatus::Unknown,
}
}
@ -558,6 +595,7 @@ impl ConnectorError {
Self {
source,
kind: ConnectorErrorKind::Other(kind),
connection: ConnectionStatus::Unknown,
}
}
@ -583,4 +621,16 @@ impl ConnectorError {
_ => None,
}
}
/// Returns metadata about the connection
///
/// If a connection was established and provided by the internal connector, a connection will
/// be returned.
pub fn connection_metadata(&self) -> Option<&ConnectionMetadata> {
match &self.connection {
ConnectionStatus::NeverConnected => None,
ConnectionStatus::Unknown => None,
ConnectionStatus::Connected(conn) => Some(conn),
}
}
}

View File

@ -143,6 +143,7 @@ pub struct RetryConfigBuilder {
mode: Option<RetryMode>,
max_attempts: Option<u32>,
initial_backoff: Option<Duration>,
reconnect_mode: Option<ReconnectMode>,
}
impl RetryConfigBuilder {
@ -163,6 +164,30 @@ impl RetryConfigBuilder {
self
}
/// Set the [`ReconnectMode`] for the retry strategy
///
/// By default, when a transient error is encountered, the connection in use will be poisoned.
/// This prevents reusing a connection to a potentially bad host but may increase the load on
/// the server.
///
/// This behavior can be disabled by setting [`ReconnectMode::ReuseAllConnections`] instead.
pub fn reconnect_mode(mut self, reconnect_mode: ReconnectMode) -> Self {
self.set_reconnect_mode(Some(reconnect_mode));
self
}
/// Set the [`ReconnectMode`] for the retry strategy
///
/// By default, when a transient error is encountered, the connection in use will be poisoned.
/// This prevents reusing a connection to a potentially bad host but may increase the load on
/// the server.
///
/// This behavior can be disabled by setting [`ReconnectMode::ReuseAllConnections`] instead.
pub fn set_reconnect_mode(&mut self, reconnect_mode: Option<ReconnectMode>) -> &mut Self {
self.reconnect_mode = reconnect_mode;
self
}
/// Sets the max attempts. This value must be greater than zero.
pub fn set_max_attempts(&mut self, max_attempts: Option<u32>) -> &mut Self {
self.max_attempts = max_attempts;
@ -208,6 +233,7 @@ impl RetryConfigBuilder {
mode: self.mode.or(other.mode),
max_attempts: self.max_attempts.or(other.max_attempts),
initial_backoff: self.initial_backoff.or(other.initial_backoff),
reconnect_mode: self.reconnect_mode.or(other.reconnect_mode),
}
}
@ -219,6 +245,9 @@ impl RetryConfigBuilder {
initial_backoff: self
.initial_backoff
.unwrap_or_else(|| Duration::from_secs(1)),
reconnect_mode: self
.reconnect_mode
.unwrap_or(ReconnectMode::ReconnectOnTransientError),
}
}
}
@ -230,6 +259,23 @@ pub struct RetryConfig {
mode: RetryMode,
max_attempts: u32,
initial_backoff: Duration,
reconnect_mode: ReconnectMode,
}
/// Mode for connection re-establishment
///
/// By default, when a transient error is encountered, the connection in use will be poisoned. This
/// behavior can be disabled by setting [`ReconnectMode::ReuseAllConnections`] instead.
#[derive(Debug, Clone, PartialEq, Copy)]
pub enum ReconnectMode {
/// Reconnect on [`ErrorKind::TransientError`]
ReconnectOnTransientError,
/// Disable reconnect on error
///
/// When this setting is applied, 503s, timeouts, and other transient errors will _not_
/// lead to a new connection being established unless the connection is closed by the remote.
ReuseAllConnections,
}
impl RetryConfig {
@ -239,6 +285,7 @@ impl RetryConfig {
mode: RetryMode::Standard,
max_attempts: 3,
initial_backoff: Duration::from_secs(1),
reconnect_mode: ReconnectMode::ReconnectOnTransientError,
}
}
@ -260,6 +307,18 @@ impl RetryConfig {
self
}
/// Set the [`ReconnectMode`] for the retry strategy
///
/// By default, when a transient error is encountered, the connection in use will be poisoned.
/// This prevents reusing a connection to a potentially bad host but may increase the load on
/// the server.
///
/// This behavior can be disabled by setting [`ReconnectMode::ReuseAllConnections`] instead.
pub fn with_reconnect_mode(mut self, reconnect_mode: ReconnectMode) -> Self {
self.reconnect_mode = reconnect_mode;
self
}
/// Set the multiplier used when calculating backoff times as part of an
/// [exponential backoff with jitter](https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/)
/// strategy. Most services should work fine with the default duration of 1 second, but if you
@ -287,6 +346,11 @@ impl RetryConfig {
self.mode
}
/// Returns the [`ReconnectMode`]
pub fn reconnect_mode(&self) -> ReconnectMode {
self.reconnect_mode
}
/// Returns the max attempts.
pub fn max_attempts(&self) -> u32 {
self.max_attempts