Port middleware connectors to the orchestrator (#2970)

This PR ports all the connectors from the `aws-smithy-client` crate into
`aws-smithy-runtime` implementing the new `HttpConnector` trait. The old
connectors are left in place for now, and follow up PRs will remove them
as well as revise the generated configs to take `HttpConnector` impls
rather than `DynConnector`.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._

---------

Co-authored-by: Zelda Hessler <zhessler@amazon.com>
This commit is contained in:
John DiSanti 2023-09-07 09:54:56 -07:00 committed by GitHub
parent 8a3b8f3a00
commit e322a2d733
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 2394 additions and 314 deletions

View File

@ -103,6 +103,30 @@ references = ["smithy-rs#2964"]
meta = { "breaking" = false, "tada" = false, "bug" = false, target = "client" }
author = "rcoh"
[[smithy-rs]]
message = "`aws_smithy_client::hyper_ext::Adapter` was moved/renamed to `aws_smithy_runtime::client::connectors::hyper_connector::HyperConnector`."
references = ["smithy-rs#2970"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "jdisanti"
[[smithy-rs]]
message = "Test connectors moved into `aws_smithy_runtime::client::connectors::test_util` behind the `test-util` feature."
references = ["smithy-rs#2970"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "jdisanti"
[[smithy-rs]]
message = "DVR's RecordingConnection and ReplayingConnection were renamed to RecordingConnector and ReplayingConnector respectively."
references = ["smithy-rs#2970"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "jdisanti"
[[smithy-rs]]
message = "TestConnection was renamed to EventConnector."
references = ["smithy-rs#2970"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "jdisanti"
[[aws-sdk-rust]]
message = "Remove `once_cell` from public API"
references = ["smithy-rs#2973"]

View File

@ -169,10 +169,8 @@ impl DynConnector {
pub fn call_lite(
&mut self,
req: http::Request<SdkBody>,
) -> BoxFuture<http::Response<SdkBody>, Box<dyn std::error::Error + Send + Sync + 'static>>
{
let future = Service::call(self, req);
Box::pin(async move { future.await.map_err(|err| Box::new(err) as _) })
) -> BoxFuture<http::Response<SdkBody>, ConnectorError> {
Service::call(self, req)
}
}

View File

@ -21,6 +21,7 @@ aws-smithy-http = { path = "../aws-smithy-http" }
aws-smithy-types = { path = "../aws-smithy-types" }
bytes = "1"
http = "0.2.3"
pin-project-lite = "0.2"
tokio = { version = "1.25", features = ["sync"] }
tracing = "0.1"
zeroize = { version = "1", optional = true }

View File

@ -20,7 +20,6 @@ pub mod runtime_plugin;
pub mod auth;
/// Smithy connectors and related code.
pub mod connectors;
pub mod ser_de;

View File

@ -3,9 +3,91 @@
* SPDX-License-Identifier: Apache-2.0
*/
use crate::client::orchestrator::{BoxFuture, HttpRequest, HttpResponse};
//! Smithy connectors and related code.
//!
//! # What is a connector?
//!
//! When we talk about connectors, we are referring to the [`HttpConnector`] trait, and implementations of
//! that trait. This trait simply takes a HTTP request, and returns a future with the response for that
//! request.
//!
//! This is slightly different from what a connector is in other libraries such as
//! [`hyper`](https://crates.io/crates/hyper). In hyper 0.x, the connector is a
//! [`tower`](https://crates.io/crates/tower) `Service` that takes a `Uri` and returns
//! a future with something that implements `AsyncRead + AsyncWrite`.
//!
//! The [`HttpConnector`](crate::client::connectors::HttpConnector) is designed to be a layer on top of
//! whole HTTP libraries, such as hyper, which allows Smithy clients to be agnostic to the underlying HTTP
//! transport layer. This also makes it easy to write tests with a fake HTTP connector, and several
//! such test connector implementations are availble in [`aws-smithy-runtime`](https://crates.io/crates/aws-smithy-runtime).
//!
//! # Responsibilities of a connector
//!
//! A connector primarily makes HTTP requests, but can also be used to implement connect and read
//! timeouts. The `HyperConnector` in [`aws-smithy-runtime`](https://crates.io/crates/aws-smithy-runtime)
//! is an example where timeouts are implemented as part of the connector.
//!
//! Connectors are also responsible for DNS lookup, TLS, connection reuse, pooling, and eviction.
//! The Smithy clients have no knowledge of such concepts.
use crate::client::orchestrator::{HttpRequest, HttpResponse};
use aws_smithy_async::future::now_or_later::NowOrLater;
use aws_smithy_http::result::ConnectorError;
use pin_project_lite::pin_project;
use std::fmt;
use std::future::Future as StdFuture;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;
type BoxFuture = Pin<Box<dyn StdFuture<Output = Result<HttpResponse, ConnectorError>> + Send>>;
pin_project! {
/// Future for [`HttpConnector::call`].
pub struct HttpConnectorFuture {
#[pin]
inner: NowOrLater<Result<HttpResponse, ConnectorError>, BoxFuture>,
}
}
impl HttpConnectorFuture {
/// Create a new `HttpConnectorFuture` with the given future.
pub fn new<F>(future: F) -> Self
where
F: StdFuture<Output = Result<HttpResponse, ConnectorError>> + Send + 'static,
{
Self {
inner: NowOrLater::new(Box::pin(future)),
}
}
/// Create a new `HttpConnectorFuture` with the given boxed future.
///
/// Use this if you already have a boxed future to avoid double boxing it.
pub fn new_boxed(
future: Pin<Box<dyn StdFuture<Output = Result<HttpResponse, ConnectorError>> + Send>>,
) -> Self {
Self {
inner: NowOrLater::new(future),
}
}
/// Create a `HttpConnectorFuture` that is immediately ready with the given result.
pub fn ready(result: Result<HttpResponse, ConnectorError>) -> Self {
Self {
inner: NowOrLater::ready(result),
}
}
}
impl StdFuture for HttpConnectorFuture {
type Output = Result<HttpResponse, ConnectorError>;
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let this = self.project();
this.inner.poll(cx)
}
}
/// Trait with a `call` function that asynchronously converts a request into a response.
///
@ -16,7 +98,7 @@ use std::sync::Arc;
/// for testing.
pub trait HttpConnector: Send + Sync + fmt::Debug {
/// Asynchronously converts a request into a response.
fn call(&self, request: HttpRequest) -> BoxFuture<HttpResponse>;
fn call(&self, request: HttpRequest) -> HttpConnectorFuture;
}
/// A shared [`HttpConnector`] implementation.
@ -31,7 +113,7 @@ impl SharedHttpConnector {
}
impl HttpConnector for SharedHttpConnector {
fn call(&self, request: HttpRequest) -> BoxFuture<HttpResponse> {
fn call(&self, request: HttpRequest) -> HttpConnectorFuture {
(*self.0).call(request)
}
}

View File

@ -511,11 +511,11 @@ impl RuntimeComponentsBuilder {
#[cfg(feature = "test-util")]
pub fn for_tests() -> Self {
use crate::client::auth::AuthSchemeOptionResolver;
use crate::client::connectors::HttpConnector;
use crate::client::connectors::{HttpConnector, HttpConnectorFuture};
use crate::client::endpoint::{EndpointResolver, EndpointResolverParams};
use crate::client::identity::Identity;
use crate::client::identity::IdentityResolver;
use crate::client::orchestrator::Future;
use crate::client::orchestrator::{Future, HttpRequest};
use crate::client::retries::RetryStrategy;
use aws_smithy_async::rt::sleep::AsyncSleep;
use aws_smithy_async::time::TimeSource;
@ -537,11 +537,7 @@ impl RuntimeComponentsBuilder {
#[derive(Debug)]
struct FakeConnector;
impl HttpConnector for FakeConnector {
fn call(
&self,
_: crate::client::orchestrator::HttpRequest,
) -> crate::client::orchestrator::BoxFuture<crate::client::orchestrator::HttpResponse>
{
fn call(&self, _: HttpRequest) -> HttpConnectorFuture {
unreachable!("fake connector must be overridden for this test")
}
}

View File

@ -258,8 +258,8 @@ impl RuntimePlugins {
#[cfg(test)]
mod tests {
use super::{RuntimePlugin, RuntimePlugins};
use crate::client::connectors::{HttpConnector, SharedHttpConnector};
use crate::client::orchestrator::{BoxFuture, HttpRequest, HttpResponse};
use crate::client::connectors::{HttpConnector, HttpConnectorFuture, SharedHttpConnector};
use crate::client::orchestrator::HttpRequest;
use crate::client::runtime_components::RuntimeComponentsBuilder;
use crate::client::runtime_plugin::Order;
use aws_smithy_http::body::SdkBody;
@ -338,12 +338,12 @@ mod tests {
#[tokio::test]
async fn components_can_wrap_components() {
// CN1, the inner connector, creates a response with a `rp1` header
// Connector1, the inner connector, creates a response with a `rp1` header
#[derive(Debug)]
struct CN1;
impl HttpConnector for CN1 {
fn call(&self, _: HttpRequest) -> BoxFuture<HttpResponse> {
Box::pin(async {
struct Connector1;
impl HttpConnector for Connector1 {
fn call(&self, _: HttpRequest) -> HttpConnectorFuture {
HttpConnectorFuture::new(async {
Ok(http::Response::builder()
.status(200)
.header("rp1", "1")
@ -353,13 +353,13 @@ mod tests {
}
}
// CN2, the outer connector, calls the inner connector and adds the `rp2` header to the response
// Connector2, the outer connector, calls the inner connector and adds the `rp2` header to the response
#[derive(Debug)]
struct CN2(SharedHttpConnector);
impl HttpConnector for CN2 {
fn call(&self, request: HttpRequest) -> BoxFuture<HttpResponse> {
struct Connector2(SharedHttpConnector);
impl HttpConnector for Connector2 {
fn call(&self, request: HttpRequest) -> HttpConnectorFuture {
let inner = self.0.clone();
Box::pin(async move {
HttpConnectorFuture::new(async move {
let mut resp = inner.call(request).await.unwrap();
resp.headers_mut()
.append("rp2", HeaderValue::from_static("1"));
@ -368,10 +368,10 @@ mod tests {
}
}
// RP1 registers CN1
// Plugin1 registers Connector1
#[derive(Debug)]
struct RP1;
impl RuntimePlugin for RP1 {
struct Plugin1;
impl RuntimePlugin for Plugin1 {
fn order(&self) -> Order {
Order::Overrides
}
@ -381,16 +381,16 @@ mod tests {
_: &RuntimeComponentsBuilder,
) -> Cow<'_, RuntimeComponentsBuilder> {
Cow::Owned(
RuntimeComponentsBuilder::new("RP1")
.with_http_connector(Some(SharedHttpConnector::new(CN1))),
RuntimeComponentsBuilder::new("Plugin1")
.with_http_connector(Some(SharedHttpConnector::new(Connector1))),
)
}
}
// RP2 registers CN2
// Plugin2 registers Connector2
#[derive(Debug)]
struct RP2;
impl RuntimePlugin for RP2 {
struct Plugin2;
impl RuntimePlugin for Plugin2 {
fn order(&self) -> Order {
Order::NestedComponents
}
@ -400,8 +400,10 @@ mod tests {
current_components: &RuntimeComponentsBuilder,
) -> Cow<'_, RuntimeComponentsBuilder> {
Cow::Owned(
RuntimeComponentsBuilder::new("RP2").with_http_connector(Some(
SharedHttpConnector::new(CN2(current_components.http_connector().unwrap())),
RuntimeComponentsBuilder::new("Plugin2").with_http_connector(Some(
SharedHttpConnector::new(Connector2(
current_components.http_connector().unwrap(),
)),
)),
)
}
@ -410,8 +412,8 @@ mod tests {
// Emulate assembling a full runtime plugins list and using it to apply configuration
let plugins = RuntimePlugins::new()
// intentionally configure the plugins in the reverse order
.with_client_plugin(RP2)
.with_client_plugin(RP1);
.with_client_plugin(Plugin2)
.with_client_plugin(Plugin1);
let mut cfg = ConfigBag::base();
let components = plugins.apply_client_configuration(&mut cfg).unwrap();

View File

@ -12,7 +12,9 @@ repository = "https://github.com/awslabs/smithy-rs"
[features]
client = ["aws-smithy-runtime-api/client"]
http-auth = ["aws-smithy-runtime-api/http-auth"]
test-util = ["aws-smithy-runtime-api/test-util", "dep:aws-smithy-protocol-test", "dep:tracing-subscriber"]
test-util = ["aws-smithy-runtime-api/test-util", "dep:aws-smithy-protocol-test", "dep:tracing-subscriber", "dep:serde", "dep:serde_json"]
connector-hyper = ["dep:hyper", "hyper?/client", "hyper?/http2", "hyper?/http1", "hyper?/tcp"]
tls-rustls = ["dep:hyper-rustls", "dep:rustls", "connector-hyper"]
[dependencies]
aws-smithy-async = { path = "../aws-smithy-async" }
@ -25,9 +27,14 @@ bytes = "1"
fastrand = "2.0.0"
http = "0.2.8"
http-body = "0.4.5"
hyper = { version = "0.14.26", default-features = false, optional = true }
hyper-rustls = { version = "0.24", features = ["rustls-native-certs", "http2"], optional = true }
once_cell = "1.18.0"
pin-project-lite = "0.2.7"
pin-utils = "0.1.0"
rustls = { version = "0.21.1", optional = true }
serde = { version = "1", features = ["derive"], optional = true }
serde_json = { version = "1", optional = true }
tokio = { version = "1.25", features = [] }
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", optional = true, features = ["fmt", "json"] }
@ -37,6 +44,7 @@ approx = "0.5.1"
aws-smithy-async = { path = "../aws-smithy-async", features = ["rt-tokio", "test-util"] }
aws-smithy-runtime-api = { path = "../aws-smithy-runtime-api", features = ["test-util"] }
aws-smithy-types = { path = "../aws-smithy-types", features = ["test-util"] }
hyper-tls = { version = "0.5.0" }
tokio = { version = "1.25", features = ["macros", "rt", "test-util"] }
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
tracing-test = "0.2.1"

View File

@ -4,9 +4,27 @@ allowed_external_types = [
"aws_smithy_http::*",
"aws_smithy_types::*",
"aws_smithy_client::erase::DynConnector",
"aws_smithy_client::http_connector::ConnectorSettings",
# TODO(audit-external-type-usage) We should newtype these or otherwise avoid exposing them
"http::header::name::HeaderName",
"http::request::Request",
"http::response::Response",
"http::uri::Uri",
# Used for creating hyper connectors
"tower_service::Service",
# TODO(https://github.com/awslabs/smithy-rs/issues/1193): Once tooling permits it, only allow the following types in the `test-util` feature
"aws_smithy_protocol_test::MediaType",
"bytes::bytes::Bytes",
"serde::ser::Serialize",
"serde::de::Deserialize",
"hyper::client::connect::dns::Name",
# TODO(https://github.com/awslabs/smithy-rs/issues/1193): Once tooling permits it, only allow the following types in the `connector-hyper` feature
"hyper::client::client::Builder",
"hyper::client::connect::Connection",
"tokio::io::async_read::AsyncRead",
"tokio::io::async_write::AsyncWrite",
]

View File

@ -6,13 +6,10 @@
/// Smithy auth scheme implementations.
pub mod auth;
/// Smithy code related to connectors and connections.
/// Built-in Smithy connectors.
///
/// A "connector" manages one or more "connections", handles connection timeouts, re-establishes
/// connections, etc.
///
/// "Connections" refers to the actual transport layer implementation of the connector.
/// By default, the orchestrator uses a connector provided by `hyper`.
/// See the [module docs in `aws-smithy-runtime-api`](aws_smithy_runtime_api::client::connectors)
/// for more information about connectors.
pub mod connectors;
/// Utility to simplify config building for config and config overrides.

View File

@ -9,14 +9,18 @@ pub mod connection_poisoning;
#[cfg(feature = "test-util")]
pub mod test_util;
/// Default HTTP and TLS connectors that use hyper and rustls.
#[cfg(feature = "connector-hyper")]
pub mod hyper_connector;
// TODO(enableNewSmithyRuntimeCleanup): Delete this module
/// Unstable API for interfacing the old middleware connectors with the newer orchestrator connectors.
///
/// Important: This module and its contents will be removed in the next release.
pub mod adapter {
use aws_smithy_client::erase::DynConnector;
use aws_smithy_runtime_api::client::connectors::HttpConnector;
use aws_smithy_runtime_api::client::orchestrator::{BoxFuture, HttpRequest, HttpResponse};
use aws_smithy_runtime_api::client::connectors::{HttpConnector, HttpConnectorFuture};
use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
use std::sync::{Arc, Mutex};
/// Adapts a [`DynConnector`] to the [`HttpConnector`] trait.
@ -40,9 +44,9 @@ pub mod adapter {
}
impl HttpConnector for DynConnectorAdapter {
fn call(&self, request: HttpRequest) -> BoxFuture<HttpResponse> {
fn call(&self, request: HttpRequest) -> HttpConnectorFuture {
let future = self.dyn_connector.lock().unwrap().call_lite(request);
future
HttpConnectorFuture::new(future)
}
}
}

View File

@ -0,0 +1,869 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
use aws_smithy_async::future::timeout::TimedOutError;
use aws_smithy_async::rt::sleep::{default_async_sleep, SharedAsyncSleep};
use aws_smithy_client::http_connector::ConnectorSettings;
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::connection::{CaptureSmithyConnection, ConnectionMetadata};
use aws_smithy_http::result::ConnectorError;
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::connectors::SharedHttpConnector;
use aws_smithy_runtime_api::client::connectors::{HttpConnector, HttpConnectorFuture};
use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
use aws_smithy_types::error::display::DisplayErrorContext;
use aws_smithy_types::retry::ErrorKind;
use http::{Extensions, Uri};
use hyper::client::connect::{capture_connection, CaptureConnection, Connection, HttpInfo};
use hyper::service::Service;
use std::error::Error;
use std::fmt;
use std::fmt::Debug;
use tokio::io::{AsyncRead, AsyncWrite};
#[cfg(feature = "tls-rustls")]
mod default_connector {
use aws_smithy_async::rt::sleep::SharedAsyncSleep;
use aws_smithy_client::http_connector::ConnectorSettings;
// Creating a `with_native_roots` HTTP client takes 300ms on OS X. Cache this so that we
// don't need to repeatedly incur that cost.
static HTTPS_NATIVE_ROOTS: once_cell::sync::Lazy<
hyper_rustls::HttpsConnector<hyper::client::HttpConnector>,
> = once_cell::sync::Lazy::new(|| {
use hyper_rustls::ConfigBuilderExt;
hyper_rustls::HttpsConnectorBuilder::new()
.with_tls_config(
rustls::ClientConfig::builder()
.with_cipher_suites(&[
// TLS1.3 suites
rustls::cipher_suite::TLS13_AES_256_GCM_SHA384,
rustls::cipher_suite::TLS13_AES_128_GCM_SHA256,
// TLS1.2 suites
rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
rustls::cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
rustls::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
rustls::cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
])
.with_safe_default_kx_groups()
.with_safe_default_protocol_versions()
.expect("Error with the TLS configuration. Please file a bug report under https://github.com/awslabs/smithy-rs/issues.")
.with_native_roots()
.with_no_client_auth()
)
.https_or_http()
.enable_http1()
.enable_http2()
.build()
});
pub(super) fn base(
settings: &ConnectorSettings,
sleep: Option<SharedAsyncSleep>,
) -> super::HyperConnectorBuilder {
let mut hyper = super::HyperConnector::builder().connector_settings(settings.clone());
if let Some(sleep) = sleep {
hyper = hyper.sleep_impl(sleep);
}
hyper
}
/// Return a default HTTPS connector backed by the `rustls` crate.
///
/// It requires a minimum TLS version of 1.2.
/// It allows you to connect to both `http` and `https` URLs.
pub(super) fn https() -> hyper_rustls::HttpsConnector<hyper::client::HttpConnector> {
HTTPS_NATIVE_ROOTS.clone()
}
}
/// Given `ConnectorSettings` and an `SharedAsyncSleep`, create a `SharedHttpConnector` from defaults depending on what cargo features are activated.
pub fn default_connector(
settings: &ConnectorSettings,
sleep: Option<SharedAsyncSleep>,
) -> Option<SharedHttpConnector> {
#[cfg(feature = "tls-rustls")]
{
tracing::trace!(settings = ?settings, sleep = ?sleep, "creating a new default connector");
let hyper = default_connector::base(settings, sleep).build_https();
Some(SharedHttpConnector::new(hyper))
}
#[cfg(not(feature = "tls-rustls"))]
{
tracing::trace!(settings = ?settings, sleep = ?sleep, "no default connector available");
None
}
}
/// [`HttpConnector`] that uses [`hyper`] to make HTTP requests.
///
/// This connector also implements socket connect and read timeouts.
///
/// # Examples
///
/// Construct a `HyperConnector` with the default TLS implementation (rustls).
/// This can be useful when you want to share a Hyper connector between multiple
/// generated Smithy clients.
///
/// ```no_run,ignore
/// use aws_smithy_runtime::client::connectors::hyper_connector::{DefaultHttpsTcpConnector, HyperConnector};
///
/// let hyper_connector = HyperConnector::builder().build(DefaultHttpsTcpConnector::new());
///
/// // This connector can then be given to a generated service Config
/// let config = my_service_client::Config::builder()
/// .endpoint_url("http://localhost:1234")
/// .http_connector(hyper_connector)
/// .build();
/// let client = my_service_client::Client::from_conf(config);
/// ```
///
/// ## Use a Hyper client with WebPKI roots
///
/// A use case for where you may want to use the [`HyperConnector`] is when setting Hyper client settings
/// that aren't otherwise exposed by the `Config` builder interface. Some examples include changing:
///
/// - Hyper client settings
/// - Allowed TLS cipher suites
/// - Using an alternative TLS connector library (not the default, rustls)
/// - CA trust root certificates (illustrated using WebPKI below)
///
/// ```no_run,ignore
/// use aws_smithy_runtime::client::connectors::hyper_connector::HyperConnector;
///
/// let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
/// .with_webpki_roots()
/// .https_only()
/// .enable_http1()
/// .enable_http2()
/// .build();
/// let hyper_connector = HyperConnector::builder().build(https_connector);
///
/// // This connector can then be given to a generated service Config
/// let config = my_service_client::Config::builder()
/// .endpoint_url("https://example.com")
/// .http_connector(hyper_connector)
/// .build();
/// let client = my_service_client::Client::from_conf(config);
/// ```
#[derive(Debug)]
pub struct HyperConnector {
adapter: Box<dyn HttpConnector>,
}
impl HyperConnector {
/// Builder for a Hyper connector.
pub fn builder() -> HyperConnectorBuilder {
Default::default()
}
}
impl HttpConnector for HyperConnector {
fn call(&self, request: HttpRequest) -> HttpConnectorFuture {
self.adapter.call(request)
}
}
/// Builder for [`HyperConnector`].
#[derive(Default, Debug)]
pub struct HyperConnectorBuilder {
connector_settings: Option<ConnectorSettings>,
sleep_impl: Option<SharedAsyncSleep>,
client_builder: Option<hyper::client::Builder>,
}
impl HyperConnectorBuilder {
/// Create a [`HyperConnector`] from this builder and a given connector.
pub fn build<C>(self, tcp_connector: C) -> HyperConnector
where
C: Clone + Send + Sync + 'static,
C: Service<Uri>,
C::Response: Connection + AsyncRead + AsyncWrite + Send + Unpin + 'static,
C::Future: Unpin + Send + 'static,
C::Error: Into<BoxError>,
{
let client_builder = self.client_builder.unwrap_or_default();
let sleep_impl = self.sleep_impl.or_else(default_async_sleep);
let (connect_timeout, read_timeout) = self
.connector_settings
.map(|c| (c.connect_timeout(), c.read_timeout()))
.unwrap_or((None, None));
let connector = match connect_timeout {
Some(duration) => timeout_middleware::ConnectTimeout::new(
tcp_connector,
sleep_impl
.clone()
.expect("a sleep impl must be provided in order to have a connect timeout"),
duration,
),
None => timeout_middleware::ConnectTimeout::no_timeout(tcp_connector),
};
let base = client_builder.build(connector);
let read_timeout = match read_timeout {
Some(duration) => timeout_middleware::HttpReadTimeout::new(
base,
sleep_impl.expect("a sleep impl must be provided in order to have a read timeout"),
duration,
),
None => timeout_middleware::HttpReadTimeout::no_timeout(base),
};
HyperConnector {
adapter: Box::new(Adapter {
client: read_timeout,
}),
}
}
/// Create a [`HyperConnector`] with the default rustls HTTPS implementation.
#[cfg(feature = "tls-rustls")]
pub fn build_https(self) -> HyperConnector {
self.build(default_connector::https())
}
/// Set the async sleep implementation used for timeouts
///
/// Calling this is only necessary for testing or to use something other than
/// [`default_async_sleep`].
pub fn sleep_impl(mut self, sleep_impl: SharedAsyncSleep) -> Self {
self.sleep_impl = Some(sleep_impl);
self
}
/// Set the async sleep implementation used for timeouts
///
/// Calling this is only necessary for testing or to use something other than
/// [`default_async_sleep`].
pub fn set_sleep_impl(&mut self, sleep_impl: Option<SharedAsyncSleep>) -> &mut Self {
self.sleep_impl = sleep_impl;
self
}
/// Configure the HTTP settings for the `HyperAdapter`
pub fn connector_settings(mut self, connector_settings: ConnectorSettings) -> Self {
self.connector_settings = Some(connector_settings);
self
}
/// Configure the HTTP settings for the `HyperAdapter`
pub fn set_connector_settings(
&mut self,
connector_settings: Option<ConnectorSettings>,
) -> &mut Self {
self.connector_settings = connector_settings;
self
}
/// Override the Hyper client [`Builder`](hyper::client::Builder) used to construct this client.
///
/// This enables changing settings like forcing HTTP2 and modifying other default client behavior.
pub fn hyper_builder(mut self, hyper_builder: hyper::client::Builder) -> Self {
self.client_builder = Some(hyper_builder);
self
}
/// Override the Hyper client [`Builder`](hyper::client::Builder) used to construct this client.
///
/// This enables changing settings like forcing HTTP2 and modifying other default client behavior.
pub fn set_hyper_builder(
&mut self,
hyper_builder: Option<hyper::client::Builder>,
) -> &mut Self {
self.client_builder = hyper_builder;
self
}
}
/// Adapter from a [`hyper::Client`](hyper::Client) to [`HttpConnector`].
///
/// This adapter also enables TCP `CONNECT` and HTTP `READ` timeouts via [`HyperConnector::builder`].
struct Adapter<C> {
client: timeout_middleware::HttpReadTimeout<
hyper::Client<timeout_middleware::ConnectTimeout<C>, SdkBody>,
>,
}
impl<C> fmt::Debug for Adapter<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Adapter")
.field("client", &"** hyper client **")
.finish()
}
}
/// Extract a smithy connection from a hyper CaptureConnection
fn extract_smithy_connection(capture_conn: &CaptureConnection) -> Option<ConnectionMetadata> {
let capture_conn = capture_conn.clone();
if let Some(conn) = capture_conn.clone().connection_metadata().as_ref() {
let mut extensions = Extensions::new();
conn.get_extras(&mut extensions);
let http_info = extensions.get::<HttpInfo>();
let smithy_connection = ConnectionMetadata::new(
conn.is_proxied(),
http_info.map(|info| info.remote_addr()),
move || match capture_conn.connection_metadata().as_ref() {
Some(conn) => conn.poison(),
None => tracing::trace!("no connection existed to poison"),
},
);
Some(smithy_connection)
} else {
None
}
}
impl<C> HttpConnector for Adapter<C>
where
C: Clone + Send + Sync + 'static,
C: hyper::service::Service<Uri>,
C::Response: Connection + AsyncRead + AsyncWrite + Send + Unpin + 'static,
C::Future: Unpin + Send + 'static,
C::Error: Into<BoxError>,
{
fn call(&self, mut request: HttpRequest) -> HttpConnectorFuture {
let capture_connection = capture_connection(&mut request);
if let Some(capture_smithy_connection) =
request.extensions().get::<CaptureSmithyConnection>()
{
capture_smithy_connection
.set_connection_retriever(move || extract_smithy_connection(&capture_connection));
}
let mut client = self.client.clone();
let fut = client.call(request);
HttpConnectorFuture::new(async move {
Ok(fut.await.map_err(downcast_error)?.map(SdkBody::from))
})
}
}
/// Downcast errors coming out of hyper into an appropriate `ConnectorError`
fn downcast_error(err: BoxError) -> ConnectorError {
// is a `TimedOutError` (from aws_smithy_async::timeout) in the chain? if it is, this is a timeout
if find_source::<TimedOutError>(err.as_ref()).is_some() {
return ConnectorError::timeout(err);
}
// is the top of chain error actually already a `ConnectorError`? return that directly
let err = match err.downcast::<ConnectorError>() {
Ok(connector_error) => return *connector_error,
Err(box_error) => box_error,
};
// generally, the top of chain will probably be a hyper error. Go through a set of hyper specific
// error classifications
let err = match err.downcast::<hyper::Error>() {
Ok(hyper_error) => return to_connector_error(*hyper_error),
Err(box_error) => box_error,
};
// otherwise, we have no idea!
ConnectorError::other(err, None)
}
/// Convert a [`hyper::Error`] into a [`ConnectorError`]
fn to_connector_error(err: hyper::Error) -> ConnectorError {
if err.is_timeout() || find_source::<timeout_middleware::HttpTimeoutError>(&err).is_some() {
ConnectorError::timeout(err.into())
} else if err.is_user() {
ConnectorError::user(err.into())
} else if err.is_closed() || err.is_canceled() || find_source::<std::io::Error>(&err).is_some()
{
ConnectorError::io(err.into())
}
// We sometimes receive this from S3: hyper::Error(IncompleteMessage)
else if err.is_incomplete_message() {
ConnectorError::other(err.into(), Some(ErrorKind::TransientError))
} else {
tracing::warn!(err = %DisplayErrorContext(&err), "unrecognized error from Hyper. If this error should be retried, please file an issue.");
ConnectorError::other(err.into(), None)
}
}
fn find_source<'a, E: Error + 'static>(err: &'a (dyn Error + 'static)) -> Option<&'a E> {
let mut next = Some(err);
while let Some(err) = next {
if let Some(matching_err) = err.downcast_ref::<E>() {
return Some(matching_err);
}
next = err.source();
}
None
}
mod timeout_middleware {
use aws_smithy_async::future::timeout::{TimedOutError, Timeout};
use aws_smithy_async::rt::sleep::Sleep;
use aws_smithy_async::rt::sleep::{AsyncSleep, SharedAsyncSleep};
use aws_smithy_runtime_api::box_error::BoxError;
use http::Uri;
use pin_project_lite::pin_project;
use std::error::Error;
use std::fmt::Formatter;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
#[derive(Debug)]
pub(crate) struct HttpTimeoutError {
kind: &'static str,
duration: Duration,
}
impl std::fmt::Display for HttpTimeoutError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{} timeout occurred after {:?}",
self.kind, self.duration
)
}
}
impl Error for HttpTimeoutError {
// We implement the `source` function as returning a `TimedOutError` because when `downcast_error`
// or `find_source` is called with an `HttpTimeoutError` (or another error wrapping an `HttpTimeoutError`)
// this method will be checked to determine if it's a timeout-related error.
fn source(&self) -> Option<&(dyn Error + 'static)> {
Some(&TimedOutError)
}
}
/// Timeout wrapper that will timeout on the initial TCP connection
///
/// # Stability
/// This interface is unstable.
#[derive(Clone, Debug)]
pub(super) struct ConnectTimeout<I> {
inner: I,
timeout: Option<(SharedAsyncSleep, Duration)>,
}
impl<I> ConnectTimeout<I> {
/// Create a new `ConnectTimeout` around `inner`.
///
/// Typically, `I` will implement [`hyper::client::connect::Connect`].
pub(crate) fn new(inner: I, sleep: SharedAsyncSleep, timeout: Duration) -> Self {
Self {
inner,
timeout: Some((sleep, timeout)),
}
}
pub(crate) fn no_timeout(inner: I) -> Self {
Self {
inner,
timeout: None,
}
}
}
#[derive(Clone, Debug)]
pub(crate) struct HttpReadTimeout<I> {
inner: I,
timeout: Option<(SharedAsyncSleep, Duration)>,
}
impl<I> HttpReadTimeout<I> {
/// Create a new `HttpReadTimeout` around `inner`.
///
/// Typically, `I` will implement [`hyper::service::Service<http::Request<SdkBody>>`].
pub(crate) fn new(inner: I, sleep: SharedAsyncSleep, timeout: Duration) -> Self {
Self {
inner,
timeout: Some((sleep, timeout)),
}
}
pub(crate) fn no_timeout(inner: I) -> Self {
Self {
inner,
timeout: None,
}
}
}
pin_project! {
/// Timeout future for Tower services
///
/// Timeout future to handle timing out, mapping errors, and the possibility of not timing out
/// without incurring an additional allocation for each timeout layer.
#[project = MaybeTimeoutFutureProj]
pub enum MaybeTimeoutFuture<F> {
Timeout {
#[pin]
timeout: Timeout<F, Sleep>,
error_type: &'static str,
duration: Duration,
},
NoTimeout {
#[pin]
future: F
}
}
}
impl<F, T, E> Future for MaybeTimeoutFuture<F>
where
F: Future<Output = Result<T, E>>,
E: Into<BoxError>,
{
type Output = Result<T, BoxError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let (timeout_future, kind, &mut duration) = match self.project() {
MaybeTimeoutFutureProj::NoTimeout { future } => {
return future.poll(cx).map_err(|err| err.into());
}
MaybeTimeoutFutureProj::Timeout {
timeout,
error_type,
duration,
} => (timeout, error_type, duration),
};
match timeout_future.poll(cx) {
Poll::Ready(Ok(response)) => Poll::Ready(response.map_err(|err| err.into())),
Poll::Ready(Err(_timeout)) => {
Poll::Ready(Err(HttpTimeoutError { kind, duration }.into()))
}
Poll::Pending => Poll::Pending,
}
}
}
impl<I> hyper::service::Service<Uri> for ConnectTimeout<I>
where
I: hyper::service::Service<Uri>,
I::Error: Into<BoxError>,
{
type Response = I::Response;
type Error = BoxError;
type Future = MaybeTimeoutFuture<I::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(|err| err.into())
}
fn call(&mut self, req: Uri) -> Self::Future {
match &self.timeout {
Some((sleep, duration)) => {
let sleep = sleep.sleep(*duration);
MaybeTimeoutFuture::Timeout {
timeout: Timeout::new(self.inner.call(req), sleep),
error_type: "HTTP connect",
duration: *duration,
}
}
None => MaybeTimeoutFuture::NoTimeout {
future: self.inner.call(req),
},
}
}
}
impl<I, B> hyper::service::Service<http::Request<B>> for HttpReadTimeout<I>
where
I: hyper::service::Service<http::Request<B>, Error = hyper::Error>,
{
type Response = I::Response;
type Error = BoxError;
type Future = MaybeTimeoutFuture<I::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(|err| err.into())
}
fn call(&mut self, req: http::Request<B>) -> Self::Future {
match &self.timeout {
Some((sleep, duration)) => {
let sleep = sleep.sleep(*duration);
MaybeTimeoutFuture::Timeout {
timeout: Timeout::new(self.inner.call(req), sleep),
error_type: "HTTP read",
duration: *duration,
}
}
None => MaybeTimeoutFuture::NoTimeout {
future: self.inner.call(req),
},
}
}
}
#[cfg(test)]
mod test {
use super::super::*;
use super::*;
use aws_smithy_async::assert_elapsed;
use aws_smithy_async::future::never::Never;
use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
use aws_smithy_http::body::SdkBody;
use aws_smithy_types::error::display::DisplayErrorContext;
use aws_smithy_types::timeout::TimeoutConfig;
use hyper::client::connect::Connected;
use std::time::Duration;
use tokio::io::ReadBuf;
use tokio::net::TcpStream;
#[allow(unused)]
fn connect_timeout_is_correct<T: Send + Sync + Clone + 'static>() {
is_send_sync::<super::ConnectTimeout<T>>();
}
#[allow(unused)]
fn is_send_sync<T: Send + Sync>() {}
/// A service that will never return whatever it is you want
///
/// Returned futures will return Pending forever
#[non_exhaustive]
#[derive(Clone, Default, Debug)]
struct NeverConnects;
impl hyper::service::Service<Uri> for NeverConnects {
type Response = TcpStream;
type Error = ConnectorError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _uri: Uri) -> Self::Future {
Box::pin(async move {
Never::new().await;
unreachable!()
})
}
}
/// A service that will connect but never send any data
#[derive(Clone, Debug, Default)]
struct NeverReplies;
impl hyper::service::Service<Uri> for NeverReplies {
type Response = EmptyStream;
type Error = BoxError;
type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Uri) -> Self::Future {
std::future::ready(Ok(EmptyStream))
}
}
/// A stream that will never return or accept any data
#[non_exhaustive]
#[derive(Debug, Default)]
struct EmptyStream;
impl AsyncRead for EmptyStream {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Pending
}
}
impl AsyncWrite for EmptyStream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
Poll::Pending
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Poll::Pending
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Poll::Pending
}
}
impl Connection for EmptyStream {
fn connected(&self) -> Connected {
Connected::new()
}
}
#[tokio::test]
async fn http_connect_timeout_works() {
let tcp_connector = NeverConnects::default();
let connector_settings = ConnectorSettings::from_timeout_config(
&TimeoutConfig::builder()
.connect_timeout(Duration::from_secs(1))
.build(),
);
let hyper = HyperConnector::builder()
.connector_settings(connector_settings)
.sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
.build(tcp_connector)
.adapter;
let now = tokio::time::Instant::now();
tokio::time::pause();
let resp = hyper
.call(
http::Request::builder()
.uri("http://foo.com")
.body(SdkBody::empty())
.unwrap(),
)
.await
.unwrap_err();
assert!(
resp.is_timeout(),
"expected resp.is_timeout() to be true but it was false, resp == {:?}",
resp
);
let message = DisplayErrorContext(&resp).to_string();
let expected =
"timeout: error trying to connect: HTTP connect timeout occurred after 1s";
assert!(
message.contains(expected),
"expected '{message}' to contain '{expected}'"
);
assert_elapsed!(now, Duration::from_secs(1));
}
#[tokio::test]
async fn http_read_timeout_works() {
let tcp_connector = NeverReplies::default();
let connector_settings = ConnectorSettings::from_timeout_config(
&TimeoutConfig::builder()
.connect_timeout(Duration::from_secs(1))
.read_timeout(Duration::from_secs(2))
.build(),
);
let hyper = HyperConnector::builder()
.connector_settings(connector_settings)
.sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
.build(tcp_connector)
.adapter;
let now = tokio::time::Instant::now();
tokio::time::pause();
let err = hyper
.call(
http::Request::builder()
.uri("http://foo.com")
.body(SdkBody::empty())
.unwrap(),
)
.await
.unwrap_err();
assert!(
err.is_timeout(),
"expected err.is_timeout() to be true but it was false, err == {err:?}",
);
let message = format!("{}", DisplayErrorContext(&err));
let expected = "timeout: HTTP read timeout occurred after 2s";
assert!(
message.contains(expected),
"expected '{message}' to contain '{expected}'"
);
assert_elapsed!(now, Duration::from_secs(2));
}
}
}
#[cfg(test)]
mod test {
use super::*;
use aws_smithy_http::body::SdkBody;
use http::Uri;
use hyper::client::connect::{Connected, Connection};
use std::io::{Error, ErrorKind};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
#[tokio::test]
async fn hyper_io_error() {
let connector = TestConnection {
inner: HangupStream,
};
let adapter = HyperConnector::builder().build(connector).adapter;
let err = adapter
.call(
http::Request::builder()
.uri("http://amazon.com")
.body(SdkBody::empty())
.unwrap(),
)
.await
.expect_err("socket hangup");
assert!(err.is_io(), "{:?}", err);
}
// ---- machinery to make a Hyper connector that responds with an IO Error
#[derive(Clone)]
struct HangupStream;
impl Connection for HangupStream {
fn connected(&self) -> Connected {
Connected::new()
}
}
impl AsyncRead for HangupStream {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Ready(Err(Error::new(
ErrorKind::ConnectionReset,
"connection reset",
)))
}
}
impl AsyncWrite for HangupStream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<Result<usize, Error>> {
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Poll::Pending
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Poll::Pending
}
}
#[derive(Clone)]
struct TestConnection<T> {
inner: T,
}
impl<T> hyper::service::Service<Uri> for TestConnection<T>
where
T: Clone + Connection,
{
type Response = T;
type Error = BoxError;
type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Uri) -> Self::Future {
std::future::ready(Ok(self.inner.clone()))
}
}
}

View File

@ -4,261 +4,38 @@
*/
//! Module with client connectors useful for testing.
//!
//! Each test connector is useful for different test use cases:
//! - [`capture_request`](capture_request::capture_request): If you don't care what the
//! response is, but just want to check that the serialized request is what you expect,
//! then use `capture_request`. Or, alternatively, if you don't care what the request
//! is, but want to always respond with a given response, then capture request can also
//! be useful since you can optionally give it a response to return.
//! - [`dvr`]: If you want to record real-world traffic and then replay it later, then DVR's
//! [`RecordingConnector`](dvr::RecordingConnector) and [`ReplayingConnector`](dvr::ReplayingConnector)
//! can accomplish this, and the recorded traffic can be saved to JSON and checked in. Note: if
//! the traffic recording has sensitive information in it, such as signatures or authorization,
//! you will need to manually scrub this out if you intend to store the recording alongside
//! your tests.
//! - [`EventConnector`]: If you want to have a set list of requests and their responses in a test,
//! then the event connector will be useful. On construction, it takes a list of tuples that represent
//! each expected request and the response for that request. At the end of the test, you can ask the
//! connector to verify that the requests matched the expectations.
//! - [`infallible_connection_fn`]: Allows you to create a connector from an infallible function
//! that takes a request and returns a response.
//! - [`NeverConnector`]: Useful for testing timeouts, where you want the connector to never respond.
use aws_smithy_async::rt::sleep::{AsyncSleep, SharedAsyncSleep};
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::result::ConnectorError;
use aws_smithy_protocol_test::{assert_ok, validate_body, MediaType};
use aws_smithy_runtime_api::client::connectors::HttpConnector;
use aws_smithy_runtime_api::client::orchestrator::{BoxFuture, HttpRequest, HttpResponse};
use http::header::{HeaderName, CONTENT_TYPE};
use std::fmt::Debug;
use std::ops::Deref;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::sync::oneshot;
mod capture_request;
pub use capture_request::{capture_request, CaptureRequestHandler, CaptureRequestReceiver};
/// Test Connection to capture a single request
#[derive(Debug, Clone)]
pub struct CaptureRequestHandler(Arc<Mutex<Inner>>);
#[cfg(feature = "connector-hyper")]
pub mod dvr;
#[derive(Debug)]
struct Inner {
_response: Option<http::Response<SdkBody>>,
_sender: Option<oneshot::Sender<HttpRequest>>,
}
mod event_connector;
pub use event_connector::{ConnectionEvent, EventConnector};
/// Receiver for [`CaptureRequestHandler`](CaptureRequestHandler)
#[derive(Debug)]
pub struct CaptureRequestReceiver {
receiver: oneshot::Receiver<HttpRequest>,
}
mod infallible;
pub use infallible::infallible_connection_fn;
impl CaptureRequestReceiver {
/// Expect that a request was sent. Returns the captured request.
///
/// # Panics
/// If no request was received
#[track_caller]
pub fn expect_request(mut self) -> HttpRequest {
self.receiver.try_recv().expect("no request was received")
}
/// Expect that no request was captured. Panics if a request was received.
///
/// # Panics
/// If a request was received
#[track_caller]
pub fn expect_no_request(mut self) {
self.receiver
.try_recv()
.expect_err("expected no request to be received!");
}
}
/// Test connection used to capture a single request
///
/// If response is `None`, it will reply with a 200 response with an empty body
///
/// Example:
/// ```compile_fail
/// let (server, request) = capture_request(None);
/// let conf = aws_sdk_sts::Config::builder()
/// .http_connector(server)
/// .build();
/// let client = aws_sdk_sts::Client::from_conf(conf);
/// let _ = client.assume_role_with_saml().send().await;
/// // web identity should be unsigned
/// assert_eq!(
/// request.expect_request().headers().get("AUTHORIZATION"),
/// None
/// );
/// ```
pub fn capture_request(
response: Option<http::Response<SdkBody>>,
) -> (CaptureRequestHandler, CaptureRequestReceiver) {
let (tx, rx) = oneshot::channel();
(
CaptureRequestHandler(Arc::new(Mutex::new(Inner {
_response: Some(response.unwrap_or_else(|| {
http::Response::builder()
.status(200)
.body(SdkBody::empty())
.expect("unreachable")
})),
_sender: Some(tx),
}))),
CaptureRequestReceiver { receiver: rx },
)
}
type ConnectionEvents = Vec<ConnectionEvent>;
/// Test data for the [`TestConnector`].
///
/// Each `ConnectionEvent` represents one HTTP request and response
/// through the connector. Optionally, a latency value can be set to simulate
/// network latency (done via async sleep in the `TestConnector`).
#[derive(Debug)]
pub struct ConnectionEvent {
latency: Duration,
req: HttpRequest,
res: HttpResponse,
}
impl ConnectionEvent {
/// Creates a new `ConnectionEvent`.
pub fn new(req: HttpRequest, res: HttpResponse) -> Self {
Self {
res,
req,
latency: Duration::from_secs(0),
}
}
/// Add simulated latency to this `ConnectionEvent`
pub fn with_latency(mut self, latency: Duration) -> Self {
self.latency = latency;
self
}
/// Returns the test request.
pub fn request(&self) -> &HttpRequest {
&self.req
}
/// Returns the test response.
pub fn response(&self) -> &HttpResponse {
&self.res
}
}
impl From<(HttpRequest, HttpResponse)> for ConnectionEvent {
fn from((req, res): (HttpRequest, HttpResponse)) -> Self {
Self::new(req, res)
}
}
#[derive(Debug)]
struct ValidateRequest {
expected: HttpRequest,
actual: HttpRequest,
}
impl ValidateRequest {
fn assert_matches(&self, index: usize, ignore_headers: &[HeaderName]) {
let (actual, expected) = (&self.actual, &self.expected);
assert_eq!(
actual.uri(),
expected.uri(),
"Request #{index} - URI doesn't match expected value"
);
for (name, value) in expected.headers() {
if !ignore_headers.contains(name) {
let actual_header = actual
.headers()
.get(name)
.unwrap_or_else(|| panic!("Request #{index} - Header {name:?} is missing"));
assert_eq!(
actual_header.to_str().unwrap(),
value.to_str().unwrap(),
"Request #{index} - Header {name:?} doesn't match expected value",
);
}
}
let actual_str = std::str::from_utf8(actual.body().bytes().unwrap_or(&[]));
let expected_str = std::str::from_utf8(expected.body().bytes().unwrap_or(&[]));
let media_type = if actual
.headers()
.get(CONTENT_TYPE)
.map(|v| v.to_str().unwrap().contains("json"))
.unwrap_or(false)
{
MediaType::Json
} else {
MediaType::Other("unknown".to_string())
};
match (actual_str, expected_str) {
(Ok(actual), Ok(expected)) => assert_ok(validate_body(actual, expected, media_type)),
_ => assert_eq!(
actual.body().bytes(),
expected.body().bytes(),
"Request #{index} - Body contents didn't match expected value"
),
};
}
}
/// Test connector for use as a [`HttpConnector`].
///
/// A basic test connection. It will:
/// - Respond to requests with a preloaded series of responses
/// - Record requests for future examination
#[derive(Debug, Clone)]
pub struct TestConnector {
data: Arc<Mutex<ConnectionEvents>>,
requests: Arc<Mutex<Vec<ValidateRequest>>>,
sleep_impl: SharedAsyncSleep,
}
impl TestConnector {
/// Creates a new test connector.
pub fn new(mut data: ConnectionEvents, sleep_impl: impl Into<SharedAsyncSleep>) -> Self {
data.reverse();
TestConnector {
data: Arc::new(Mutex::new(data)),
requests: Default::default(),
sleep_impl: sleep_impl.into(),
}
}
fn requests(&self) -> impl Deref<Target = Vec<ValidateRequest>> + '_ {
self.requests.lock().unwrap()
}
/// Asserts the expected requests match the actual requests.
///
/// The expected requests are given as the connection events when the `TestConnector`
/// is created. The `TestConnector` will record the actual requests and assert that
/// they match the expected requests.
///
/// A list of headers that should be ignored when comparing requests can be passed
/// for cases where headers are non-deterministic or are irrelevant to the test.
#[track_caller]
pub fn assert_requests_match(&self, ignore_headers: &[HeaderName]) {
for (i, req) in self.requests().iter().enumerate() {
req.assert_matches(i, ignore_headers)
}
let remaining_requests = self.data.lock().unwrap();
let number_of_remaining_requests = remaining_requests.len();
let actual_requests = self.requests().len();
assert!(
remaining_requests.is_empty(),
"Expected {number_of_remaining_requests} additional requests (only {actual_requests} sent)",
);
}
}
impl HttpConnector for TestConnector {
fn call(&self, request: HttpRequest) -> BoxFuture<HttpResponse> {
let (res, simulated_latency) = if let Some(event) = self.data.lock().unwrap().pop() {
self.requests.lock().unwrap().push(ValidateRequest {
expected: event.req,
actual: request,
});
(Ok(event.res.map(SdkBody::from)), event.latency)
} else {
(
Err(ConnectorError::other("No more data".into(), None).into()),
Duration::from_secs(0),
)
};
let sleep = self.sleep_impl.sleep(simulated_latency);
Box::pin(async move {
sleep.await;
res
})
}
}
mod never;
pub use never::NeverConnector;

View File

@ -0,0 +1,84 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
use aws_smithy_http::body::SdkBody;
use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
use std::fmt::Debug;
use std::sync::{Arc, Mutex};
use tokio::sync::oneshot;
/// Test Connection to capture a single request
#[derive(Debug, Clone)]
pub struct CaptureRequestHandler(Arc<Mutex<Inner>>);
#[derive(Debug)]
struct Inner {
_response: Option<http::Response<SdkBody>>,
_sender: Option<oneshot::Sender<HttpRequest>>,
}
/// Receiver for [`CaptureRequestHandler`](CaptureRequestHandler)
#[derive(Debug)]
pub struct CaptureRequestReceiver {
receiver: oneshot::Receiver<HttpRequest>,
}
impl CaptureRequestReceiver {
/// Expect that a request was sent. Returns the captured request.
///
/// # Panics
/// If no request was received
#[track_caller]
pub fn expect_request(mut self) -> HttpRequest {
self.receiver.try_recv().expect("no request was received")
}
/// Expect that no request was captured. Panics if a request was received.
///
/// # Panics
/// If a request was received
#[track_caller]
pub fn expect_no_request(mut self) {
self.receiver
.try_recv()
.expect_err("expected no request to be received!");
}
}
/// Test connection used to capture a single request
///
/// If response is `None`, it will reply with a 200 response with an empty body
///
/// Example:
/// ```compile_fail
/// let (server, request) = capture_request(None);
/// let conf = aws_sdk_sts::Config::builder()
/// .http_connector(server)
/// .build();
/// let client = aws_sdk_sts::Client::from_conf(conf);
/// let _ = client.assume_role_with_saml().send().await;
/// // web identity should be unsigned
/// assert_eq!(
/// request.expect_request().headers().get("AUTHORIZATION"),
/// None
/// );
/// ```
pub fn capture_request(
response: Option<http::Response<SdkBody>>,
) -> (CaptureRequestHandler, CaptureRequestReceiver) {
let (tx, rx) = oneshot::channel();
(
CaptureRequestHandler(Arc::new(Mutex::new(Inner {
_response: Some(response.unwrap_or_else(|| {
http::Response::builder()
.status(200)
.body(SdkBody::empty())
.expect("unreachable")
})),
_sender: Some(tx),
}))),
CaptureRequestReceiver { receiver: rx },
)
}

View File

@ -0,0 +1,274 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
//! Extremely Experimental Test Connection
//!
//! Warning: Extremely experimental, API likely to change.
//!
//! DVR is an extremely experimental record & replay framework that supports multi-frame HTTP request / response traffic.
use aws_smithy_types::base64;
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
mod record;
mod replay;
pub use aws_smithy_protocol_test::MediaType;
pub use record::RecordingConnector;
pub use replay::ReplayingConnector;
/// A complete traffic recording
///
/// A traffic recording can be replayed with [`RecordingConnector`](RecordingConnector)
#[derive(Debug, Serialize, Deserialize)]
pub struct NetworkTraffic {
events: Vec<Event>,
docs: Option<String>,
version: Version,
}
impl NetworkTraffic {
/// Network events
pub fn events(&self) -> &Vec<Event> {
&self.events
}
}
/// Serialization version of DVR data
#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum Version {
/// Initial network traffic version
V0,
}
/// A network traffic recording may contain multiple different connections occurring simultaneously
#[derive(Copy, Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
pub struct ConnectionId(usize);
/// A network event
///
/// Network events consist of a connection identifier and an action. An event is sufficient to
/// reproduce traffic later during replay
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
pub struct Event {
connection_id: ConnectionId,
action: Action,
}
/// An initial HTTP request, roughly equivalent to `http::Request<()>`
///
/// The initial request phase of an HTTP request. The body will be
/// sent later as a separate action.
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
pub struct Request {
uri: String,
headers: HashMap<String, Vec<String>>,
method: String,
}
/// An initial HTTP response roughly equivalent to `http::Response<()>`
///
/// The initial response phase of an HTTP request. The body will be
/// sent later as a separate action.
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
pub struct Response {
status: u16,
version: String,
headers: HashMap<String, Vec<String>>,
}
impl From<&Request> for http::Request<()> {
fn from(request: &Request) -> Self {
let mut builder = http::Request::builder().uri(request.uri.as_str());
for (k, values) in request.headers.iter() {
for v in values {
builder = builder.header(k, v);
}
}
builder.method(request.method.as_str()).body(()).unwrap()
}
}
impl<'a, B> From<&'a http::Request<B>> for Request {
fn from(req: &'a http::Request<B>) -> Self {
let uri = req.uri().to_string();
let headers = headers_to_map(req.headers());
let method = req.method().to_string();
Self {
uri,
headers,
method,
}
}
}
fn headers_to_map(headers: &http::HeaderMap<http::HeaderValue>) -> HashMap<String, Vec<String>> {
let mut out: HashMap<_, Vec<_>> = HashMap::new();
for (header_name, header_value) in headers.iter() {
let entry = out.entry(header_name.to_string()).or_default();
entry.push(header_value.to_str().unwrap().to_string());
}
out
}
impl<'a, B> From<&'a http::Response<B>> for Response {
fn from(resp: &'a http::Response<B>) -> Self {
let status = resp.status().as_u16();
let version = format!("{:?}", resp.version());
let headers = headers_to_map(resp.headers());
Self {
status,
version,
headers,
}
}
}
/// Error response wrapper
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
pub struct Error(String);
/// Network Action
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
#[non_exhaustive]
pub enum Action {
/// Initial HTTP Request
Request {
/// HTTP Request headers, method, and URI
request: Request,
},
/// Initial HTTP response or failure
Response {
/// HTTP response or failure
response: Result<Response, Error>,
},
/// Data segment
Data {
/// Body Data
data: BodyData,
/// Direction: request vs. response
direction: Direction,
},
/// End of data
Eof {
/// Succesful vs. failed termination
ok: bool,
/// Direction: request vs. response
direction: Direction,
},
}
/// Event direction
///
/// During replay, this is used to replay data in the right direction
#[derive(Copy, Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub enum Direction {
/// Request phase
Request,
/// Response phase
Response,
}
impl Direction {
/// The opposite of a given direction
pub fn opposite(self) -> Self {
match self {
Direction::Request => Direction::Response,
Direction::Response => Direction::Request,
}
}
}
/// HTTP Body Data Abstraction
///
/// When the data is a UTF-8 encoded string, it will be serialized as a string for readability.
/// Otherwise, it will be base64 encoded.
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[non_exhaustive]
pub enum BodyData {
/// UTF-8 encoded data
Utf8(String),
/// Base64 encoded binary data
Base64(String),
}
impl BodyData {
/// Convert [`BodyData`](BodyData) into Bytes
pub fn into_bytes(self) -> Vec<u8> {
match self {
BodyData::Utf8(string) => string.into_bytes(),
BodyData::Base64(string) => base64::decode(string).unwrap(),
}
}
/// Copy [`BodyData`](BodyData) into a `Vec<u8>`
pub fn copy_to_vec(&self) -> Vec<u8> {
match self {
BodyData::Utf8(string) => string.as_bytes().into(),
BodyData::Base64(string) => base64::decode(string).unwrap(),
}
}
}
impl From<Bytes> for BodyData {
fn from(data: Bytes) -> Self {
match std::str::from_utf8(data.as_ref()) {
Ok(string) => BodyData::Utf8(string.to_string()),
Err(_) => BodyData::Base64(base64::encode(data)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::byte_stream::ByteStream;
use aws_smithy_runtime_api::client::connectors::{HttpConnector, SharedHttpConnector};
use bytes::Bytes;
use http::Uri;
use std::error::Error;
use std::fs;
#[tokio::test]
async fn turtles_all_the_way_down() -> Result<(), Box<dyn Error>> {
// create a replaying connection from a recording, wrap a recording connection around it,
// make a request, then verify that the same traffic was recorded.
let network_traffic = fs::read_to_string("test-data/example.com.json")?;
let network_traffic: NetworkTraffic = serde_json::from_str(&network_traffic)?;
let inner = ReplayingConnector::new(network_traffic.events.clone());
let connection = RecordingConnector::new(SharedHttpConnector::new(inner.clone()));
let req = http::Request::post("https://www.example.com")
.body(SdkBody::from("hello world"))
.unwrap();
let mut resp = connection.call(req).await.expect("ok");
let body = std::mem::replace(resp.body_mut(), SdkBody::taken());
let data = ByteStream::new(body).collect().await.unwrap().into_bytes();
assert_eq!(
String::from_utf8(data.to_vec()).unwrap(),
"hello from example.com"
);
assert_eq!(
connection.events().as_slice(),
network_traffic.events.as_slice()
);
let requests = inner.take_requests().await;
assert_eq!(
requests[0].uri(),
&Uri::from_static("https://www.example.com")
);
assert_eq!(
requests[0].body(),
&Bytes::from_static("hello world".as_bytes())
);
Ok(())
}
}

View File

@ -0,0 +1,202 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
use super::{
Action, BodyData, ConnectionId, Direction, Error, Event, NetworkTraffic, Request, Response,
Version,
};
use aws_smithy_http::body::SdkBody;
use aws_smithy_runtime_api::client::connectors::{
HttpConnector, HttpConnectorFuture, SharedHttpConnector,
};
use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
use http_body::Body;
use std::path::Path;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, MutexGuard};
use std::{fs, io};
use tokio::task::JoinHandle;
/// Recording Connection Wrapper
///
/// RecordingConnector wraps an inner connection and records all traffic, enabling traffic replay.
#[derive(Clone, Debug)]
pub struct RecordingConnector {
pub(crate) data: Arc<Mutex<Vec<Event>>>,
pub(crate) num_events: Arc<AtomicUsize>,
pub(crate) inner: SharedHttpConnector,
}
#[cfg(all(feature = "tls-rustls"))]
impl RecordingConnector {
/// Construct a recording connection wrapping a default HTTPS implementation
pub fn https() -> Self {
use crate::client::connectors::hyper_connector::HyperConnector;
Self {
data: Default::default(),
num_events: Arc::new(AtomicUsize::new(0)),
inner: SharedHttpConnector::new(HyperConnector::builder().build_https()),
}
}
}
impl RecordingConnector {
/// Create a new recording connection from a connection
pub fn new(underlying_connector: SharedHttpConnector) -> Self {
Self {
data: Default::default(),
num_events: Arc::new(AtomicUsize::new(0)),
inner: underlying_connector,
}
}
/// Return the traffic recorded by this connection
pub fn events(&self) -> MutexGuard<'_, Vec<Event>> {
self.data.lock().unwrap()
}
/// NetworkTraffic struct suitable for serialization
pub fn network_traffic(&self) -> NetworkTraffic {
NetworkTraffic {
events: self.events().clone(),
docs: Some("todo docs".into()),
version: Version::V0,
}
}
/// Dump the network traffic to a file
pub fn dump_to_file(&self, path: impl AsRef<Path>) -> Result<(), io::Error> {
fs::write(
path,
serde_json::to_string(&self.network_traffic()).unwrap(),
)
}
fn next_id(&self) -> ConnectionId {
ConnectionId(self.num_events.fetch_add(1, Ordering::Relaxed))
}
}
fn record_body(
body: &mut SdkBody,
event_id: ConnectionId,
direction: Direction,
event_bus: Arc<Mutex<Vec<Event>>>,
) -> JoinHandle<()> {
let (sender, output_body) = hyper::Body::channel();
let real_body = std::mem::replace(body, SdkBody::from(output_body));
tokio::spawn(async move {
let mut real_body = real_body;
let mut sender = sender;
loop {
let data = real_body.data().await;
match data {
Some(Ok(data)) => {
event_bus.lock().unwrap().push(Event {
connection_id: event_id,
action: Action::Data {
data: BodyData::from(data.clone()),
direction,
},
});
// This happens if the real connection is closed during recording.
// Need to think more carefully if this is the correct thing to log in this
// case.
if sender.send_data(data).await.is_err() {
event_bus.lock().unwrap().push(Event {
connection_id: event_id,
action: Action::Eof {
direction: direction.opposite(),
ok: false,
},
})
};
}
None => {
event_bus.lock().unwrap().push(Event {
connection_id: event_id,
action: Action::Eof {
ok: true,
direction,
},
});
drop(sender);
break;
}
Some(Err(_err)) => {
event_bus.lock().unwrap().push(Event {
connection_id: event_id,
action: Action::Eof {
ok: false,
direction,
},
});
sender.abort();
break;
}
}
}
})
}
impl HttpConnector for RecordingConnector {
fn call(&self, mut request: HttpRequest) -> HttpConnectorFuture {
let event_id = self.next_id();
// A request has three phases:
// 1. A "Request" phase. This is initial HTTP request, headers, & URI
// 2. A body phase. This may contain multiple data segments.
// 3. A finalization phase. An EOF of some sort is sent on the body to indicate that
// the channel should be closed.
// Phase 1: the initial http request
self.data.lock().unwrap().push(Event {
connection_id: event_id,
action: Action::Request {
request: Request::from(&request),
},
});
// Phase 2: Swap out the real request body for one that will log all traffic that passes
// through it
// This will also handle phase three when the request body runs out of data.
record_body(
request.body_mut(),
event_id,
Direction::Request,
self.data.clone(),
);
let events = self.data.clone();
// create a channel we'll use to stream the data while reading it
let resp_fut = self.inner.call(request);
let fut = async move {
let resp = resp_fut.await;
match resp {
Ok(mut resp) => {
// push the initial response event
events.lock().unwrap().push(Event {
connection_id: event_id,
action: Action::Response {
response: Ok(Response::from(&resp)),
},
});
// instrument the body and record traffic
record_body(resp.body_mut(), event_id, Direction::Response, events);
Ok(resp)
}
Err(e) => {
events.lock().unwrap().push(Event {
connection_id: event_id,
action: Action::Response {
response: Err(Error(format!("{}", &e))),
},
});
Err(e)
}
}
};
HttpConnectorFuture::new(fut)
}
}

View File

@ -0,0 +1,351 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
use super::{Action, ConnectionId, Direction, Event, NetworkTraffic};
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::result::ConnectorError;
use aws_smithy_protocol_test::MediaType;
use aws_smithy_runtime_api::client::connectors::{HttpConnector, HttpConnectorFuture};
use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
use aws_smithy_types::error::display::DisplayErrorContext;
use bytes::{Bytes, BytesMut};
use http::{Request, Version};
use http_body::Body;
use std::collections::{HashMap, VecDeque};
use std::error::Error;
use std::ops::DerefMut;
use std::path::Path;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use tokio::task::JoinHandle;
/// Wrapper type to enable optionally waiting for a future to complete
#[derive(Debug)]
enum Waitable<T> {
Loading(JoinHandle<T>),
Value(T),
}
impl<T> Waitable<T> {
/// Consumes the future and returns the value
async fn take(self) -> T {
match self {
Waitable::Loading(f) => f.await.expect("join failed"),
Waitable::Value(value) => value,
}
}
/// Waits for the future to be ready
async fn wait(&mut self) {
match self {
Waitable::Loading(f) => *self = Waitable::Value(f.await.expect("join failed")),
Waitable::Value(_) => {}
}
}
}
/// Replay traffic recorded by a [`RecordingConnector`](super::RecordingConnector)
#[derive(Clone, Debug)]
pub struct ReplayingConnector {
live_events: Arc<Mutex<HashMap<ConnectionId, VecDeque<Event>>>>,
verifiable_events: Arc<HashMap<ConnectionId, Request<Bytes>>>,
num_events: Arc<AtomicUsize>,
recorded_requests: Arc<Mutex<HashMap<ConnectionId, Waitable<http::Request<Bytes>>>>>,
}
impl ReplayingConnector {
fn next_id(&self) -> ConnectionId {
ConnectionId(self.num_events.fetch_add(1, Ordering::Relaxed))
}
/// Validate all headers and bodies
pub async fn full_validate(self, media_type: MediaType) -> Result<(), Box<dyn Error>> {
self.validate_body_and_headers(None, media_type).await
}
/// Validate actual requests against expected requests
pub async fn validate(
self,
checked_headers: &[&str],
body_comparer: impl Fn(&[u8], &[u8]) -> Result<(), Box<dyn Error>>,
) -> Result<(), Box<dyn Error>> {
self.validate_base(Some(checked_headers), body_comparer)
.await
}
/// Validate that the bodies match, using a given [`MediaType`] for comparison
///
/// The specified headers are also validated
pub async fn validate_body_and_headers(
self,
checked_headers: Option<&[&str]>,
media_type: MediaType,
) -> Result<(), Box<dyn Error>> {
self.validate_base(checked_headers, |b1, b2| {
aws_smithy_protocol_test::validate_body(
b1,
std::str::from_utf8(b2).unwrap(),
media_type.clone(),
)
.map_err(|e| Box::new(e) as _)
})
.await
}
async fn validate_base(
self,
checked_headers: Option<&[&str]>,
body_comparer: impl Fn(&[u8], &[u8]) -> Result<(), Box<dyn Error>>,
) -> Result<(), Box<dyn Error>> {
let mut actual_requests =
std::mem::take(self.recorded_requests.lock().unwrap().deref_mut());
for conn_id in 0..self.verifiable_events.len() {
let conn_id = ConnectionId(conn_id);
let expected = self.verifiable_events.get(&conn_id).unwrap();
let actual = actual_requests
.remove(&conn_id)
.ok_or(format!(
"expected connection {:?} but request was never sent",
conn_id
))?
.take()
.await;
aws_smithy_protocol_test::assert_uris_match(expected.uri(), actual.uri());
body_comparer(expected.body().as_ref(), actual.body().as_ref())?;
let expected_headers = expected
.headers()
.keys()
.map(|k| k.as_str())
.filter(|k| match checked_headers {
Some(list) => list.contains(k),
None => true,
})
.flat_map(|key| {
let _ = expected.headers().get(key)?;
Some((
key,
expected
.headers()
.get_all(key)
.iter()
.map(|h| h.to_str().unwrap())
.collect::<Vec<_>>()
.join(", "),
))
})
.collect::<Vec<_>>();
aws_smithy_protocol_test::validate_headers(actual.headers(), expected_headers)
.map_err(|err| {
format!(
"event {} validation failed with: {}",
conn_id.0,
DisplayErrorContext(&err)
)
})?;
}
Ok(())
}
/// Return all the recorded requests for further analysis
pub async fn take_requests(self) -> Vec<http::Request<Bytes>> {
let mut recorded_requests =
std::mem::take(self.recorded_requests.lock().unwrap().deref_mut());
let mut out = Vec::with_capacity(recorded_requests.len());
for conn_id in 0..recorded_requests.len() {
out.push(
recorded_requests
.remove(&ConnectionId(conn_id))
.expect("should exist")
.take()
.await,
)
}
out
}
/// Build a replay connection from a JSON file
pub fn from_file(path: impl AsRef<Path>) -> Result<Self, Box<dyn Error>> {
let events: NetworkTraffic =
serde_json::from_str(&std::fs::read_to_string(path.as_ref())?)?;
Ok(Self::new(events.events))
}
/// Build a replay connection from a sequence of events
pub fn new(events: Vec<Event>) -> Self {
let mut event_map: HashMap<_, VecDeque<_>> = HashMap::new();
for event in events {
let event_buffer = event_map.entry(event.connection_id).or_default();
event_buffer.push_back(event);
}
let verifiable_events = event_map
.iter()
.map(|(id, events)| {
let mut body = BytesMut::new();
for event in events {
if let Action::Data {
direction: Direction::Request,
data,
} = &event.action
{
body.extend_from_slice(&data.copy_to_vec());
}
}
let initial_request = events.iter().next().expect("must have one event");
let request = match &initial_request.action {
Action::Request { request } => {
http::Request::from(request).map(|_| Bytes::from(body))
}
_ => panic!("invalid first event"),
};
(*id, request)
})
.collect();
let verifiable_events = Arc::new(verifiable_events);
ReplayingConnector {
live_events: Arc::new(Mutex::new(event_map)),
num_events: Arc::new(AtomicUsize::new(0)),
recorded_requests: Default::default(),
verifiable_events,
}
}
}
async fn replay_body(events: VecDeque<Event>, mut sender: hyper::body::Sender) {
for event in events {
match event.action {
Action::Request { .. } => panic!(),
Action::Response { .. } => panic!(),
Action::Data {
data,
direction: Direction::Response,
} => {
sender
.send_data(Bytes::from(data.into_bytes()))
.await
.expect("this is in memory traffic that should not fail to send");
}
Action::Data {
data: _data,
direction: Direction::Request,
} => {}
Action::Eof {
direction: Direction::Request,
..
} => {}
Action::Eof {
direction: Direction::Response,
ok: true,
..
} => {
drop(sender);
break;
}
Action::Eof {
direction: Direction::Response,
ok: false,
..
} => {
sender.abort();
break;
}
}
}
}
fn convert_version(version: &str) -> Version {
match version {
"HTTP/1.1" => Version::HTTP_11,
"HTTP/2.0" => Version::HTTP_2,
_ => panic!("unsupported: {}", version),
}
}
impl HttpConnector for ReplayingConnector {
fn call(&self, mut request: HttpRequest) -> HttpConnectorFuture {
let event_id = self.next_id();
tracing::debug!("received event {}: {request:?}", event_id.0);
let mut events = match self.live_events.lock().unwrap().remove(&event_id) {
Some(traffic) => traffic,
None => {
return HttpConnectorFuture::ready(Err(ConnectorError::other(
format!("no data for event {}. request: {:?}", event_id.0, request).into(),
None,
)));
}
};
let _initial_request = events.pop_front().unwrap();
let (sender, response_body) = hyper::Body::channel();
let body = SdkBody::from(response_body);
let recording = self.recorded_requests.clone();
let recorded_request = tokio::spawn(async move {
let mut data_read = vec![];
while let Some(data) = request.body_mut().data().await {
data_read
.extend_from_slice(data.expect("in memory request should not fail").as_ref())
}
request.map(|_| Bytes::from(data_read))
});
let mut recorded_request = Waitable::Loading(recorded_request);
let fut = async move {
let resp: Result<_, ConnectorError> = loop {
let event = events
.pop_front()
.expect("no events, needed a response event");
match event.action {
// to ensure deterministic behavior if the request EOF happens first in the log,
// wait for the request body to be done before returning a response.
Action::Eof {
direction: Direction::Request,
..
} => {
recorded_request.wait().await;
}
Action::Request { .. } => panic!("invalid"),
Action::Response {
response: Err(error),
} => break Err(ConnectorError::other(error.0.into(), None)),
Action::Response {
response: Ok(response),
} => {
let mut builder = http::Response::builder()
.status(response.status)
.version(convert_version(&response.version));
for (name, values) in response.headers {
for value in values {
builder = builder.header(&name, &value);
}
}
tokio::spawn(async move {
replay_body(events, sender).await;
// insert the finalized body into
});
break Ok(builder.body(body).expect("valid builder"));
}
Action::Data {
direction: Direction::Request,
data: _data,
} => {
tracing::info!("get request data");
}
Action::Eof {
direction: Direction::Response,
..
} => panic!("got eof before response"),
Action::Data {
data: _,
direction: Direction::Response,
} => panic!("got response data before response"),
}
};
recording.lock().unwrap().insert(event_id, recorded_request);
resp
};
HttpConnectorFuture::new(fut)
}
}

View File

@ -0,0 +1,187 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
use aws_smithy_async::rt::sleep::{AsyncSleep, SharedAsyncSleep};
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::result::ConnectorError;
use aws_smithy_protocol_test::{assert_ok, validate_body, MediaType};
use aws_smithy_runtime_api::client::connectors::{HttpConnector, HttpConnectorFuture};
use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse};
use http::header::{HeaderName, CONTENT_TYPE};
use std::fmt::Debug;
use std::ops::Deref;
use std::sync::{Arc, Mutex};
use std::time::Duration;
type ConnectionEvents = Vec<ConnectionEvent>;
/// Test data for the [`EventConnector`].
///
/// Each `ConnectionEvent` represents one HTTP request and response
/// through the connector. Optionally, a latency value can be set to simulate
/// network latency (done via async sleep in the `EventConnector`).
#[derive(Debug)]
pub struct ConnectionEvent {
latency: Duration,
req: HttpRequest,
res: HttpResponse,
}
impl ConnectionEvent {
/// Creates a new `ConnectionEvent`.
pub fn new(req: HttpRequest, res: HttpResponse) -> Self {
Self {
res,
req,
latency: Duration::from_secs(0),
}
}
/// Add simulated latency to this `ConnectionEvent`
pub fn with_latency(mut self, latency: Duration) -> Self {
self.latency = latency;
self
}
/// Returns the test request.
pub fn request(&self) -> &HttpRequest {
&self.req
}
/// Returns the test response.
pub fn response(&self) -> &HttpResponse {
&self.res
}
}
impl From<(HttpRequest, HttpResponse)> for ConnectionEvent {
fn from((req, res): (HttpRequest, HttpResponse)) -> Self {
Self::new(req, res)
}
}
#[derive(Debug)]
struct ValidateRequest {
expected: HttpRequest,
actual: HttpRequest,
}
impl ValidateRequest {
fn assert_matches(&self, index: usize, ignore_headers: &[HeaderName]) {
let (actual, expected) = (&self.actual, &self.expected);
assert_eq!(
actual.uri(),
expected.uri(),
"Request #{index} - URI doesn't match expected value"
);
for (name, value) in expected.headers() {
if !ignore_headers.contains(name) {
let actual_header = actual
.headers()
.get(name)
.unwrap_or_else(|| panic!("Request #{index} - Header {name:?} is missing"));
assert_eq!(
actual_header.to_str().unwrap(),
value.to_str().unwrap(),
"Request #{index} - Header {name:?} doesn't match expected value",
);
}
}
let actual_str = std::str::from_utf8(actual.body().bytes().unwrap_or(&[]));
let expected_str = std::str::from_utf8(expected.body().bytes().unwrap_or(&[]));
let media_type = if actual
.headers()
.get(CONTENT_TYPE)
.map(|v| v.to_str().unwrap().contains("json"))
.unwrap_or(false)
{
MediaType::Json
} else {
MediaType::Other("unknown".to_string())
};
match (actual_str, expected_str) {
(Ok(actual), Ok(expected)) => assert_ok(validate_body(actual, expected, media_type)),
_ => assert_eq!(
actual.body().bytes(),
expected.body().bytes(),
"Request #{index} - Body contents didn't match expected value"
),
};
}
}
/// Request/response event-driven connector for use in tests.
///
/// A basic test connection. It will:
/// - Respond to requests with a preloaded series of responses
/// - Record requests for future examination
#[derive(Debug, Clone)]
pub struct EventConnector {
data: Arc<Mutex<ConnectionEvents>>,
requests: Arc<Mutex<Vec<ValidateRequest>>>,
sleep_impl: SharedAsyncSleep,
}
impl EventConnector {
/// Creates a new event connector.
pub fn new(mut data: ConnectionEvents, sleep_impl: impl Into<SharedAsyncSleep>) -> Self {
data.reverse();
EventConnector {
data: Arc::new(Mutex::new(data)),
requests: Default::default(),
sleep_impl: sleep_impl.into(),
}
}
fn requests(&self) -> impl Deref<Target = Vec<ValidateRequest>> + '_ {
self.requests.lock().unwrap()
}
/// Asserts the expected requests match the actual requests.
///
/// The expected requests are given as the connection events when the `EventConnector`
/// is created. The `EventConnector` will record the actual requests and assert that
/// they match the expected requests.
///
/// A list of headers that should be ignored when comparing requests can be passed
/// for cases where headers are non-deterministic or are irrelevant to the test.
#[track_caller]
pub fn assert_requests_match(&self, ignore_headers: &[HeaderName]) {
for (i, req) in self.requests().iter().enumerate() {
req.assert_matches(i, ignore_headers)
}
let remaining_requests = self.data.lock().unwrap();
let number_of_remaining_requests = remaining_requests.len();
let actual_requests = self.requests().len();
assert!(
remaining_requests.is_empty(),
"Expected {number_of_remaining_requests} additional requests (only {actual_requests} sent)",
);
}
}
impl HttpConnector for EventConnector {
fn call(&self, request: HttpRequest) -> HttpConnectorFuture {
let (res, simulated_latency) = if let Some(event) = self.data.lock().unwrap().pop() {
self.requests.lock().unwrap().push(ValidateRequest {
expected: event.req,
actual: request,
});
(Ok(event.res.map(SdkBody::from)), event.latency)
} else {
(
Err(ConnectorError::other("No more data".into(), None)),
Duration::from_secs(0),
)
};
let sleep = self.sleep_impl.sleep(simulated_latency);
HttpConnectorFuture::new(async move {
sleep.await;
res
})
}
}

View File

@ -0,0 +1,62 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::result::ConnectorError;
use aws_smithy_runtime_api::client::connectors::{
HttpConnector, HttpConnectorFuture, SharedHttpConnector,
};
use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
use std::fmt;
use std::sync::Arc;
/// Create a [`SharedHttpConnector`] from `Fn(http:Request) -> http::Response`
///
/// # Examples
///
/// ```rust
/// use aws_smithy_runtime::client::connectors::test_util::infallible_connection_fn;
/// let connector = infallible_connection_fn(|_req| http::Response::builder().status(200).body("OK!").unwrap());
/// ```
pub fn infallible_connection_fn<B>(
f: impl Fn(http::Request<SdkBody>) -> http::Response<B> + Send + Sync + 'static,
) -> SharedHttpConnector
where
B: Into<SdkBody>,
{
SharedHttpConnector::new(InfallibleConnectorFn::new(f))
}
#[derive(Clone)]
struct InfallibleConnectorFn {
#[allow(clippy::type_complexity)]
response: Arc<
dyn Fn(http::Request<SdkBody>) -> Result<http::Response<SdkBody>, ConnectorError>
+ Send
+ Sync,
>,
}
impl fmt::Debug for InfallibleConnectorFn {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("InfallibleConnectorFn").finish()
}
}
impl InfallibleConnectorFn {
fn new<B: Into<SdkBody>>(
f: impl Fn(http::Request<SdkBody>) -> http::Response<B> + Send + Sync + 'static,
) -> Self {
Self {
response: Arc::new(move |request| Ok(f(request).map(|b| b.into()))),
}
}
}
impl HttpConnector for InfallibleConnectorFn {
fn call(&self, request: HttpRequest) -> HttpConnectorFuture {
HttpConnectorFuture::ready((self.response)(request))
}
}

View File

@ -0,0 +1,42 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
//! Test connectors that never return data
use aws_smithy_async::future::never::Never;
use aws_smithy_runtime_api::client::connectors::{HttpConnector, HttpConnectorFuture};
use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
/// A connector that will never respond.
///
/// Returned futures will return Pending forever
#[derive(Clone, Debug, Default)]
pub struct NeverConnector {
invocations: Arc<AtomicUsize>,
}
impl NeverConnector {
/// Create a new never connector.
pub fn new() -> Self {
Default::default()
}
/// Returns the number of invocations made to this connector.
pub fn num_calls(&self) -> usize {
self.invocations.load(Ordering::SeqCst)
}
}
impl HttpConnector for NeverConnector {
fn call(&self, _request: HttpRequest) -> HttpConnectorFuture {
self.invocations.fetch_add(1, Ordering::SeqCst);
HttpConnectorFuture::new(async move {
Never::new().await;
unreachable!()
})
}
}

View File

@ -356,12 +356,7 @@ async fn try_attempt(
OrchestratorError::other("No HTTP connector was available to send this request. \
Enable the `rustls` crate feature or set a connector to fix this.")
));
connector.call(request).await.map_err(|err| {
match err.downcast() {
Ok(connector_error) => OrchestratorError::connector(*connector_error),
Err(box_err) => OrchestratorError::other(box_err)
}
})
connector.call(request).await.map_err(OrchestratorError::connector)
});
trace!(response = ?response, "received response from service");
ctx.set_response(response);
@ -442,7 +437,9 @@ mod tests {
use aws_smithy_runtime_api::client::auth::{
AuthSchemeOptionResolverParams, SharedAuthSchemeOptionResolver,
};
use aws_smithy_runtime_api::client::connectors::{HttpConnector, SharedHttpConnector};
use aws_smithy_runtime_api::client::connectors::{
HttpConnector, HttpConnectorFuture, SharedHttpConnector,
};
use aws_smithy_runtime_api::client::endpoint::{
EndpointResolverParams, SharedEndpointResolver,
};
@ -454,7 +451,7 @@ mod tests {
FinalizerInterceptorContextRef,
};
use aws_smithy_runtime_api::client::interceptors::{Interceptor, SharedInterceptor};
use aws_smithy_runtime_api::client::orchestrator::{BoxFuture, Future, HttpRequest};
use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
use aws_smithy_runtime_api::client::retries::SharedRetryStrategy;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
use aws_smithy_runtime_api::client::runtime_plugin::{RuntimePlugin, RuntimePlugins};
@ -492,11 +489,11 @@ mod tests {
}
impl HttpConnector for OkConnector {
fn call(&self, _request: HttpRequest) -> BoxFuture<HttpResponse> {
Box::pin(Future::ready(Ok(::http::Response::builder()
fn call(&self, _request: HttpRequest) -> HttpConnectorFuture {
HttpConnectorFuture::ready(Ok(::http::Response::builder()
.status(200)
.body(SdkBody::empty())
.expect("OK response is valid"))))
.expect("OK response is valid")))
}
}

View File

@ -0,0 +1,106 @@
{
"events": [
{
"connection_id": 0,
"action": {
"Request": {
"request": {
"uri": "https://www.example.com/",
"headers": {},
"method": "POST"
}
}
}
},
{
"connection_id": 0,
"action": {
"Data": {
"data": {
"Utf8": "hello world"
},
"direction": "Request"
}
}
},
{
"connection_id": 0,
"action": {
"Eof": {
"ok": true,
"direction": "Request"
}
}
},
{
"connection_id": 0,
"action": {
"Response": {
"response": {
"Ok": {
"status": 200,
"version": "HTTP/2.0",
"headers": {
"etag": [
"\"3147526947+ident\""
],
"vary": [
"Accept-Encoding"
],
"server": [
"ECS (bsa/EB20)"
],
"x-cache": [
"HIT"
],
"age": [
"355292"
],
"content-length": [
"1256"
],
"cache-control": [
"max-age=604800"
],
"expires": [
"Mon, 16 Aug 2021 18:51:30 GMT"
],
"content-type": [
"text/html; charset=UTF-8"
],
"date": [
"Mon, 09 Aug 2021 18:51:30 GMT"
],
"last-modified": [
"Thu, 17 Oct 2019 07:18:26 GMT"
]
}
}
}
}
}
},
{
"connection_id": 0,
"action": {
"Data": {
"data": {
"Utf8": "hello from example.com"
},
"direction": "Response"
}
}
},
{
"connection_id": 0,
"action": {
"Eof": {
"ok": true,
"direction": "Response"
}
}
}
],
"docs": "test of example.com. response body has been manually changed",
"version": "V0"
}