diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index 07e4f95af..8407592ed 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -241,3 +241,15 @@ message = "The `futures_core::stream::Stream` trait has been removed from [`Byte references = ["smithy-rs#2983"] meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" } author = "ysaito1001" + +[[smithy-rs]] +message = "`StaticUriEndpointResolver`'s `uri` constructor now takes a `String` instead of a `Uri`." +references = ["smithy-rs#2997"] +meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" } +author = "jdisanti" + +[[aws-sdk-rust]] +message = "The IMDS Client builder's `build()` method is no longer async." +references = ["smithy-rs#2997"] +meta = { "breaking" = true, "tada" = false, "bug" = false } +author = "jdisanti" diff --git a/aws/rust-runtime/aws-config/Cargo.toml b/aws/rust-runtime/aws-config/Cargo.toml index 60fb0f526..2e9d00f08 100644 --- a/aws/rust-runtime/aws-config/Cargo.toml +++ b/aws/rust-runtime/aws-config/Cargo.toml @@ -9,7 +9,7 @@ license = "Apache-2.0" repository = "https://github.com/awslabs/smithy-rs" [features] -client-hyper = ["aws-smithy-client/client-hyper"] +client-hyper = ["aws-smithy-client/client-hyper", "aws-smithy-runtime/connector-hyper"] rustls = ["aws-smithy-client/rustls", "client-hyper"] native-tls = [] allow-compilation = [] # our tests use `cargo test --all-features` and native-tls breaks CI @@ -27,7 +27,10 @@ aws-smithy-client = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-client", de aws-smithy-http = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-http" } aws-smithy-http-tower = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-http-tower" } aws-smithy-json = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-json" } +aws-smithy-runtime = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-runtime", features = ["client"] } +aws-smithy-runtime-api = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-runtime-api", features = ["client"] } aws-smithy-types = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-types" } +aws-runtime = { path = "../../sdk/build/aws-sdk/sdk/aws-runtime" } aws-types = { path = "../../sdk/build/aws-sdk/sdk/aws-types" } hyper = { version = "0.14.26", default-features = false } time = { version = "0.3.4", features = ["parsing"] } @@ -48,6 +51,7 @@ hex = { version = "0.4.3", optional = true } zeroize = { version = "1", optional = true } [dev-dependencies] +aws-smithy-runtime = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-runtime", features = ["client", "test-util"] } futures-util = { version = "0.3.16", default-features = false } tracing-test = "0.2.1" tracing-subscriber = { version = "0.3.16", features = ["fmt", "json"] } diff --git a/aws/rust-runtime/aws-config/examples/imds.rs b/aws/rust-runtime/aws-config/examples/imds.rs index 3722835e5..b1da23336 100644 --- a/aws/rust-runtime/aws-config/examples/imds.rs +++ b/aws/rust-runtime/aws-config/examples/imds.rs @@ -12,8 +12,8 @@ async fn main() -> Result<(), Box> { use aws_config::imds::Client; - let imds = Client::builder().build().await?; + let imds = Client::builder().build(); let instance_id = imds.get("/latest/meta-data/instance-id").await?; - println!("current instance id: {}", instance_id); + println!("current instance id: {}", instance_id.as_ref()); Ok(()) } diff --git a/aws/rust-runtime/aws-config/src/ecs.rs b/aws/rust-runtime/aws-config/src/ecs.rs index 17fab949f..f5ca10e54 100644 --- a/aws/rust-runtime/aws-config/src/ecs.rs +++ b/aws/rust-runtime/aws-config/src/ecs.rs @@ -55,7 +55,7 @@ use aws_credential_types::provider::{self, error::CredentialsError, future, Prov use aws_smithy_client::erase::boxclone::BoxCloneService; use aws_smithy_http::endpoint::apply_endpoint; use aws_smithy_types::error::display::DisplayErrorContext; -use http::uri::{InvalidUri, Scheme}; +use http::uri::{InvalidUri, PathAndQuery, Scheme}; use http::{HeaderValue, Uri}; use tower::{Service, ServiceExt}; @@ -166,6 +166,15 @@ impl Provider { Err(EcsConfigurationError::NotConfigured) => return Provider::NotConfigured, Err(err) => return Provider::InvalidConfiguration(err), }; + let path = uri.path().to_string(); + let endpoint = { + let mut parts = uri.into_parts(); + parts.path_and_query = Some(PathAndQuery::from_static("/")); + Uri::from_parts(parts) + } + .expect("parts will be valid") + .to_string(); + let http_provider = HttpCredentialProvider::builder() .configure(&provider_config) .connector_settings( @@ -174,7 +183,7 @@ impl Provider { .read_timeout(DEFAULT_READ_TIMEOUT) .build(), ) - .build("EcsContainer", uri); + .build("EcsContainer", &endpoint, path); Provider::Configured(http_provider) } diff --git a/aws/rust-runtime/aws-config/src/http_credential_provider.rs b/aws/rust-runtime/aws-config/src/http_credential_provider.rs index 2568cc435..87950ea17 100644 --- a/aws/rust-runtime/aws-config/src/http_credential_provider.rs +++ b/aws/rust-runtime/aws-config/src/http_credential_provider.rs @@ -8,35 +8,43 @@ //! //! Future work will stabilize this interface and enable it to be used directly. -use aws_credential_types::provider::{self, error::CredentialsError}; -use aws_credential_types::Credentials; -use aws_smithy_client::erase::DynConnector; -use aws_smithy_client::http_connector::ConnectorSettings; -use aws_smithy_http::body::SdkBody; -use aws_smithy_http::operation::{Operation, Request}; -use aws_smithy_http::response::ParseStrictResponse; -use aws_smithy_http::result::{SdkError, SdkSuccess}; -use aws_smithy_http::retry::ClassifyRetry; -use aws_smithy_types::retry::{ErrorKind, RetryKind}; - use crate::connector::expect_connector; use crate::json_credentials::{parse_json_credentials, JsonCredentials, RefreshableCredentials}; use crate::provider_config::ProviderConfig; - -use bytes::Bytes; +use aws_credential_types::provider::{self, error::CredentialsError}; +use aws_credential_types::Credentials; +use aws_smithy_client::http_connector::ConnectorSettings; +use aws_smithy_http::body::SdkBody; +use aws_smithy_http::result::SdkError; +use aws_smithy_runtime::client::connectors::adapter::DynConnectorAdapter; +use aws_smithy_runtime::client::orchestrator::operation::Operation; +use aws_smithy_runtime::client::retries::classifier::{ + HttpStatusCodeClassifier, SmithyErrorClassifier, +}; +use aws_smithy_runtime_api::client::connectors::SharedHttpConnector; +use aws_smithy_runtime_api::client::interceptors::context::{Error, InterceptorContext}; +use aws_smithy_runtime_api::client::orchestrator::{ + HttpResponse, OrchestratorError, SensitiveOutput, +}; +use aws_smithy_runtime_api::client::retries::{ClassifyRetry, RetryClassifiers, RetryReason}; +use aws_smithy_runtime_api::client::runtime_plugin::StaticRuntimePlugin; +use aws_smithy_types::config_bag::Layer; +use aws_smithy_types::retry::{ErrorKind, RetryConfig}; use http::header::{ACCEPT, AUTHORIZATION}; -use http::{HeaderValue, Response, Uri}; +use http::{HeaderValue, Response}; use std::time::Duration; -use tower::layer::util::Identity; const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(5); const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(2); +#[derive(Debug)] +struct HttpProviderAuth { + auth: Option, +} + #[derive(Debug)] pub(crate) struct HttpCredentialProvider { - uri: Uri, - client: aws_smithy_client::Client, - provider_name: &'static str, + operation: Operation, } impl HttpCredentialProvider { @@ -45,34 +53,13 @@ impl HttpCredentialProvider { } pub(crate) async fn credentials(&self, auth: Option) -> provider::Result { - let credentials = self.client.call(self.operation(auth)).await; + let credentials = self.operation.invoke(HttpProviderAuth { auth }).await; match credentials { Ok(creds) => Ok(creds), Err(SdkError::ServiceError(context)) => Err(context.into_err()), Err(other) => Err(CredentialsError::unhandled(other)), } } - - fn operation( - &self, - auth: Option, - ) -> Operation { - let mut http_req = http::Request::builder() - .uri(&self.uri) - .header(ACCEPT, "application/json"); - - if let Some(auth) = auth { - http_req = http_req.header(AUTHORIZATION, auth); - } - let http_req = http_req.body(SdkBody::empty()).expect("valid request"); - Operation::new( - Request::new(http_req), - CredentialsResponseParser { - provider_name: self.provider_name, - }, - ) - .with_retry_classifier(HttpCredentialRetryClassifier) - } } #[derive(Default)] @@ -92,7 +79,12 @@ impl Builder { self } - pub(crate) fn build(self, provider_name: &'static str, uri: Uri) -> HttpCredentialProvider { + pub(crate) fn build( + self, + provider_name: &'static str, + endpoint: &str, + path: impl Into, + ) -> HttpCredentialProvider { let provider_config = self.provider_config.unwrap_or_default(); let connector_settings = self.connector_settings.unwrap_or_else(|| { ConnectorSettings::builder() @@ -104,198 +96,241 @@ impl Builder { "The HTTP credentials provider", provider_config.connector(&connector_settings), ); - let mut client_builder = aws_smithy_client::Client::builder() - .connector(connector) - .middleware(Identity::new()); - client_builder.set_sleep_impl(provider_config.sleep()); - let client = client_builder.build(); - HttpCredentialProvider { - uri, - client, - provider_name, + + // The following errors are retryable: + // - Socket errors + // - Networking timeouts + // - 5xx errors + // - Non-parseable 200 responses. + let retry_classifiers = RetryClassifiers::new() + .with_classifier(HttpCredentialRetryClassifier) + // Socket errors and network timeouts + .with_classifier(SmithyErrorClassifier::::new()) + // 5xx errors + .with_classifier(HttpStatusCodeClassifier::default()); + + let mut builder = Operation::builder() + .service_name("HttpCredentialProvider") + .operation_name("LoadCredentials") + .http_connector(SharedHttpConnector::new(DynConnectorAdapter::new( + connector, + ))) + .endpoint_url(endpoint) + .no_auth() + .runtime_plugin(StaticRuntimePlugin::new().with_config({ + let mut layer = Layer::new("SensitiveOutput"); + layer.store_put(SensitiveOutput); + layer.freeze() + })); + if let Some(sleep_impl) = provider_config.sleep() { + builder = builder + .standard_retry(&RetryConfig::standard()) + .retry_classifiers(retry_classifiers) + .sleep_impl(sleep_impl); + } else { + builder = builder.no_retry(); } + let path = path.into(); + let operation = builder + .serializer(move |input: HttpProviderAuth| { + let mut http_req = http::Request::builder() + .uri(path.clone()) + .header(ACCEPT, "application/json"); + if let Some(auth) = input.auth { + http_req = http_req.header(AUTHORIZATION, auth); + } + Ok(http_req.body(SdkBody::empty()).expect("valid request")) + }) + .deserializer(move |response| parse_response(provider_name, response)) + .build(); + HttpCredentialProvider { operation } } } -#[derive(Clone, Debug)] -struct CredentialsResponseParser { +fn parse_response( provider_name: &'static str, -} -impl ParseStrictResponse for CredentialsResponseParser { - type Output = provider::Result; - - fn parse(&self, response: &Response) -> Self::Output { - if !response.status().is_success() { - return Err(CredentialsError::provider_error(format!( + response: &Response, +) -> Result> { + if !response.status().is_success() { + return Err(OrchestratorError::operation( + CredentialsError::provider_error(format!( "Non-success status from HTTP credential provider: {:?}", response.status() - ))); - } - let str_resp = - std::str::from_utf8(response.body().as_ref()).map_err(CredentialsError::unhandled)?; - let json_creds = parse_json_credentials(str_resp).map_err(CredentialsError::unhandled)?; - match json_creds { - JsonCredentials::RefreshableCredentials(RefreshableCredentials { - access_key_id, - secret_access_key, - session_token, - expiration, - }) => Ok(Credentials::new( - access_key_id, - secret_access_key, - Some(session_token.to_string()), - Some(expiration), - self.provider_name, )), - JsonCredentials::Error { code, message } => Err(CredentialsError::provider_error( - format!("failed to load credentials [{}]: {}", code, message), - )), - } + )); } - - fn sensitive(&self) -> bool { - true + let resp_bytes = response.body().bytes().expect("non-streaming deserializer"); + let str_resp = std::str::from_utf8(resp_bytes) + .map_err(|err| OrchestratorError::operation(CredentialsError::unhandled(err)))?; + let json_creds = parse_json_credentials(str_resp) + .map_err(|err| OrchestratorError::operation(CredentialsError::unhandled(err)))?; + match json_creds { + JsonCredentials::RefreshableCredentials(RefreshableCredentials { + access_key_id, + secret_access_key, + session_token, + expiration, + }) => Ok(Credentials::new( + access_key_id, + secret_access_key, + Some(session_token.to_string()), + Some(expiration), + provider_name, + )), + JsonCredentials::Error { code, message } => Err(OrchestratorError::operation( + CredentialsError::provider_error(format!( + "failed to load credentials [{}]: {}", + code, message + )), + )), } } #[derive(Clone, Debug)] struct HttpCredentialRetryClassifier; -impl ClassifyRetry, SdkError> - for HttpCredentialRetryClassifier -{ - fn classify_retry( - &self, - response: Result<&SdkSuccess, &SdkError>, - ) -> RetryKind { - /* The following errors are retryable: - * - Socket errors - * - Networking timeouts - * - 5xx errors - * - Non-parseable 200 responses. - * */ - match response { - Ok(_) => RetryKind::Unnecessary, - // socket errors, networking timeouts - Err(SdkError::DispatchFailure(client_err)) - if client_err.is_timeout() || client_err.is_io() => - { - RetryKind::Error(ErrorKind::TransientError) +impl ClassifyRetry for HttpCredentialRetryClassifier { + fn name(&self) -> &'static str { + "HttpCredentialRetryClassifier" + } + + fn classify_retry(&self, ctx: &InterceptorContext) -> Option { + let output_or_error = ctx.output_or_error()?; + let error = match output_or_error { + Ok(_) => return None, + Err(err) => err, + }; + + // Retry non-parseable 200 responses + if let Some((err, status)) = error + .as_operation_error() + .and_then(|err| err.downcast_ref::()) + .zip(ctx.response().map(HttpResponse::status)) + { + if matches!(err, CredentialsError::Unhandled { .. }) && status.is_success() { + return Some(RetryReason::Error(ErrorKind::ServerError)); } - // non-parseable 200s - Err(SdkError::ServiceError(context)) - if matches!(context.err(), CredentialsError::Unhandled { .. }) - && context.raw().http().status().is_success() => - { - RetryKind::Error(ErrorKind::ServerError) - } - // 5xx errors - Err(SdkError::ResponseError(context)) - if context.raw().http().status().is_server_error() => - { - RetryKind::Error(ErrorKind::ServerError) - } - Err(SdkError::ServiceError(context)) - if context.raw().http().status().is_server_error() => - { - RetryKind::Error(ErrorKind::ServerError) - } - Err(_) => RetryKind::UnretryableFailure, } + + None } } #[cfg(test)] mod test { - use crate::http_credential_provider::{ - CredentialsResponseParser, HttpCredentialRetryClassifier, - }; + use super::*; use aws_credential_types::provider::error::CredentialsError; - use aws_credential_types::Credentials; + use aws_smithy_client::test_connection::TestConnection; use aws_smithy_http::body::SdkBody; - use aws_smithy_http::operation; - use aws_smithy_http::response::ParseStrictResponse; - use aws_smithy_http::result::{SdkError, SdkSuccess}; - use aws_smithy_http::retry::ClassifyRetry; - use aws_smithy_types::retry::{ErrorKind, RetryKind}; - use bytes::Bytes; + use aws_smithy_runtime_api::client::orchestrator::HttpRequest; + use http::{Request, Response, Uri}; + use std::time::SystemTime; - fn sdk_resp( - resp: http::Response<&'static str>, - ) -> Result, SdkError> { - let resp = resp.map(|data| Bytes::from_static(data.as_bytes())); - match (CredentialsResponseParser { - provider_name: "test", - }) - .parse(&resp) - { - Ok(creds) => Ok(SdkSuccess { - raw: operation::Response::new(resp.map(SdkBody::from)), - parsed: creds, - }), - Err(err) => Err(SdkError::service_error( - err, - operation::Response::new(resp.map(SdkBody::from)), - )), - } + async fn provide_creds( + connector: TestConnection, + ) -> Result { + let provider_config = ProviderConfig::default().with_http_connector(connector.clone()); + let provider = HttpCredentialProvider::builder() + .configure(&provider_config) + .build("test", "http://localhost:1234/", "/some-creds"); + provider.credentials(None).await } - #[test] - fn non_parseable_is_retriable() { - let bad_response = http::Response::builder() - .status(200) - .body("notjson") - .unwrap(); - - assert_eq!( - HttpCredentialRetryClassifier.classify_retry(sdk_resp(bad_response).as_ref()), - RetryKind::Error(ErrorKind::ServerError) - ); + fn successful_req_resp() -> (HttpRequest, HttpResponse) { + ( + Request::builder() + .uri(Uri::from_static("http://localhost:1234/some-creds")) + .body(SdkBody::empty()) + .unwrap(), + Response::builder() + .status(200) + .body(SdkBody::from( + r#"{ + "AccessKeyId" : "MUA...", + "SecretAccessKey" : "/7PC5om....", + "Token" : "AQoDY....=", + "Expiration" : "2016-02-25T06:03:31Z" + }"#, + )) + .unwrap(), + ) } - #[test] - fn ok_response_not_retriable() { - let ok_response = http::Response::builder() - .status(200) - .body( - r#" { - "AccessKeyId" : "MUA...", - "SecretAccessKey" : "/7PC5om....", - "Token" : "AQoDY....=", - "Expiration" : "2016-02-25T06:03:31Z" - }"#, - ) - .unwrap(); - let sdk_result = sdk_resp(ok_response); - + #[tokio::test] + async fn successful_response() { + let connector = TestConnection::new(vec![successful_req_resp()]); + let creds = provide_creds(connector.clone()).await.expect("success"); + assert_eq!("MUA...", creds.access_key_id()); + assert_eq!("/7PC5om....", creds.secret_access_key()); + assert_eq!(Some("AQoDY....="), creds.session_token()); assert_eq!( - HttpCredentialRetryClassifier.classify_retry(sdk_result.as_ref()), - RetryKind::Unnecessary + Some(SystemTime::UNIX_EPOCH + Duration::from_secs(1456380211)), + creds.expiry() ); - - assert!(sdk_result.is_ok(), "should be ok: {:?}", sdk_result) + connector.assert_requests_match(&[]); } - #[test] - fn explicit_error_not_retriable() { - let error_response = http::Response::builder() - .status(400) - .body(r#"{ "Code": "Error", "Message": "There was a problem, it was your fault" }"#) - .unwrap(); - let sdk_result = sdk_resp(error_response); - assert_eq!( - HttpCredentialRetryClassifier.classify_retry(sdk_result.as_ref()), - RetryKind::UnretryableFailure - ); - let sdk_error = sdk_result.expect_err("should be error"); - - assert!( - matches!( - sdk_error, - SdkError::ServiceError(ref context) if matches!(context.err(), CredentialsError::ProviderError { .. }) + #[tokio::test] + async fn retry_nonparseable_response() { + let connector = TestConnection::new(vec![ + ( + Request::builder() + .uri(Uri::from_static("http://localhost:1234/some-creds")) + .body(SdkBody::empty()) + .unwrap(), + Response::builder() + .status(200) + .body(SdkBody::from(r#"not json"#)) + .unwrap(), ), - "should be provider error: {}", - sdk_error + successful_req_resp(), + ]); + let creds = provide_creds(connector.clone()).await.expect("success"); + assert_eq!("MUA...", creds.access_key_id()); + connector.assert_requests_match(&[]); + } + + #[tokio::test] + async fn retry_error_code() { + let connector = TestConnection::new(vec![ + ( + Request::builder() + .uri(Uri::from_static("http://localhost:1234/some-creds")) + .body(SdkBody::empty()) + .unwrap(), + Response::builder() + .status(500) + .body(SdkBody::from(r#"it broke"#)) + .unwrap(), + ), + successful_req_resp(), + ]); + let creds = provide_creds(connector.clone()).await.expect("success"); + assert_eq!("MUA...", creds.access_key_id()); + connector.assert_requests_match(&[]); + } + + #[tokio::test] + async fn explicit_error_not_retriable() { + let connector = TestConnection::new(vec![( + Request::builder() + .uri(Uri::from_static("http://localhost:1234/some-creds")) + .body(SdkBody::empty()) + .unwrap(), + Response::builder() + .status(400) + .body(SdkBody::from( + r#"{ "Code": "Error", "Message": "There was a problem, it was your fault" }"#, + )) + .unwrap(), + )]); + let err = provide_creds(connector.clone()) + .await + .expect_err("it should fail"); + assert!( + matches!(err, CredentialsError::ProviderError { .. }), + "should be CredentialsError::ProviderError: {err}", ); + connector.assert_requests_match(&[]); } } diff --git a/aws/rust-runtime/aws-config/src/imds/client.rs b/aws/rust-runtime/aws-config/src/imds/client.rs index f084bc9a8..ed76fc866 100644 --- a/aws/rust-runtime/aws-config/src/imds/client.rs +++ b/aws/rust-runtime/aws-config/src/imds/client.rs @@ -9,34 +9,48 @@ use crate::connector::expect_connector; use crate::imds::client::error::{BuildError, ImdsError, InnerImdsError, InvalidEndpointMode}; -use crate::imds::client::token::TokenMiddleware; +use crate::imds::client::token::TokenRuntimePlugin; use crate::provider_config::ProviderConfig; use crate::PKG_VERSION; -use aws_http::user_agent::{ApiMetadata, AwsUserAgent, UserAgentStage}; +use aws_http::user_agent::{ApiMetadata, AwsUserAgent}; +use aws_runtime::user_agent::UserAgentInterceptor; +use aws_smithy_async::rt::sleep::SharedAsyncSleep; +use aws_smithy_async::time::SharedTimeSource; +use aws_smithy_client::erase::DynConnector; use aws_smithy_client::http_connector::ConnectorSettings; -use aws_smithy_client::{erase::DynConnector, SdkSuccess}; -use aws_smithy_client::{retry, SdkError}; +use aws_smithy_client::SdkError; use aws_smithy_http::body::SdkBody; -use aws_smithy_http::endpoint::apply_endpoint; -use aws_smithy_http::operation; -use aws_smithy_http::operation::{Metadata, Operation}; -use aws_smithy_http::response::ParseStrictResponse; -use aws_smithy_http::retry::ClassifyRetry; -use aws_smithy_http_tower::map_request::{ - AsyncMapRequestLayer, AsyncMapRequestService, MapRequestLayer, MapRequestService, +use aws_smithy_http::result::ConnectorError; +use aws_smithy_runtime::client::connectors::adapter::DynConnectorAdapter; +use aws_smithy_runtime::client::orchestrator::operation::Operation; +use aws_smithy_runtime::client::retries::strategy::StandardRetryStrategy; +use aws_smithy_runtime_api::client::auth::AuthSchemeOptionResolverParams; +use aws_smithy_runtime_api::client::connectors::SharedHttpConnector; +use aws_smithy_runtime_api::client::endpoint::{ + EndpointResolver, EndpointResolverParams, SharedEndpointResolver, }; -use aws_smithy_types::error::display::DisplayErrorContext; -use aws_smithy_types::retry::{ErrorKind, RetryKind}; +use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext; +use aws_smithy_runtime_api::client::interceptors::SharedInterceptor; +use aws_smithy_runtime_api::client::orchestrator::{ + Future, HttpResponse, OrchestratorError, SensitiveOutput, +}; +use aws_smithy_runtime_api::client::retries::{ + ClassifyRetry, RetryClassifiers, RetryReason, SharedRetryStrategy, +}; +use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder; +use aws_smithy_runtime_api::client::runtime_plugin::{RuntimePlugin, SharedRuntimePlugin}; +use aws_smithy_types::config_bag::{FrozenLayer, Layer}; +use aws_smithy_types::endpoint::Endpoint; +use aws_smithy_types::retry::{ErrorKind, RetryConfig}; use aws_smithy_types::timeout::TimeoutConfig; use aws_types::os_shim_internal::Env; -use bytes::Bytes; -use http::{Response, Uri}; +use http::Uri; use std::borrow::Cow; -use std::error::Error; +use std::error::Error as _; +use std::fmt; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; -use tokio::sync::OnceCell; pub mod error; mod token; @@ -85,8 +99,7 @@ fn user_agent() -> AwsUserAgent { /// # async fn docs() { /// let client = Client::builder() /// .endpoint(Uri::from_static("http://customidms:456/")) -/// .build() -/// .await; +/// .build(); /// # } /// ``` /// @@ -104,7 +117,7 @@ fn user_agent() -> AwsUserAgent { /// ```no_run /// use aws_config::imds::client::{Client, EndpointMode}; /// # async fn docs() { -/// let client = Client::builder().endpoint_mode(EndpointMode::IpV6).build().await; +/// let client = Client::builder().endpoint_mode(EndpointMode::IpV6).build(); /// # } /// ``` /// @@ -123,49 +136,7 @@ fn user_agent() -> AwsUserAgent { /// #[derive(Clone, Debug)] pub struct Client { - inner: Arc, -} - -#[derive(Debug)] -struct ClientInner { - endpoint: Uri, - smithy_client: aws_smithy_client::Client, -} - -/// Client where build is sync, but usage is async -/// -/// Building an imds::Client is actually an async operation, however, for credentials and region -/// providers, we want build to always be a synchronous operation. This allows building to be deferred -/// and cached until request time. -#[derive(Debug)] -pub(super) struct LazyClient { - client: OnceCell>, - builder: Builder, -} - -impl LazyClient { - pub(super) fn from_ready_client(client: Client) -> Self { - Self { - client: OnceCell::from(Ok(client)), - // the builder will never be used in this case - builder: Builder::default(), - } - } - pub(super) async fn client(&self) -> Result<&Client, &BuildError> { - let builder = &self.builder; - self.client - // the clone will only happen once when we actually construct it for the first time, - // after that, we will use the cache. - .get_or_init(|| async { - let client = builder.clone().build().await; - if let Err(err) = &client { - tracing::warn!(err = %DisplayErrorContext(err), "failed to create IMDS client") - } - client - }) - .await - .as_ref() - } + operation: Operation, } impl Client { @@ -187,18 +158,16 @@ impl Client { /// ```no_run /// use aws_config::imds::client::Client; /// # async fn docs() { - /// let client = Client::builder().build().await.expect("valid client"); + /// let client = Client::builder().build(); /// let ami_id = client /// .get("/latest/meta-data/ami-id") /// .await /// .expect("failure communicating with IMDS"); /// # } /// ``` - pub async fn get(&self, path: &str) -> Result { - let operation = self.make_operation(path)?; - self.inner - .smithy_client - .call(operation) + pub async fn get(&self, path: impl Into) -> Result { + self.operation + .invoke(path.into()) .await .map_err(|err| match err { SdkError::ConstructionFailure(_) if err.source().is_some() => { @@ -213,76 +182,112 @@ impl Client { InnerImdsError::InvalidUtf8 => { ImdsError::unexpected("IMDS returned invalid UTF-8") } - InnerImdsError::BadStatus => { - ImdsError::error_response(context.into_raw().into_parts().0) - } + InnerImdsError::BadStatus => ImdsError::error_response(context.into_raw()), }, - SdkError::TimeoutError(_) - | SdkError::DispatchFailure(_) - | SdkError::ResponseError(_) => ImdsError::io_error(err), + // If the error source is an ImdsError, then we need to directly return that source. + // That way, the IMDS token provider's errors can become the top-level ImdsError. + // There is a unit test that checks the correct error is being extracted. + err @ SdkError::DispatchFailure(_) => match err.into_source() { + Ok(source) => match source.downcast::() { + Ok(source) => match source.into_source().downcast::() { + Ok(source) => *source, + Err(err) => ImdsError::unexpected(err), + }, + Err(err) => ImdsError::unexpected(err), + }, + Err(err) => ImdsError::unexpected(err), + }, + SdkError::TimeoutError(_) | SdkError::ResponseError(_) => ImdsError::io_error(err), _ => ImdsError::unexpected(err), }) } +} - /// Creates a aws_smithy_http Operation to for `path` - /// - Convert the path to a URI - /// - Set the base endpoint on the URI - /// - Add a user agent - fn make_operation( - &self, - path: &str, - ) -> Result, ImdsError> { - let mut base_uri: Uri = path.parse().map_err(|_| { - ImdsError::unexpected("IMDS path was not a valid URI. Hint: does it begin with `/`?") - })?; - apply_endpoint(&mut base_uri, &self.inner.endpoint, None).map_err(ImdsError::unexpected)?; - let request = http::Request::builder() - .uri(base_uri) - .body(SdkBody::empty()) - .expect("valid request"); - let mut request = operation::Request::new(request); - request.properties_mut().insert(user_agent()); - Ok(Operation::new(request, ImdsGetResponseHandler) - .with_metadata(Metadata::new("get", "imds")) - .with_retry_classifier(ImdsResponseRetryClassifier)) +/// New-type around `String` that doesn't emit the string value in the `Debug` impl. +#[derive(Clone)] +pub struct SensitiveString(String); + +impl fmt::Debug for SensitiveString { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("SensitiveString") + .field(&"** redacted **") + .finish() } } -/// IMDS Middleware -/// -/// The IMDS middleware includes a token-loader & a UserAgent stage -#[derive(Clone, Debug)] -struct ImdsMiddleware { - token_loader: TokenMiddleware, -} - -impl tower::Layer for ImdsMiddleware { - type Service = AsyncMapRequestService, TokenMiddleware>; - - fn layer(&self, inner: S) -> Self::Service { - AsyncMapRequestLayer::for_mapper(self.token_loader.clone()) - .layer(MapRequestLayer::for_mapper(UserAgentStage::new()).layer(inner)) +impl AsRef for SensitiveString { + fn as_ref(&self) -> &str { + &self.0 } } -#[derive(Copy, Clone)] -struct ImdsGetResponseHandler; +impl From for SensitiveString { + fn from(value: String) -> Self { + Self(value) + } +} -impl ParseStrictResponse for ImdsGetResponseHandler { - type Output = Result; +impl From for String { + fn from(value: SensitiveString) -> Self { + value.0 + } +} - fn parse(&self, response: &Response) -> Self::Output { - if response.status().is_success() { - std::str::from_utf8(response.body().as_ref()) - .map(|data| data.to_string()) - .map_err(|_| InnerImdsError::InvalidUtf8) - } else { - Err(InnerImdsError::BadStatus) +/// Runtime plugin that is used by both the IMDS client and the inner client that resolves +/// the IMDS token and attaches it to requests. This runtime plugin marks the responses as +/// sensitive, configures user agent headers, and sets up retries and timeouts. +#[derive(Debug)] +struct ImdsCommonRuntimePlugin { + config: FrozenLayer, + components: RuntimeComponentsBuilder, +} + +impl ImdsCommonRuntimePlugin { + fn new( + connector: DynConnector, + endpoint_resolver: ImdsEndpointResolver, + retry_config: &RetryConfig, + timeout_config: TimeoutConfig, + time_source: SharedTimeSource, + sleep_impl: Option, + ) -> Self { + let mut layer = Layer::new("ImdsCommonRuntimePlugin"); + layer.store_put(AuthSchemeOptionResolverParams::new(())); + layer.store_put(EndpointResolverParams::new(())); + layer.store_put(SensitiveOutput); + layer.store_put(timeout_config); + layer.store_put(user_agent()); + + Self { + config: layer.freeze(), + components: RuntimeComponentsBuilder::new("ImdsCommonRuntimePlugin") + .with_http_connector(Some(SharedHttpConnector::new(DynConnectorAdapter::new( + connector, + )))) + .with_endpoint_resolver(Some(SharedEndpointResolver::new(endpoint_resolver))) + .with_interceptor(SharedInterceptor::new(UserAgentInterceptor::new())) + .with_retry_classifiers(Some( + RetryClassifiers::new().with_classifier(ImdsResponseRetryClassifier), + )) + .with_retry_strategy(Some(SharedRetryStrategy::new(StandardRetryStrategy::new( + retry_config, + )))) + .with_time_source(Some(time_source)) + .with_sleep_impl(sleep_impl), } } +} - fn sensitive(&self) -> bool { - true +impl RuntimePlugin for ImdsCommonRuntimePlugin { + fn config(&self) -> Option { + Some(self.config.clone()) + } + + fn runtime_components( + &self, + _current_components: &RuntimeComponentsBuilder, + ) -> Cow<'_, RuntimeComponentsBuilder> { + Cow::Borrowed(&self.components) } } @@ -415,15 +420,8 @@ impl Builder { self }*/ - pub(super) fn build_lazy(self) -> LazyClient { - LazyClient { - client: OnceCell::new(), - builder: self, - } - } - /// Build an IMDSv2 Client - pub async fn build(self) -> Result { + pub fn build(self) -> Client { let config = self.config.unwrap_or_default(); let timeout_config = TimeoutConfig::builder() .connect_timeout(self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT)) @@ -437,34 +435,46 @@ impl Builder { let endpoint_source = self .endpoint .unwrap_or_else(|| EndpointSource::Env(config.clone())); - let endpoint = endpoint_source.endpoint(self.mode_override).await?; - let retry_config = retry::Config::default() - .with_max_attempts(self.max_attempts.unwrap_or(DEFAULT_ATTEMPTS)); - let token_loader = token::TokenMiddleware::new( - connector.clone(), - config.time_source(), - endpoint.clone(), - self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL), - retry_config.clone(), - timeout_config.clone(), - config.sleep(), - ); - let middleware = ImdsMiddleware { token_loader }; - let mut smithy_builder = aws_smithy_client::Client::builder() - .connector(connector.clone()) - .middleware(middleware) - .retry_config(retry_config) - .operation_timeout_config(timeout_config.into()); - smithy_builder.set_sleep_impl(config.sleep()); - let smithy_client = smithy_builder.build(); - - let client = Client { - inner: Arc::new(ClientInner { - endpoint, - smithy_client, - }), + let endpoint_resolver = ImdsEndpointResolver { + endpoint_source: Arc::new(endpoint_source), + mode_override: self.mode_override, }; - Ok(client) + let retry_config = RetryConfig::standard() + .with_max_attempts(self.max_attempts.unwrap_or(DEFAULT_ATTEMPTS)); + let common_plugin = SharedRuntimePlugin::new(ImdsCommonRuntimePlugin::new( + connector, + endpoint_resolver, + &retry_config, + timeout_config, + config.time_source(), + config.sleep(), + )); + let operation = Operation::builder() + .service_name("imds") + .operation_name("get") + .runtime_plugin(common_plugin.clone()) + .runtime_plugin(TokenRuntimePlugin::new( + common_plugin, + config.time_source(), + self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL), + )) + .serializer(|path| { + Ok(http::Request::builder() + .uri(path) + .body(SdkBody::empty()) + .expect("valid request")) + }) + .deserializer(|response| { + if response.status().is_success() { + std::str::from_utf8(response.body().bytes().expect("non-streaming response")) + .map(|data| SensitiveString::from(data.to_string())) + .map_err(|_| OrchestratorError::operation(InnerImdsError::InvalidUtf8)) + } else { + Err(OrchestratorError::operation(InnerImdsError::BadStatus)) + } + }) + .build(); + Client { operation } } } @@ -531,19 +541,22 @@ impl EndpointSource { } } -#[derive(Clone)] -struct ImdsResponseRetryClassifier; +#[derive(Clone, Debug)] +struct ImdsEndpointResolver { + endpoint_source: Arc, + mode_override: Option, +} -impl ImdsResponseRetryClassifier { - fn classify(response: &operation::Response) -> RetryKind { - let status = response.http().status(); - match status { - _ if status.is_server_error() => RetryKind::Error(ErrorKind::ServerError), - // 401 indicates that the token has expired, this is retryable - _ if status.as_u16() == 401 => RetryKind::Error(ErrorKind::ServerError), - // This catch-all includes successful responses that fail to parse. These should not be retried. - _ => RetryKind::UnretryableFailure, - } +impl EndpointResolver for ImdsEndpointResolver { + fn resolve_endpoint(&self, _: &EndpointResolverParams) -> Future { + let this = self.clone(); + Future::new(Box::pin(async move { + this.endpoint_source + .endpoint(this.mode_override) + .await + .map(|uri| Endpoint::builder().url(uri.to_string()).build()) + .map_err(|err| err.into()) + })) } } @@ -556,13 +569,35 @@ impl ImdsResponseRetryClassifier { /// - 403 (IMDS disabled): **Not Retryable** /// - 404 (Not found): **Not Retryable** /// - >=500 (server error): **Retryable** -impl ClassifyRetry, SdkError> for ImdsResponseRetryClassifier { - fn classify_retry(&self, response: Result<&SdkSuccess, &SdkError>) -> RetryKind { - match response { - Ok(_) => RetryKind::Unnecessary, - Err(SdkError::ResponseError(context)) => Self::classify(context.raw()), - Err(SdkError::ServiceError(context)) => Self::classify(context.raw()), - _ => RetryKind::UnretryableFailure, +#[derive(Clone, Debug)] +struct ImdsResponseRetryClassifier; + +impl ImdsResponseRetryClassifier { + fn classify(response: &HttpResponse) -> Option { + let status = response.status(); + match status { + _ if status.is_server_error() => Some(RetryReason::Error(ErrorKind::ServerError)), + // 401 indicates that the token has expired, this is retryable + _ if status.as_u16() == 401 => Some(RetryReason::Error(ErrorKind::ServerError)), + // This catch-all includes successful responses that fail to parse. These should not be retried. + _ => None, + } + } +} + +impl ClassifyRetry for ImdsResponseRetryClassifier { + fn name(&self) -> &'static str { + "ImdsResponseRetryClassifier" + } + + fn classify_retry(&self, ctx: &InterceptorContext) -> Option { + if let Some(response) = ctx.response() { + Self::classify(response) + } else { + // Don't retry timeouts for IMDS, or else it will take ~30 seconds for the default + // credentials provider chain to fail to provide credentials. + // Also don't retry non-responses. + None } } } @@ -575,10 +610,15 @@ pub(crate) mod test { use aws_smithy_async::test_util::instant_time_and_sleep; use aws_smithy_client::erase::DynConnector; use aws_smithy_client::test_connection::{capture_request, TestConnection}; - use aws_smithy_client::{SdkError, SdkSuccess}; use aws_smithy_http::body::SdkBody; - use aws_smithy_http::operation; - use aws_smithy_types::retry::RetryKind; + use aws_smithy_http::result::ConnectorError; + use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs; + use aws_smithy_runtime_api::client::interceptors::context::{ + Input, InterceptorContext, Output, + }; + use aws_smithy_runtime_api::client::orchestrator::OrchestratorError; + use aws_smithy_runtime_api::client::retries::ClassifyRetry; + use aws_smithy_types::error::display::DisplayErrorContext; use aws_types::os_shim_internal::{Env, Fs}; use http::header::USER_AGENT; use http::Uri; @@ -637,7 +677,7 @@ pub(crate) mod test { http::Response::builder().status(200).body(body).unwrap() } - pub(crate) async fn make_client(conn: &TestConnection) -> super::Client + pub(crate) fn make_client(conn: &TestConnection) -> super::Client where SdkBody: From, T: Send + 'static, @@ -650,8 +690,6 @@ pub(crate) mod test { .with_http_connector(DynConnector::new(conn.clone())), ) .build() - .await - .expect("valid client") } #[tokio::test] @@ -670,13 +708,13 @@ pub(crate) mod test { imds_response("output2"), ), ]); - let client = make_client(&connection).await; + let client = make_client(&connection); // load once let metadata = client.get("/latest/metadata").await.expect("failed"); - assert_eq!(metadata, "test-imds-output"); + assert_eq!("test-imds-output", metadata.as_ref()); // load again: the cached token should be used let metadata = client.get("/latest/metadata2").await.expect("failed"); - assert_eq!(metadata, "output2"); + assert_eq!("output2", metadata.as_ref()); connection.assert_requests_match(&[]); } @@ -710,17 +748,15 @@ pub(crate) mod test { ) .endpoint_mode(EndpointMode::IpV6) .token_ttl(Duration::from_secs(600)) - .build() - .await - .expect("valid client"); + .build(); let resp1 = client.get("/latest/metadata").await.expect("success"); // now the cached credential has expired time_source.advance(Duration::from_secs(600)); let resp2 = client.get("/latest/metadata").await.expect("success"); connection.assert_requests_match(&[]); - assert_eq!(resp1, "test-imds-output1"); - assert_eq!(resp2, "test-imds-output2"); + assert_eq!("test-imds-output1", resp1.as_ref()); + assert_eq!("test-imds-output2", resp2.as_ref()); } /// Tokens are refreshed up to 120 seconds early to avoid using an expired token. @@ -761,9 +797,7 @@ pub(crate) mod test { ) .endpoint_mode(EndpointMode::IpV6) .token_ttl(Duration::from_secs(600)) - .build() - .await - .expect("valid client"); + .build(); let resp1 = client.get("/latest/metadata").await.expect("success"); // now the cached credential has expired @@ -772,9 +806,9 @@ pub(crate) mod test { time_source.advance(Duration::from_secs(150)); let resp3 = client.get("/latest/metadata").await.expect("success"); connection.assert_requests_match(&[]); - assert_eq!(resp1, "test-imds-output1"); - assert_eq!(resp2, "test-imds-output2"); - assert_eq!(resp3, "test-imds-output3"); + assert_eq!("test-imds-output1", resp1.as_ref()); + assert_eq!("test-imds-output2", resp2.as_ref()); + assert_eq!("test-imds-output3", resp3.as_ref()); } /// 500 error during the GET should be retried @@ -795,8 +829,15 @@ pub(crate) mod test { imds_response("ok"), ), ]); - let client = make_client(&connection).await; - assert_eq!(client.get("/latest/metadata").await.expect("success"), "ok"); + let client = make_client(&connection); + assert_eq!( + "ok", + client + .get("/latest/metadata") + .await + .expect("success") + .as_ref() + ); connection.assert_requests_match(&[]); // all requests should have a user agent header @@ -823,8 +864,15 @@ pub(crate) mod test { imds_response("ok"), ), ]); - let client = make_client(&connection).await; - assert_eq!(client.get("/latest/metadata").await.expect("success"), "ok"); + let client = make_client(&connection); + assert_eq!( + "ok", + client + .get("/latest/metadata") + .await + .expect("success") + .as_ref() + ); connection.assert_requests_match(&[]); } @@ -850,8 +898,15 @@ pub(crate) mod test { imds_response("ok"), ), ]); - let client = make_client(&connection).await; - assert_eq!(client.get("/latest/metadata").await.expect("success"), "ok"); + let client = make_client(&connection); + assert_eq!( + "ok", + client + .get("/latest/metadata") + .await + .expect("success") + .as_ref() + ); connection.assert_requests_match(&[]); } @@ -863,7 +918,7 @@ pub(crate) mod test { token_request("http://169.254.169.254", 21600), http::Response::builder().status(403).body("").unwrap(), )]); - let client = make_client(&connection).await; + let client = make_client(&connection); let err = client.get("/latest/metadata").await.expect_err("no token"); assert_full_error_contains!(err, "forbidden"); connection.assert_requests_match(&[]); @@ -872,30 +927,18 @@ pub(crate) mod test { /// Successful responses should classify as `RetryKind::Unnecessary` #[test] fn successful_response_properly_classified() { - use aws_smithy_http::retry::ClassifyRetry; - + let mut ctx = InterceptorContext::new(Input::doesnt_matter()); + ctx.set_output_or_error(Ok(Output::doesnt_matter())); + ctx.set_response(imds_response("").map(|_| SdkBody::empty())); let classifier = ImdsResponseRetryClassifier; - fn response_200() -> operation::Response { - operation::Response::new(imds_response("").map(|_| SdkBody::empty())) - } - let success = SdkSuccess { - raw: response_200(), - parsed: (), - }; - assert_eq!( - RetryKind::Unnecessary, - classifier.classify_retry(Ok::<_, &SdkError<()>>(&success)) - ); + assert_eq!(None, classifier.classify_retry(&ctx)); // Emulate a failure to parse the response body (using an io error since it's easy to construct in a test) - let failure = SdkError::<()>::response_error( - io::Error::new(io::ErrorKind::BrokenPipe, "fail to parse"), - response_200(), - ); - assert_eq!( - RetryKind::UnretryableFailure, - classifier.classify_retry(Err::<&SdkSuccess<()>, _>(&failure)) - ); + let mut ctx = InterceptorContext::new(Input::doesnt_matter()); + ctx.set_output_or_error(Err(OrchestratorError::connector(ConnectorError::io( + io::Error::new(io::ErrorKind::BrokenPipe, "fail to parse").into(), + )))); + assert_eq!(None, classifier.classify_retry(&ctx)); } // since tokens are sent as headers, the tokens need to be valid header values @@ -905,7 +948,7 @@ pub(crate) mod test { token_request("http://169.254.169.254", 21600), token_response(21600, "replaced").map(|_| vec![1, 0]), )]); - let client = make_client(&connection).await; + let client = make_client(&connection); let err = client.get("/latest/metadata").await.expect_err("no token"); assert_full_error_contains!(err, "invalid token"); connection.assert_requests_match(&[]); @@ -926,7 +969,7 @@ pub(crate) mod test { .unwrap(), ), ]); - let client = make_client(&connection).await; + let client = make_client(&connection); let err = client.get("/latest/metadata").await.expect_err("no token"); assert_full_error_contains!(err, "invalid UTF-8"); connection.assert_requests_match(&[]); @@ -943,14 +986,20 @@ pub(crate) mod test { let client = Client::builder() // 240.* can never be resolved .endpoint(Uri::from_static("http://240.0.0.0")) - .build() - .await - .expect("valid client"); + .build(); let now = SystemTime::now(); let resp = client .get("/latest/metadata") .await .expect_err("240.0.0.0 will never resolve"); + match resp { + err @ ImdsError::FailedToLoadToken(_) + if format!("{}", DisplayErrorContext(&err)).contains("timeout") => {} // ok, + other => panic!( + "wrong error, expected construction failure with TimedOutError inside: {}", + DisplayErrorContext(&other) + ), + } let time_elapsed = now.elapsed().unwrap(); assert!( time_elapsed > Duration::from_secs(1), @@ -962,14 +1011,6 @@ pub(crate) mod test { "time_elapsed should be less than 2s but was {:?}", time_elapsed ); - match resp { - err @ ImdsError::FailedToLoadToken(_) - if format!("{}", DisplayErrorContext(&err)).contains("timeout") => {} // ok, - other => panic!( - "wrong error, expected construction failure with TimedOutError inside: {}", - other - ), - } } #[derive(Debug, Deserialize)] @@ -983,8 +1024,10 @@ pub(crate) mod test { } #[tokio::test] - async fn config_tests() -> Result<(), Box> { - let test_cases = std::fs::read_to_string("test-data/imds-config/imds-tests.json")?; + async fn endpoint_config_tests() -> Result<(), Box> { + let _logs = capture_test_logs(); + + let test_cases = std::fs::read_to_string("test-data/imds-config/imds-endpoint-tests.json")?; #[derive(Deserialize)] struct TestCases { tests: Vec, @@ -1014,24 +1057,22 @@ pub(crate) mod test { imds_client = imds_client.endpoint_mode(mode_override.parse().unwrap()); } - let imds_client = imds_client.build().await; - let (uri, imds_client) = match (&test_case.result, imds_client) { - (Ok(uri), Ok(client)) => (uri, client), - (Err(test), Ok(_client)) => panic!( - "test should fail: {} but a valid client was made. {}", - test, test_case.docs - ), - (Err(substr), Err(err)) => { - assert_full_error_contains!(err, substr); - return; + let imds_client = imds_client.build(); + match &test_case.result { + Ok(uri) => { + // this request will fail, we just want to capture the endpoint configuration + let _ = imds_client.get("/hello").await; + assert_eq!(&watcher.expect_request().uri().to_string(), uri); + } + Err(expected) => { + let err = imds_client.get("/hello").await.expect_err("it should fail"); + let message = format!("{}", DisplayErrorContext(&err)); + assert!( + message.contains(expected), + "{}\nexpected error: {expected}\nactual error: {message}", + test_case.docs + ); } - (Ok(_uri), Err(e)) => panic!( - "a valid client should be made but: {}. {}", - e, test_case.docs - ), }; - // this request will fail, we just want to capture the endpoint configuration - let _ = imds_client.get("/hello").await; - assert_eq!(&watcher.expect_request().uri().to_string(), uri); } } diff --git a/aws/rust-runtime/aws-config/src/imds/client/error.rs b/aws/rust-runtime/aws-config/src/imds/client/error.rs index b9559486a..4d32aee01 100644 --- a/aws/rust-runtime/aws-config/src/imds/client/error.rs +++ b/aws/rust-runtime/aws-config/src/imds/client/error.rs @@ -8,13 +8,14 @@ use aws_smithy_client::SdkError; use aws_smithy_http::body::SdkBody; use aws_smithy_http::endpoint::error::InvalidEndpointError; +use aws_smithy_runtime_api::client::orchestrator::HttpResponse; use std::error::Error; use std::fmt; /// Error context for [`ImdsError::FailedToLoadToken`] #[derive(Debug)] pub struct FailedToLoadToken { - source: SdkError, + source: SdkError, } impl FailedToLoadToken { @@ -23,7 +24,7 @@ impl FailedToLoadToken { matches!(self.source, SdkError::DispatchFailure(_)) } - pub(crate) fn into_source(self) -> SdkError { + pub(crate) fn into_source(self) -> SdkError { self.source } } @@ -76,7 +77,7 @@ pub enum ImdsError { } impl ImdsError { - pub(super) fn failed_to_load_token(source: SdkError) -> Self { + pub(super) fn failed_to_load_token(source: SdkError) -> Self { Self::FailedToLoadToken(FailedToLoadToken { source }) } diff --git a/aws/rust-runtime/aws-config/src/imds/client/token.rs b/aws/rust-runtime/aws-config/src/imds/client/token.rs index 41e96777b..98e1a79a8 100644 --- a/aws/rust-runtime/aws-config/src/imds/client/token.rs +++ b/aws/rust-runtime/aws-config/src/imds/client/token.rs @@ -15,26 +15,29 @@ //! - Attach the token to the request in the `x-aws-ec2-metadata-token` header use crate::imds::client::error::{ImdsError, TokenError, TokenErrorKind}; -use crate::imds::client::ImdsResponseRetryClassifier; use aws_credential_types::cache::ExpiringCache; -use aws_http::user_agent::UserAgentStage; -use aws_smithy_async::rt::sleep::SharedAsyncSleep; use aws_smithy_async::time::SharedTimeSource; -use aws_smithy_client::erase::DynConnector; -use aws_smithy_client::retry; use aws_smithy_http::body::SdkBody; -use aws_smithy_http::endpoint::apply_endpoint; -use aws_smithy_http::middleware::AsyncMapRequest; -use aws_smithy_http::operation; -use aws_smithy_http::operation::Operation; -use aws_smithy_http::operation::{Metadata, Request}; -use aws_smithy_http::response::ParseStrictResponse; -use aws_smithy_http_tower::map_request::MapRequestLayer; -use aws_smithy_types::timeout::TimeoutConfig; +use aws_smithy_runtime::client::orchestrator::operation::Operation; +use aws_smithy_runtime_api::box_error::BoxError; +use aws_smithy_runtime_api::client::auth::static_resolver::StaticAuthSchemeOptionResolver; +use aws_smithy_runtime_api::client::auth::{ + AuthScheme, AuthSchemeEndpointConfig, AuthSchemeId, Signer, +}; +use aws_smithy_runtime_api::client::identity::{ + Identity, IdentityResolver, SharedIdentityResolver, +}; +use aws_smithy_runtime_api::client::orchestrator::{ + Future, HttpRequest, HttpResponse, OrchestratorError, +}; +use aws_smithy_runtime_api::client::runtime_components::{ + GetIdentityResolver, RuntimeComponents, RuntimeComponentsBuilder, +}; +use aws_smithy_runtime_api::client::runtime_plugin::{RuntimePlugin, SharedRuntimePlugin}; +use aws_smithy_types::config_bag::ConfigBag; use http::{HeaderValue, Uri}; -use std::fmt::{Debug, Formatter}; -use std::future::Future; -use std::pin::Pin; +use std::borrow::Cow; +use std::fmt; use std::sync::Arc; use std::time::{Duration, SystemTime}; @@ -47,6 +50,7 @@ const TOKEN_REFRESH_BUFFER: Duration = Duration::from_secs(120); const X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS: &str = "x-aws-ec2-metadata-token-ttl-seconds"; const X_AWS_EC2_METADATA_TOKEN: &str = "x-aws-ec2-metadata-token"; +const IMDS_TOKEN_AUTH_SCHEME: AuthSchemeId = AuthSchemeId::new(X_AWS_EC2_METADATA_TOKEN); /// IMDS Token #[derive(Clone)] @@ -54,151 +58,214 @@ struct Token { value: HeaderValue, expiry: SystemTime, } - -/// Token Middleware -/// -/// Token middleware will load/cache a token when required and handle caching/expiry. -/// -/// It will attach the token to the incoming request on the `x-aws-ec2-metadata-token` header. -#[derive(Clone)] -pub(super) struct TokenMiddleware { - client: Arc>>, - token_parser: GetTokenResponseHandler, - token: ExpiringCache, - time_source: SharedTimeSource, - endpoint: Uri, - token_ttl: Duration, -} - -impl Debug for TokenMiddleware { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "ImdsTokenMiddleware") +impl fmt::Debug for Token { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Token") + .field("value", &"** redacted **") + .field("expiry", &self.expiry) + .finish() } } -impl TokenMiddleware { +/// Token Runtime Plugin +/// +/// This runtime plugin wires up the necessary components to load/cache a token +/// when required and handle caching/expiry. This token will get attached to the +/// request to IMDS on the `x-aws-ec2-metadata-token` header. +#[derive(Debug)] +pub(super) struct TokenRuntimePlugin { + components: RuntimeComponentsBuilder, +} + +impl TokenRuntimePlugin { pub(super) fn new( - connector: DynConnector, + common_plugin: SharedRuntimePlugin, time_source: SharedTimeSource, - endpoint: Uri, token_ttl: Duration, - retry_config: retry::Config, - timeout_config: TimeoutConfig, - sleep_impl: Option, ) -> Self { - let mut inner_builder = aws_smithy_client::Client::builder() - .connector(connector) - .middleware(MapRequestLayer::::default()) - .retry_config(retry_config) - .operation_timeout_config(timeout_config.into()); - inner_builder.set_sleep_impl(sleep_impl); - let inner_client = inner_builder.build(); - let client = Arc::new(inner_client); Self { - client, - token_parser: GetTokenResponseHandler { - time: time_source.clone(), - }, - token: ExpiringCache::new(TOKEN_REFRESH_BUFFER), - time_source, - endpoint, - token_ttl, + components: RuntimeComponentsBuilder::new("TokenRuntimePlugin") + .with_auth_scheme(TokenAuthScheme::new()) + .with_auth_scheme_option_resolver(Some(StaticAuthSchemeOptionResolver::new(vec![ + IMDS_TOKEN_AUTH_SCHEME, + ]))) + .with_identity_resolver( + IMDS_TOKEN_AUTH_SCHEME, + TokenResolver::new(common_plugin, time_source, token_ttl), + ), } } - async fn add_token(&self, request: Request) -> Result { - let preloaded_token = self - .token - .yield_or_clear_if_expired(self.time_source.now()) - .await; - let token = match preloaded_token { - Some(token) => Ok(token), - None => { - self.token - .get_or_load(|| async move { self.get_token().await }) - .await - } - }?; - request.augment(|mut request, _| { - request - .headers_mut() - .insert(X_AWS_EC2_METADATA_TOKEN, token.value); - Ok(request) - }) +} + +impl RuntimePlugin for TokenRuntimePlugin { + fn runtime_components( + &self, + _current_components: &RuntimeComponentsBuilder, + ) -> Cow<'_, RuntimeComponentsBuilder> { + Cow::Borrowed(&self.components) + } +} + +#[derive(Debug)] +struct TokenResolverInner { + cache: ExpiringCache, + refresh: Operation<(), Token, TokenError>, + time_source: SharedTimeSource, +} + +#[derive(Clone, Debug)] +struct TokenResolver { + inner: Arc, +} + +impl TokenResolver { + fn new( + common_plugin: SharedRuntimePlugin, + time_source: SharedTimeSource, + token_ttl: Duration, + ) -> Self { + Self { + inner: Arc::new(TokenResolverInner { + cache: ExpiringCache::new(TOKEN_REFRESH_BUFFER), + refresh: Operation::builder() + .service_name("imds") + .operation_name("get-token") + .runtime_plugin(common_plugin) + .no_auth() + .serializer(move |_| { + Ok(http::Request::builder() + .method("PUT") + .uri(Uri::from_static("/latest/api/token")) + .header(X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS, token_ttl.as_secs()) + .body(SdkBody::empty()) + .expect("valid HTTP request")) + }) + .deserializer({ + let time_source = time_source.clone(); + move |response| { + let now = time_source.now(); + parse_token_response(response, now) + .map_err(OrchestratorError::operation) + } + }) + .build(), + time_source, + }), + } } async fn get_token(&self) -> Result<(Token, SystemTime), ImdsError> { - let mut uri = Uri::from_static("/latest/api/token"); - apply_endpoint(&mut uri, &self.endpoint, None).map_err(ImdsError::unexpected)?; - let request = http::Request::builder() - .header( - X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS, - self.token_ttl.as_secs(), - ) - .uri(uri) - .method("PUT") - .body(SdkBody::empty()) - .expect("valid HTTP request"); - let mut request = operation::Request::new(request); - request.properties_mut().insert(super::user_agent()); - - let operation = Operation::new(request, self.token_parser.clone()) - .with_retry_classifier(ImdsResponseRetryClassifier) - .with_metadata(Metadata::new("get-token", "imds")); - let response = self - .client - .call(operation) + self.inner + .refresh + .invoke(()) .await - .map_err(ImdsError::failed_to_load_token)?; - let expiry = response.expiry; - Ok((response, expiry)) + .map(|token| { + let expiry = token.expiry; + (token, expiry) + }) + .map_err(ImdsError::failed_to_load_token) } } -impl AsyncMapRequest for TokenMiddleware { - type Error = ImdsError; - type Future = Pin> + Send + 'static>>; - - fn name(&self) -> &'static str { - "attach_imds_token" +fn parse_token_response(response: &HttpResponse, now: SystemTime) -> Result { + match response.status().as_u16() { + 400 => return Err(TokenErrorKind::InvalidParameters.into()), + 403 => return Err(TokenErrorKind::Forbidden.into()), + _ => {} } - - fn apply(&self, request: Request) -> Self::Future { - let this = self.clone(); - Box::pin(async move { this.add_token(request).await }) - } -} - -#[derive(Clone)] -struct GetTokenResponseHandler { - time: SharedTimeSource, -} - -impl ParseStrictResponse for GetTokenResponseHandler { - type Output = Result; - - fn parse(&self, response: &http::Response) -> Self::Output { - match response.status().as_u16() { - 400 => return Err(TokenErrorKind::InvalidParameters.into()), - 403 => return Err(TokenErrorKind::Forbidden.into()), - _ => {} - } - let value = HeaderValue::from_maybe_shared(response.body().clone()) + let mut value = + HeaderValue::from_bytes(response.body().bytes().expect("non-streaming response")) .map_err(|_| TokenErrorKind::InvalidToken)?; - let ttl: u64 = response - .headers() - .get(X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS) - .ok_or(TokenErrorKind::NoTtl)? - .to_str() - .map_err(|_| TokenErrorKind::InvalidTtl)? - .parse() - .map_err(|_parse_error| TokenErrorKind::InvalidTtl)?; - Ok(Token { - value, - expiry: self.time.now() + Duration::from_secs(ttl), - }) - } + value.set_sensitive(true); - fn sensitive(&self) -> bool { - true + let ttl: u64 = response + .headers() + .get(X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS) + .ok_or(TokenErrorKind::NoTtl)? + .to_str() + .map_err(|_| TokenErrorKind::InvalidTtl)? + .parse() + .map_err(|_parse_error| TokenErrorKind::InvalidTtl)?; + Ok(Token { + value, + expiry: now + Duration::from_secs(ttl), + }) +} + +impl IdentityResolver for TokenResolver { + fn resolve_identity(&self, _config_bag: &ConfigBag) -> Future { + let this = self.clone(); + Future::new(Box::pin(async move { + let preloaded_token = this + .inner + .cache + .yield_or_clear_if_expired(this.inner.time_source.now()) + .await; + let token = match preloaded_token { + Some(token) => Ok(token), + None => { + this.inner + .cache + .get_or_load(|| { + let this = this.clone(); + async move { this.get_token().await } + }) + .await + } + }?; + + let expiry = token.expiry; + Ok(Identity::new(token, Some(expiry))) + })) + } +} + +#[derive(Debug)] +struct TokenAuthScheme { + signer: TokenSigner, +} + +impl TokenAuthScheme { + fn new() -> Self { + Self { + signer: TokenSigner, + } + } +} + +impl AuthScheme for TokenAuthScheme { + fn scheme_id(&self) -> AuthSchemeId { + IMDS_TOKEN_AUTH_SCHEME + } + + fn identity_resolver( + &self, + identity_resolvers: &dyn GetIdentityResolver, + ) -> Option { + identity_resolvers.identity_resolver(IMDS_TOKEN_AUTH_SCHEME) + } + + fn signer(&self) -> &dyn Signer { + &self.signer + } +} + +#[derive(Debug)] +struct TokenSigner; + +impl Signer for TokenSigner { + fn sign_http_request( + &self, + request: &mut HttpRequest, + identity: &Identity, + _auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>, + _runtime_components: &RuntimeComponents, + _config_bag: &ConfigBag, + ) -> Result<(), BoxError> { + let token = identity.data::().expect("correct type"); + request + .headers_mut() + .append(X_AWS_EC2_METADATA_TOKEN, token.value.clone()); + Ok(()) } } diff --git a/aws/rust-runtime/aws-config/src/imds/credentials.rs b/aws/rust-runtime/aws-config/src/imds/credentials.rs index ccc65eaf9..3bde8a451 100644 --- a/aws/rust-runtime/aws-config/src/imds/credentials.rs +++ b/aws/rust-runtime/aws-config/src/imds/credentials.rs @@ -9,8 +9,7 @@ //! This credential provider will NOT fallback to IMDSv1. Ensure that IMDSv2 is enabled on your instances. use super::client::error::ImdsError; -use crate::imds; -use crate::imds::client::LazyClient; +use crate::imds::{self, Client}; use crate::json_credentials::{parse_json_credentials, JsonCredentials, RefreshableCredentials}; use crate::provider_config::ProviderConfig; use aws_credential_types::provider::{self, error::CredentialsError, future, ProvideCredentials}; @@ -50,7 +49,7 @@ impl StdError for ImdsCommunicationError { /// _Note: This credentials provider will NOT fallback to the IMDSv1 flow._ #[derive(Debug)] pub struct ImdsCredentialsProvider { - client: LazyClient, + client: Client, env: Env, profile: Option, time_source: SharedTimeSource, @@ -110,12 +109,7 @@ impl Builder { let env = provider_config.env(); let client = self .imds_override - .map(LazyClient::from_ready_client) - .unwrap_or_else(|| { - imds::Client::builder() - .configure(&provider_config) - .build_lazy() - }); + .unwrap_or_else(|| imds::Client::builder().configure(&provider_config).build()); ImdsCredentialsProvider { client, env, @@ -156,23 +150,14 @@ impl ImdsCredentialsProvider { } } - /// Load an inner IMDS client from the OnceCell - async fn client(&self) -> Result<&imds::Client, CredentialsError> { - self.client.client().await.map_err(|build_error| { - // need to format the build error since we don't own it and it can't be cloned - CredentialsError::invalid_configuration(format!("{}", build_error)) - }) - } - /// Retrieve the instance profile from IMDS async fn get_profile_uncached(&self) -> Result { match self - .client() - .await? + .client .get("/latest/meta-data/iam/security-credentials/") .await { - Ok(profile) => Ok(profile), + Ok(profile) => Ok(profile.as_ref().into()), Err(ImdsError::ErrorResponse(context)) if context.response().status().as_u16() == 404 => { @@ -223,9 +208,11 @@ impl ImdsCredentialsProvider { async fn retrieve_credentials(&self) -> provider::Result { if self.imds_disabled() { - tracing::debug!("IMDS disabled because $AWS_EC2_METADATA_DISABLED was set to `true`"); + tracing::debug!( + "IMDS disabled because AWS_EC2_METADATA_DISABLED env var was set to `true`" + ); return Err(CredentialsError::not_loaded( - "IMDS disabled by $AWS_ECS_METADATA_DISABLED", + "IMDS disabled by AWS_ECS_METADATA_DISABLED env var", )); } tracing::debug!("loading credentials from IMDS"); @@ -235,15 +222,14 @@ impl ImdsCredentialsProvider { }; tracing::debug!(profile = %profile, "loaded profile"); let credentials = self - .client() - .await? - .get(&format!( + .client + .get(format!( "/latest/meta-data/iam/security-credentials/{}", profile )) .await .map_err(CredentialsError::provider_error)?; - match parse_json_credentials(&credentials) { + match parse_json_credentials(credentials.as_ref()) { Ok(JsonCredentials::RefreshableCredentials(RefreshableCredentials { access_key_id, secret_access_key, @@ -296,19 +282,16 @@ impl ImdsCredentialsProvider { #[cfg(test)] mod test { - use std::time::{Duration, UNIX_EPOCH}; - + use super::*; use crate::imds::client::test::{ imds_request, imds_response, make_client, token_request, token_response, }; - use crate::imds::credentials::{ - ImdsCredentialsProvider, WARNING_FOR_EXTENDING_CREDENTIALS_EXPIRY, - }; use crate::provider_config::ProviderConfig; use aws_credential_types::provider::ProvideCredentials; use aws_smithy_async::test_util::instant_time_and_sleep; use aws_smithy_client::erase::DynConnector; use aws_smithy_client::test_connection::TestConnection; + use std::time::{Duration, UNIX_EPOCH}; use tracing_test::traced_test; const TOKEN_A: &str = "token_a"; @@ -338,7 +321,7 @@ mod test { ), ]); let client = ImdsCredentialsProvider::builder() - .imds_client(make_client(&connection).await) + .imds_client(make_client(&connection)) .build(); let creds1 = client.provide_credentials().await.expect("valid creds"); let creds2 = client.provide_credentials().await.expect("valid creds"); @@ -376,9 +359,7 @@ mod test { .with_time_source(time_source); let client = crate::imds::Client::builder() .configure(&provider_config) - .build() - .await - .expect("valid client"); + .build(); let provider = ImdsCredentialsProvider::builder() .configure(&provider_config) .imds_client(client) @@ -422,9 +403,7 @@ mod test { .with_time_source(time_source); let client = crate::imds::Client::builder() .configure(&provider_config) - .build() - .await - .expect("valid client"); + .build(); let provider = ImdsCredentialsProvider::builder() .configure(&provider_config) .imds_client(client) @@ -443,9 +422,7 @@ mod test { let client = crate::imds::Client::builder() // 240.* can never be resolved .endpoint(http::Uri::from_static("http://240.0.0.0")) - .build() - .await - .expect("valid client"); + .build(); let expected = aws_credential_types::Credentials::for_tests(); let provider = ImdsCredentialsProvider::builder() .imds_client(client) @@ -463,18 +440,16 @@ mod test { let client = crate::imds::Client::builder() // 240.* can never be resolved .endpoint(http::Uri::from_static("http://240.0.0.0")) - .build() - .await - .expect("valid client"); + .build(); let provider = ImdsCredentialsProvider::builder() .imds_client(client) // no fallback credentials provided .build(); let actual = provider.provide_credentials().await; - assert!(matches!( - actual, - Err(aws_credential_types::provider::error::CredentialsError::CredentialsNotLoaded(_)) - )); + assert!( + matches!(actual, Err(CredentialsError::CredentialsNotLoaded(_))), + "\nexpected: Err(CredentialsError::CredentialsNotLoaded(_))\nactual: {actual:?}" + ); } #[tokio::test] @@ -484,9 +459,7 @@ mod test { let client = crate::imds::Client::builder() // 240.* can never be resolved .endpoint(http::Uri::from_static("http://240.0.0.0")) - .build() - .await - .expect("valid client"); + .build(); let expected = aws_credential_types::Credentials::for_tests(); let provider = ImdsCredentialsProvider::builder() .imds_client(client) @@ -536,7 +509,7 @@ mod test { ), ]); let provider = ImdsCredentialsProvider::builder() - .imds_client(make_client(&connection).await) + .imds_client(make_client(&connection)) .build(); let creds1 = provider.provide_credentials().await.expect("valid creds"); assert_eq!(creds1.access_key_id(), "ASIARTEST"); diff --git a/aws/rust-runtime/aws-config/src/imds/region.rs b/aws/rust-runtime/aws-config/src/imds/region.rs index bc784f8d4..072dc97a8 100644 --- a/aws/rust-runtime/aws-config/src/imds/region.rs +++ b/aws/rust-runtime/aws-config/src/imds/region.rs @@ -8,8 +8,7 @@ //! Load region from IMDS from `/latest/meta-data/placement/region` //! This provider has a 5 second timeout. -use crate::imds; -use crate::imds::client::LazyClient; +use crate::imds::{self, Client}; use crate::meta::region::{future, ProvideRegion}; use crate::provider_config::ProviderConfig; use aws_smithy_types::error::display::DisplayErrorContext; @@ -22,7 +21,7 @@ use tracing::Instrument; /// This provider is included in the default region chain, so it does not need to be used manually. #[derive(Debug)] pub struct ImdsRegionProvider { - client: LazyClient, + client: Client, env: Env, } @@ -49,11 +48,10 @@ impl ImdsRegionProvider { tracing::debug!("not using IMDS to load region, IMDS is disabled"); return None; } - let client = self.client.client().await.ok()?; - match client.get(REGION_PATH).await { + match self.client.get(REGION_PATH).await { Ok(region) => { - tracing::debug!(region = %region, "loaded region from IMDS"); - Some(Region::new(region)) + tracing::debug!(region = %region.as_ref(), "loaded region from IMDS"); + Some(Region::new(String::from(region))) } Err(err) => { tracing::warn!(err = %DisplayErrorContext(&err), "failed to load region from IMDS"); @@ -99,12 +97,7 @@ impl Builder { let provider_config = self.provider_config.unwrap_or_default(); let client = self .imds_client_override - .map(LazyClient::from_ready_client) - .unwrap_or_else(|| { - imds::Client::builder() - .configure(&provider_config) - .build_lazy() - }); + .unwrap_or_else(|| imds::Client::builder().configure(&provider_config).build()); ImdsRegionProvider { client, env: provider_config.env(), diff --git a/aws/rust-runtime/aws-config/test-data/imds-config/imds-tests.json b/aws/rust-runtime/aws-config/test-data/imds-config/imds-endpoint-tests.json similarity index 100% rename from aws/rust-runtime/aws-config/test-data/imds-config/imds-tests.json rename to aws/rust-runtime/aws-config/test-data/imds-config/imds-endpoint-tests.json diff --git a/aws/rust-runtime/aws-runtime/src/user_agent.rs b/aws/rust-runtime/aws-runtime/src/user_agent.rs index 7310820f8..1b6c36998 100644 --- a/aws/rust-runtime/aws-runtime/src/user_agent.rs +++ b/aws/rust-runtime/aws-runtime/src/user_agent.rs @@ -82,25 +82,25 @@ impl Interceptor for UserAgentInterceptor { _runtime_components: &RuntimeComponents, cfg: &mut ConfigBag, ) -> Result<(), BoxError> { - let api_metadata = cfg - .load::() - .ok_or(UserAgentInterceptorError::MissingApiMetadata)?; - // Allow for overriding the user agent by an earlier interceptor (so, for example, // tests can use `AwsUserAgent::for_tests()`) by attempting to grab one out of the // config bag before creating one. let ua: Cow<'_, AwsUserAgent> = cfg .load::() .map(Cow::Borrowed) + .map(Result::<_, UserAgentInterceptorError>::Ok) .unwrap_or_else(|| { + let api_metadata = cfg + .load::() + .ok_or(UserAgentInterceptorError::MissingApiMetadata)?; let mut ua = AwsUserAgent::new_from_environment(Env::real(), api_metadata.clone()); let maybe_app_name = cfg.load::(); if let Some(app_name) = maybe_app_name { ua.set_app_name(app_name.clone()); } - Cow::Owned(ua) - }); + Ok(Cow::Owned(ua)) + })?; let headers = context.request_mut().headers_mut(); let (user_agent, x_amz_user_agent) = header_values(&ua)?; @@ -250,4 +250,30 @@ mod tests { "`{error}` should contain message `This is a bug`" ); } + + #[test] + fn test_api_metadata_missing_with_ua_override() { + let rc = RuntimeComponentsBuilder::for_tests().build().unwrap(); + let mut context = context(); + + let mut layer = Layer::new("test"); + layer.store_put(AwsUserAgent::for_tests()); + let mut config = ConfigBag::of_layers(vec![layer]); + + let interceptor = UserAgentInterceptor::new(); + let mut ctx = Into::into(&mut context); + + interceptor + .modify_before_signing(&mut ctx, &rc, &mut config) + .expect("it should succeed"); + + let header = expect_header(&context, "user-agent"); + assert_eq!(AwsUserAgent::for_tests().ua_header(), header); + assert!(!header.contains("unused")); + + assert_eq!( + AwsUserAgent::for_tests().aws_ua_header(), + expect_header(&context, "x-amz-user-agent") + ); + } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SensitiveOutputDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SensitiveOutputDecorator.kt index a902c1432..88a1a77c9 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SensitiveOutputDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SensitiveOutputDecorator.kt @@ -12,7 +12,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationCus import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationSection import software.amazon.smithy.rust.codegen.client.smithy.generators.SensitiveIndex import software.amazon.smithy.rust.codegen.core.rustlang.Writable -import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationCustomization.kt index d872311cd..48bfc7589 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationCustomization.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationCustomization.kt @@ -10,7 +10,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.Section @@ -72,17 +71,9 @@ sealed class OperationSection(name: String) : Section(name) { val operationShape: OperationShape, ) : OperationSection("AdditionalInterceptors") { fun registerInterceptor(runtimeConfig: RuntimeConfig, writer: RustWriter, interceptor: Writable) { - val smithyRuntimeApi = RuntimeType.smithyRuntimeApi(runtimeConfig) writer.rustTemplate( - """ - .with_interceptor( - #{SharedInterceptor}::new( - #{interceptor} - ) as _ - ) - """, + ".with_interceptor(#{interceptor})", "interceptor" to interceptor, - "SharedInterceptor" to smithyRuntimeApi.resolve("client::interceptors::SharedInterceptor"), ) } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt index 555ea7e03..921d16a3e 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt @@ -43,10 +43,9 @@ sealed class ServiceRuntimePluginSection(name: String) : Section(name) { fun registerInterceptor(runtimeConfig: RuntimeConfig, writer: RustWriter, interceptor: Writable) { writer.rustTemplate( """ - runtime_components.push_interceptor(#{SharedInterceptor}::new(#{interceptor}) as _); + runtime_components.push_interceptor(#{interceptor}); """, "interceptor" to interceptor, - "SharedInterceptor" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::interceptors::SharedInterceptor"), ) } diff --git a/rust-runtime/aws-smithy-client/src/test_connection.rs b/rust-runtime/aws-smithy-client/src/test_connection.rs index b0600b369..622f5fedc 100644 --- a/rust-runtime/aws-smithy-client/src/test_connection.rs +++ b/rust-runtime/aws-smithy-client/src/test_connection.rs @@ -134,7 +134,7 @@ pub struct ValidateRequest { impl ValidateRequest { pub fn assert_matches(&self, ignore_headers: &[HeaderName]) { let (actual, expected) = (&self.actual, &self.expected); - assert_eq!(actual.uri(), expected.uri()); + assert_eq!(expected.uri(), actual.uri()); for (name, value) in expected.headers() { if !ignore_headers.contains(name) { let actual_header = actual diff --git a/rust-runtime/aws-smithy-http/src/result.rs b/rust-runtime/aws-smithy-http/src/result.rs index 8d60c8d2e..c86e9c5f8 100644 --- a/rust-runtime/aws-smithy-http/src/result.rs +++ b/rust-runtime/aws-smithy-http/src/result.rs @@ -651,6 +651,11 @@ impl ConnectorError { } } + /// Grants ownership of this error's source. + pub fn into_source(self) -> BoxError { + self.source + } + /// Returns metadata about the connection /// /// If a connection was established and provided by the internal connector, a connection will diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/auth.rs b/rust-runtime/aws-smithy-runtime-api/src/client/auth.rs index 79a28bef3..d5e02b6da 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/auth.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/auth.rs @@ -9,6 +9,7 @@ use crate::box_error::BoxError; use crate::client::identity::{Identity, SharedIdentityResolver}; use crate::client::orchestrator::HttpRequest; use crate::client::runtime_components::{GetIdentityResolver, RuntimeComponents}; +use crate::impl_shared_conversions; use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace}; use aws_smithy_types::type_erasure::TypeErasedBox; use aws_smithy_types::Document; @@ -120,6 +121,12 @@ impl AuthSchemeOptionResolver for SharedAuthSchemeOptionResolver { } } +impl_shared_conversions!( + convert SharedAuthSchemeOptionResolver + from AuthSchemeOptionResolver + using SharedAuthSchemeOptionResolver::new +); + /// An auth scheme. /// /// Auth schemes have unique identifiers (the `scheme_id`), @@ -177,6 +184,8 @@ impl AuthScheme for SharedAuthScheme { } } +impl_shared_conversions!(convert SharedAuthScheme from AuthScheme using SharedAuthScheme::new); + /// Signing implementation for an auth scheme. pub trait Signer: Send + Sync + fmt::Debug { /// Sign the given request with the given identity, components, and config. diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/connectors.rs b/rust-runtime/aws-smithy-runtime-api/src/client/connectors.rs index dd91f634b..9399fa05b 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/connectors.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/connectors.rs @@ -31,6 +31,7 @@ //! The Smithy clients have no knowledge of such concepts. use crate::client::orchestrator::{HttpRequest, HttpResponse}; +use crate::impl_shared_conversions; use aws_smithy_async::future::now_or_later::NowOrLater; use aws_smithy_http::result::ConnectorError; use pin_project_lite::pin_project; @@ -117,3 +118,5 @@ impl HttpConnector for SharedHttpConnector { (*self.0).call(request) } } + +impl_shared_conversions!(convert SharedHttpConnector from HttpConnector using SharedHttpConnector::new); diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/endpoint.rs b/rust-runtime/aws-smithy-runtime-api/src/client/endpoint.rs index a32ed602f..0037b9eb1 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/endpoint.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/endpoint.rs @@ -6,6 +6,7 @@ //! APIs needed to configure endpoint resolution for clients. use crate::client::orchestrator::Future; +use crate::impl_shared_conversions; use aws_smithy_types::config_bag::{Storable, StoreReplace}; use aws_smithy_types::endpoint::Endpoint; use aws_smithy_types::type_erasure::TypeErasedBox; @@ -60,3 +61,5 @@ impl EndpointResolver for SharedEndpointResolver { self.0.resolve_endpoint(params) } } + +impl_shared_conversions!(convert SharedEndpointResolver from EndpointResolver using SharedEndpointResolver::new); diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/identity.rs b/rust-runtime/aws-smithy-runtime-api/src/client/identity.rs index 27c4accf4..ee115487c 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/identity.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/identity.rs @@ -5,6 +5,7 @@ use crate::client::auth::AuthSchemeId; use crate::client::orchestrator::Future; +use crate::impl_shared_conversions; use aws_smithy_types::config_bag::ConfigBag; use std::any::Any; use std::fmt; @@ -48,6 +49,8 @@ impl IdentityResolver for SharedIdentityResolver { } } +impl_shared_conversions!(convert SharedIdentityResolver from IdentityResolver using SharedIdentityResolver::new); + /// An identity resolver paired with an auth scheme ID that it resolves for. #[derive(Clone, Debug)] pub(crate) struct ConfiguredIdentityResolver { diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/interceptors.rs b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors.rs index 443af2962..9a26fbb40 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/interceptors.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors.rs @@ -19,12 +19,12 @@ use crate::client::runtime_components::RuntimeComponents; use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace}; use std::fmt; use std::marker::PhantomData; -use std::ops::Deref; use std::sync::Arc; pub mod context; pub mod error; +use crate::impl_shared_conversions; pub use error::InterceptorError; macro_rules! interceptor_trait_fn { @@ -618,18 +618,201 @@ impl SharedInterceptor { } } -impl AsRef for SharedInterceptor { - fn as_ref(&self) -> &(dyn Interceptor + 'static) { - self.interceptor.as_ref() +impl Interceptor for SharedInterceptor { + fn name(&self) -> &'static str { + self.interceptor.name() + } + + fn modify_before_attempt_completion( + &self, + context: &mut FinalizerInterceptorContextMut<'_>, + runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + self.interceptor + .modify_before_attempt_completion(context, runtime_components, cfg) + } + + fn modify_before_completion( + &self, + context: &mut FinalizerInterceptorContextMut<'_>, + runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + self.interceptor + .modify_before_completion(context, runtime_components, cfg) + } + + fn modify_before_deserialization( + &self, + context: &mut BeforeDeserializationInterceptorContextMut<'_>, + runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + self.interceptor + .modify_before_deserialization(context, runtime_components, cfg) + } + + fn modify_before_retry_loop( + &self, + context: &mut BeforeTransmitInterceptorContextMut<'_>, + runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + self.interceptor + .modify_before_retry_loop(context, runtime_components, cfg) + } + + fn modify_before_serialization( + &self, + context: &mut BeforeSerializationInterceptorContextMut<'_>, + runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + self.interceptor + .modify_before_serialization(context, runtime_components, cfg) + } + + fn modify_before_signing( + &self, + context: &mut BeforeTransmitInterceptorContextMut<'_>, + runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + self.interceptor + .modify_before_signing(context, runtime_components, cfg) + } + + fn modify_before_transmit( + &self, + context: &mut BeforeTransmitInterceptorContextMut<'_>, + runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + self.interceptor + .modify_before_transmit(context, runtime_components, cfg) + } + + fn read_after_attempt( + &self, + context: &FinalizerInterceptorContextRef<'_>, + runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + self.interceptor + .read_after_attempt(context, runtime_components, cfg) + } + + fn read_after_deserialization( + &self, + context: &AfterDeserializationInterceptorContextRef<'_>, + runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + self.interceptor + .read_after_deserialization(context, runtime_components, cfg) + } + + fn read_after_execution( + &self, + context: &FinalizerInterceptorContextRef<'_>, + runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + self.interceptor + .read_after_execution(context, runtime_components, cfg) + } + + fn read_after_serialization( + &self, + context: &BeforeTransmitInterceptorContextRef<'_>, + runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + self.interceptor + .read_after_serialization(context, runtime_components, cfg) + } + + fn read_after_signing( + &self, + context: &BeforeTransmitInterceptorContextRef<'_>, + runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + self.interceptor + .read_after_signing(context, runtime_components, cfg) + } + + fn read_after_transmit( + &self, + context: &BeforeDeserializationInterceptorContextRef<'_>, + runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + self.interceptor + .read_after_transmit(context, runtime_components, cfg) + } + + fn read_before_attempt( + &self, + context: &BeforeTransmitInterceptorContextRef<'_>, + runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + self.interceptor + .read_before_attempt(context, runtime_components, cfg) + } + + fn read_before_deserialization( + &self, + context: &BeforeDeserializationInterceptorContextRef<'_>, + runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + self.interceptor + .read_before_deserialization(context, runtime_components, cfg) + } + + fn read_before_execution( + &self, + context: &BeforeSerializationInterceptorContextRef<'_>, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + self.interceptor.read_before_execution(context, cfg) + } + + fn read_before_serialization( + &self, + context: &BeforeSerializationInterceptorContextRef<'_>, + runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + self.interceptor + .read_before_serialization(context, runtime_components, cfg) + } + + fn read_before_signing( + &self, + context: &BeforeTransmitInterceptorContextRef<'_>, + runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + self.interceptor + .read_before_signing(context, runtime_components, cfg) + } + + fn read_before_transmit( + &self, + context: &BeforeTransmitInterceptorContextRef<'_>, + runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + self.interceptor + .read_before_transmit(context, runtime_components, cfg) } } -impl Deref for SharedInterceptor { - type Target = Arc; - fn deref(&self) -> &Self::Target { - &self.interceptor - } -} +impl_shared_conversions!(convert SharedInterceptor from Interceptor using SharedInterceptor::new); /// Generalized interceptor disabling interface /// diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs b/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs index 8de224cc8..a03451556 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs @@ -25,7 +25,8 @@ use aws_smithy_http::body::SdkBody; use aws_smithy_http::result::{ConnectorError, SdkError}; use aws_smithy_types::config_bag::{Storable, StoreReplace}; use bytes::Bytes; -use std::fmt::Debug; +use std::error::Error as StdError; +use std::fmt; use std::future::Future as StdFuture; use std::pin::Pin; @@ -244,6 +245,35 @@ impl OrchestratorError { } } +impl StdError for OrchestratorError +where + E: StdError + 'static, +{ + fn source(&self) -> Option<&(dyn StdError + 'static)> { + Some(match &self.kind { + ErrorKind::Connector { source } => source as _, + ErrorKind::Operation { err } => err as _, + ErrorKind::Interceptor { source } => source as _, + ErrorKind::Response { source } => source.as_ref(), + ErrorKind::Timeout { source } => source.as_ref(), + ErrorKind::Other { source } => source.as_ref(), + }) + } +} + +impl fmt::Display for OrchestratorError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self.kind { + ErrorKind::Connector { .. } => "connector error", + ErrorKind::Operation { .. } => "operation error", + ErrorKind::Interceptor { .. } => "interceptor error", + ErrorKind::Response { .. } => "response error", + ErrorKind::Timeout { .. } => "timeout", + ErrorKind::Other { .. } => "an unknown error occurred", + }) + } +} + fn convert_dispatch_error( err: BoxError, response: Option, @@ -262,7 +292,7 @@ fn convert_dispatch_error( impl From for OrchestratorError where - E: Debug + std::error::Error + 'static, + E: fmt::Debug + std::error::Error + 'static, { fn from(err: InterceptorError) -> Self { Self::interceptor(err) diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/retries.rs b/rust-runtime/aws-smithy-runtime-api/src/client/retries.rs index 0703e87d8..0fd61b577 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/retries.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/retries.rs @@ -98,6 +98,8 @@ impl RetryStrategy for SharedRetryStrategy { } } +impl_shared_conversions!(convert SharedRetryStrategy from RetryStrategy using SharedRetryStrategy::new); + /// Classification result from [`ClassifyRetry`]. #[non_exhaustive] #[derive(Clone, Eq, PartialEq, Debug)] @@ -231,6 +233,7 @@ mod test_util { use crate::box_error::BoxError; use crate::client::runtime_components::RuntimeComponents; +use crate::impl_shared_conversions; use std::sync::Arc; #[cfg(feature = "test-util")] pub use test_util::AlwaysRetry; diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/runtime_components.rs b/rust-runtime/aws-smithy-runtime-api/src/client/runtime_components.rs index ac67705ef..4a92ec698 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/runtime_components.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/runtime_components.rs @@ -19,6 +19,7 @@ use crate::client::endpoint::SharedEndpointResolver; use crate::client::identity::{ConfiguredIdentityResolver, SharedIdentityResolver}; use crate::client::interceptors::SharedInterceptor; use crate::client::retries::{RetryClassifiers, SharedRetryStrategy}; +use crate::shared::IntoShared; use aws_smithy_async::rt::sleep::SharedAsyncSleep; use aws_smithy_async::time::SharedTimeSource; use std::fmt; @@ -273,17 +274,17 @@ impl RuntimeComponentsBuilder { /// Sets the auth scheme option resolver. pub fn set_auth_scheme_option_resolver( &mut self, - auth_scheme_option_resolver: Option, + auth_scheme_option_resolver: Option>, ) -> &mut Self { self.auth_scheme_option_resolver = - auth_scheme_option_resolver.map(|r| Tracked::new(self.builder_name, r)); + auth_scheme_option_resolver.map(|r| Tracked::new(self.builder_name, r.into_shared())); self } /// Sets the auth scheme option resolver. pub fn with_auth_scheme_option_resolver( mut self, - auth_scheme_option_resolver: Option, + auth_scheme_option_resolver: Option>, ) -> Self { self.set_auth_scheme_option_resolver(auth_scheme_option_resolver); self @@ -295,13 +296,19 @@ impl RuntimeComponentsBuilder { } /// Sets the HTTP connector. - pub fn set_http_connector(&mut self, connector: Option) -> &mut Self { - self.http_connector = connector.map(|c| Tracked::new(self.builder_name, c)); + pub fn set_http_connector( + &mut self, + connector: Option>, + ) -> &mut Self { + self.http_connector = connector.map(|c| Tracked::new(self.builder_name, c.into_shared())); self } /// Sets the HTTP connector. - pub fn with_http_connector(mut self, connector: Option) -> Self { + pub fn with_http_connector( + mut self, + connector: Option>, + ) -> Self { self.set_http_connector(connector); self } @@ -314,16 +321,17 @@ impl RuntimeComponentsBuilder { /// Sets the endpoint resolver. pub fn set_endpoint_resolver( &mut self, - endpoint_resolver: Option, + endpoint_resolver: Option>, ) -> &mut Self { - self.endpoint_resolver = endpoint_resolver.map(|s| Tracked::new(self.builder_name, s)); + self.endpoint_resolver = + endpoint_resolver.map(|s| Tracked::new(self.builder_name, s.into_shared())); self } /// Sets the endpoint resolver. pub fn with_endpoint_resolver( mut self, - endpoint_resolver: Option, + endpoint_resolver: Option>, ) -> Self { self.set_endpoint_resolver(endpoint_resolver); self @@ -335,14 +343,17 @@ impl RuntimeComponentsBuilder { } /// Adds an auth scheme. - pub fn push_auth_scheme(&mut self, auth_scheme: SharedAuthScheme) -> &mut Self { + pub fn push_auth_scheme( + &mut self, + auth_scheme: impl IntoShared, + ) -> &mut Self { self.auth_schemes - .push(Tracked::new(self.builder_name, auth_scheme)); + .push(Tracked::new(self.builder_name, auth_scheme.into_shared())); self } /// Adds an auth scheme. - pub fn with_auth_scheme(mut self, auth_scheme: SharedAuthScheme) -> Self { + pub fn with_auth_scheme(mut self, auth_scheme: impl IntoShared) -> Self { self.push_auth_scheme(auth_scheme); self } @@ -351,11 +362,11 @@ impl RuntimeComponentsBuilder { pub fn push_identity_resolver( &mut self, scheme_id: AuthSchemeId, - identity_resolver: SharedIdentityResolver, + identity_resolver: impl IntoShared, ) -> &mut Self { self.identity_resolvers.push(Tracked::new( self.builder_name, - ConfiguredIdentityResolver::new(scheme_id, identity_resolver), + ConfiguredIdentityResolver::new(scheme_id, identity_resolver.into_shared()), )); self } @@ -364,7 +375,7 @@ impl RuntimeComponentsBuilder { pub fn with_identity_resolver( mut self, scheme_id: AuthSchemeId, - identity_resolver: SharedIdentityResolver, + identity_resolver: impl IntoShared, ) -> Self { self.push_identity_resolver(scheme_id, identity_resolver); self @@ -386,14 +397,17 @@ impl RuntimeComponentsBuilder { } /// Adds an interceptor. - pub fn push_interceptor(&mut self, interceptor: SharedInterceptor) -> &mut Self { + pub fn push_interceptor( + &mut self, + interceptor: impl IntoShared, + ) -> &mut Self { self.interceptors - .push(Tracked::new(self.builder_name, interceptor)); + .push(Tracked::new(self.builder_name, interceptor.into_shared())); self } /// Adds an interceptor. - pub fn with_interceptor(mut self, interceptor: SharedInterceptor) -> Self { + pub fn with_interceptor(mut self, interceptor: impl IntoShared) -> Self { self.push_interceptor(interceptor); self } @@ -444,14 +458,22 @@ impl RuntimeComponentsBuilder { } /// Sets the retry strategy. - pub fn set_retry_strategy(&mut self, retry_strategy: Option) -> &mut Self { - self.retry_strategy = retry_strategy.map(|s| Tracked::new(self.builder_name, s)); + pub fn set_retry_strategy( + &mut self, + retry_strategy: Option>, + ) -> &mut Self { + self.retry_strategy = + retry_strategy.map(|s| Tracked::new(self.builder_name, s.into_shared())); self } /// Sets the retry strategy. - pub fn with_retry_strategy(mut self, retry_strategy: Option) -> Self { - self.retry_strategy = retry_strategy.map(|s| Tracked::new(self.builder_name, s)); + pub fn with_retry_strategy( + mut self, + retry_strategy: Option>, + ) -> Self { + self.retry_strategy = + retry_strategy.map(|s| Tracked::new(self.builder_name, s.into_shared())); self } @@ -617,13 +639,13 @@ impl RuntimeComponentsBuilder { } Self::new("aws_smithy_runtime_api::client::runtime_components::RuntimeComponentBuilder::for_tests") - .with_auth_scheme(SharedAuthScheme::new(FakeAuthScheme)) - .with_auth_scheme_option_resolver(Some(SharedAuthSchemeOptionResolver::new(FakeAuthSchemeOptionResolver))) - .with_endpoint_resolver(Some(SharedEndpointResolver::new(FakeEndpointResolver))) - .with_http_connector(Some(SharedHttpConnector::new(FakeConnector))) - .with_identity_resolver(AuthSchemeId::new("fake"), SharedIdentityResolver::new(FakeIdentityResolver)) + .with_auth_scheme(FakeAuthScheme) + .with_auth_scheme_option_resolver(Some(FakeAuthSchemeOptionResolver)) + .with_endpoint_resolver(Some(FakeEndpointResolver)) + .with_http_connector(Some(FakeConnector)) + .with_identity_resolver(AuthSchemeId::new("fake"), FakeIdentityResolver) .with_retry_classifiers(Some(RetryClassifiers::new())) - .with_retry_strategy(Some(SharedRetryStrategy::new(FakeRetryStrategy))) + .with_retry_strategy(Some(FakeRetryStrategy)) .with_sleep_impl(Some(SharedAsyncSleep::new(FakeSleep))) .with_time_source(Some(SharedTimeSource::new(FakeTimeSource))) } diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/runtime_plugin.rs b/rust-runtime/aws-smithy-runtime-api/src/client/runtime_plugin.rs index 9fec5519a..34548c0ab 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/runtime_plugin.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/runtime_plugin.rs @@ -22,6 +22,8 @@ use crate::box_error::BoxError; use crate::client::runtime_components::{ RuntimeComponentsBuilder, EMPTY_RUNTIME_COMPONENTS_BUILDER, }; +use crate::impl_shared_conversions; +use crate::shared::IntoShared; use aws_smithy_types::config_bag::{ConfigBag, FrozenLayer}; use std::borrow::Cow; use std::fmt::Debug; @@ -56,7 +58,7 @@ pub enum Order { /// Runtime plugin trait /// -/// A `RuntimePlugin` is the unit of configuration for augmenting the SDK with new behavior. +/// A `RuntimePlugin` is the unit of configuration for augmenting the client with new behavior. /// /// Runtime plugins can register interceptors, set runtime components, and modify configuration. pub trait RuntimePlugin: Debug + Send + Sync { @@ -113,7 +115,7 @@ pub trait RuntimePlugin: Debug + Send + Sync { pub struct SharedRuntimePlugin(Arc); impl SharedRuntimePlugin { - /// Returns a new [`SharedRuntimePlugin`]. + /// Creates a new [`SharedRuntimePlugin`]. pub fn new(plugin: impl RuntimePlugin + 'static) -> Self { Self(Arc::new(plugin)) } @@ -136,6 +138,8 @@ impl RuntimePlugin for SharedRuntimePlugin { } } +impl_shared_conversions!(convert SharedRuntimePlugin from RuntimePlugin using SharedRuntimePlugin::new); + /// Runtime plugin that simply returns the config and components given at construction time. #[derive(Default, Debug)] pub struct StaticRuntimePlugin { @@ -185,15 +189,16 @@ impl RuntimePlugin for StaticRuntimePlugin { self.runtime_components .as_ref() .map(Cow::Borrowed) - .unwrap_or_else(|| RuntimePlugin::runtime_components(self, _current_components)) + .unwrap_or_else(|| Cow::Borrowed(&EMPTY_RUNTIME_COMPONENTS_BUILDER)) } } macro_rules! insert_plugin { - ($vec:expr, $plugin:ident, $create_rp:expr) => {{ + ($vec:expr, $plugin:expr) => {{ // Insert the plugin in the correct order + let plugin = $plugin; let mut insert_index = 0; - let order = $plugin.order(); + let order = plugin.order(); for (index, other_plugin) in $vec.iter().enumerate() { let other_order = other_plugin.order(); if other_order <= order { @@ -202,7 +207,7 @@ macro_rules! insert_plugin { break; } } - $vec.insert(insert_index, $create_rp); + $vec.insert(insert_index, plugin); }}; } @@ -235,21 +240,13 @@ impl RuntimePlugins { Default::default() } - pub fn with_client_plugin(mut self, plugin: impl RuntimePlugin + 'static) -> Self { - insert_plugin!( - self.client_plugins, - plugin, - SharedRuntimePlugin::new(plugin) - ); + pub fn with_client_plugin(mut self, plugin: impl IntoShared) -> Self { + insert_plugin!(self.client_plugins, plugin.into_shared()); self } - pub fn with_operation_plugin(mut self, plugin: impl RuntimePlugin + 'static) -> Self { - insert_plugin!( - self.operation_plugins, - plugin, - SharedRuntimePlugin::new(plugin) - ); + pub fn with_operation_plugin(mut self, plugin: impl IntoShared) -> Self { + insert_plugin!(self.operation_plugins, plugin.into_shared()); self } @@ -274,7 +271,7 @@ mod tests { use crate::client::connectors::{HttpConnector, HttpConnectorFuture, SharedHttpConnector}; use crate::client::orchestrator::HttpRequest; use crate::client::runtime_components::RuntimeComponentsBuilder; - use crate::client::runtime_plugin::Order; + use crate::client::runtime_plugin::{Order, SharedRuntimePlugin}; use aws_smithy_http::body::SdkBody; use aws_smithy_types::config_bag::ConfigBag; use http::HeaderValue; @@ -307,7 +304,7 @@ mod tests { } fn insert_plugin(vec: &mut Vec, plugin: RP) { - insert_plugin!(vec, plugin, plugin); + insert_plugin!(vec, plugin); } let mut vec = Vec::new(); @@ -450,4 +447,25 @@ mod tests { assert_eq!("1", resp.headers().get("rp1").unwrap()); assert_eq!("1", resp.headers().get("rp2").unwrap()); } + + #[test] + fn shared_runtime_plugin_new_specialization() { + #[derive(Debug)] + struct RP; + impl RuntimePlugin for RP {} + + use crate::shared::IntoShared; + let shared1 = SharedRuntimePlugin::new(RP); + let shared2: SharedRuntimePlugin = shared1.clone().into_shared(); + assert_eq!( + "SharedRuntimePlugin(RP)", + format!("{shared1:?}"), + "precondition: RP shows up in the debug format" + ); + assert_eq!( + format!("{shared1:?}"), + format!("{shared2:?}"), + "it should not nest the shared runtime plugins" + ); + } } diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/ser_de.rs b/rust-runtime/aws-smithy-runtime-api/src/client/ser_de.rs index 5a3a77988..027b33a5f 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/ser_de.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/ser_de.rs @@ -8,6 +8,7 @@ use crate::box_error::BoxError; use crate::client::interceptors::context::{Error, Input, Output}; use crate::client::orchestrator::{HttpRequest, HttpResponse, OrchestratorError}; +use crate::impl_shared_conversions; use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace}; use std::fmt; use std::sync::Arc; @@ -48,6 +49,8 @@ impl Storable for SharedRequestSerializer { type Storer = StoreReplace; } +impl_shared_conversions!(convert SharedRequestSerializer from RequestSerializer using SharedRequestSerializer::new); + /// Deserialization implementation that converts an [`HttpResponse`] into an [`Output`] or [`Error`]. pub trait ResponseDeserializer: Send + Sync + fmt::Debug { /// For streaming requests, deserializes the response headers. @@ -103,3 +106,5 @@ impl ResponseDeserializer for SharedResponseDeserializer { impl Storable for SharedResponseDeserializer { type Storer = StoreReplace; } + +impl_shared_conversions!(convert SharedResponseDeserializer from ResponseDeserializer using SharedResponseDeserializer::new); diff --git a/rust-runtime/aws-smithy-runtime-api/src/lib.rs b/rust-runtime/aws-smithy-runtime-api/src/lib.rs index ebc9225a8..56b65f209 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/lib.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/lib.rs @@ -34,3 +34,5 @@ pub mod client; /// Internal builder macros. Not intended to be used outside of the aws-smithy-runtime crates. #[doc(hidden)] pub mod macros; + +pub mod shared; diff --git a/rust-runtime/aws-smithy-runtime-api/src/shared.rs b/rust-runtime/aws-smithy-runtime-api/src/shared.rs new file mode 100644 index 000000000..c45506c9f --- /dev/null +++ b/rust-runtime/aws-smithy-runtime-api/src/shared.rs @@ -0,0 +1,224 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Conversion traits for converting an unshared type into a shared type. +//! +//! The standard [`From`](std::convert::From)/[`Into`](std::convert::Into) traits can't be +//! used for this purpose due to the blanket implementation of `Into`. +//! +//! This implementation also adds a [`maybe_shared`] method and [`impl_shared_conversions`](crate::impl_shared_conversions) +//! macro to trivially avoid nesting shared types with other shared types. +//! +//! # What is a shared type? +//! +//! A shared type is a new-type around a `Send + Sync` reference counting smart pointer +//! (i.e., an [`Arc`](std::sync::Arc)) around an object-safe trait. Shared types are +//! used to share a trait object among multiple threads/clients/requests. +#![cfg_attr( + feature = "client", + doc = " +For example, [`SharedHttpConnector`](crate::client::connectors::SharedHttpConnector), is +a shared type for the [`HttpConnector`](crate::client::connectors::HttpConnector) trait, +which allows for sharing a single HTTP connector instance (and its connection pool) among multiple clients. +" +)] +//! +//! A shared type implements the [`FromUnshared`] trait, which allows any implementation +//! of the trait it wraps to easily be converted into it. +//! +#![cfg_attr( + feature = "client", + doc = " +To illustrate, let's examine the +[`RuntimePlugin`](crate::client::runtime_plugin::RuntimePlugin)/[`SharedRuntimePlugin`](crate::client::runtime_plugin::SharedRuntimePlugin) +duo. +The following instantiates a concrete implementation of the `RuntimePlugin` trait. +We can do `RuntimePlugin` things on this instance. + +```rust,no_run +use aws_smithy_runtime_api::client::runtime_plugin::StaticRuntimePlugin; + +let some_plugin = StaticRuntimePlugin::new(); +``` + +We can convert this instance into a shared type in two different ways. + +```rust,no_run +# use aws_smithy_runtime_api::client::runtime_plugin::StaticRuntimePlugin; +# let some_plugin = StaticRuntimePlugin::new(); +use aws_smithy_runtime_api::client::runtime_plugin::SharedRuntimePlugin; +use aws_smithy_runtime_api::shared::{IntoShared, FromUnshared}; + +// Using the `IntoShared` trait +let shared: SharedRuntimePlugin = some_plugin.into_shared(); + +// Using the `FromUnshared` trait: +# let some_plugin = StaticRuntimePlugin::new(); +let shared = SharedRuntimePlugin::from_unshared(some_plugin); +``` + +The `IntoShared` trait is useful for making functions that take any `RuntimePlugin` impl and convert it to a shared type. +For example, this function will convert the given `plugin` argument into a `SharedRuntimePlugin`. + +```rust,no_run +# use aws_smithy_runtime_api::client::runtime_plugin::{SharedRuntimePlugin, StaticRuntimePlugin}; +# use aws_smithy_runtime_api::shared::{IntoShared, FromUnshared}; +fn take_shared(plugin: impl IntoShared) { + let _plugin: SharedRuntimePlugin = plugin.into_shared(); +} +``` + +This can be called with different types, and even if a `SharedRuntimePlugin` is passed in, it won't nest that +`SharedRuntimePlugin` inside of another `SharedRuntimePlugin`. + +```rust,no_run +# use aws_smithy_runtime_api::client::runtime_plugin::{SharedRuntimePlugin, StaticRuntimePlugin}; +# use aws_smithy_runtime_api::shared::{IntoShared, FromUnshared}; +# fn take_shared(plugin: impl IntoShared) { +# let _plugin: SharedRuntimePlugin = plugin.into_shared(); +# } +// Automatically converts it to `SharedRuntimePlugin(StaticRuntimePlugin)` +take_shared(StaticRuntimePlugin::new()); + +// This is OK. +// It create a `SharedRuntimePlugin(StaticRuntimePlugin))` +// instead of a nested `SharedRuntimePlugin(SharedRuntimePlugin(StaticRuntimePlugin)))` +take_shared(SharedRuntimePlugin::new(StaticRuntimePlugin::new())); +``` +" +)] + +use std::any::{Any, TypeId}; + +/// Like the `From` trait, but for converting to a shared type. +/// +/// See the [module docs](crate::shared) for information about shared types. +pub trait FromUnshared { + /// Creates a shared type from an unshared type. + fn from_unshared(value: Unshared) -> Self; +} + +/// Like the `Into` trait, but for (efficiently) converting into a shared type. +/// +/// If the type is already a shared type, it won't be nested in another shared type. +/// +/// See the [module docs](crate::shared) for information about shared types. +pub trait IntoShared { + /// Creates a shared type from an unshared type. + fn into_shared(self) -> Shared; +} + +impl IntoShared for Unshared +where + Shared: FromUnshared, +{ + fn into_shared(self) -> Shared { + FromUnshared::from_unshared(self) + } +} + +/// Given a `value`, determine if that value is already shared. If it is, return it. Otherwise, wrap it in a shared type. +/// +/// See the [module docs](crate::shared) for information about shared types. +pub fn maybe_shared(value: MaybeShared, ctor: F) -> Shared +where + Shared: 'static, + MaybeShared: IntoShared + 'static, + F: FnOnce(MaybeShared) -> Shared, +{ + // Check if the type is already a shared type + if TypeId::of::() == TypeId::of::() { + // Convince the compiler it is already a shared type and return it + let mut placeholder = Some(value); + let value: Shared = (&mut placeholder as &mut dyn Any) + .downcast_mut::>() + .expect("type checked above") + .take() + .expect("set to Some above"); + value + } else { + (ctor)(value) + } +} + +/// Implements `FromUnshared` for a shared type. +/// +/// See the [`shared` module docs](crate::shared) for information about shared types. +/// +/// # Example +/// ```rust,no_run +/// use aws_smithy_runtime_api::impl_shared_conversions; +/// use std::sync::Arc; +/// +/// trait Thing {} +/// +/// struct Thingamajig; +/// impl Thing for Thingamajig {} +/// +/// struct SharedThing(Arc); +/// impl Thing for SharedThing {} +/// impl SharedThing { +/// fn new(thing: impl Thing + 'static) -> Self { +/// Self(Arc::new(thing)) +/// } +/// } +/// impl_shared_conversions!(convert SharedThing from Thing using SharedThing::new); +/// ``` +#[macro_export] +macro_rules! impl_shared_conversions { + (convert $shared_type:ident from $unshared_trait:ident using $ctor:expr) => { + impl $crate::shared::FromUnshared for $shared_type + where + T: $unshared_trait + 'static, + { + fn from_unshared(value: T) -> Self { + $crate::shared::maybe_shared(value, $ctor) + } + } + }; +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fmt; + use std::sync::Arc; + + trait Thing: fmt::Debug {} + + #[derive(Debug)] + struct Thingamajig; + impl Thing for Thingamajig {} + + #[derive(Debug)] + struct SharedThing(Arc); + impl Thing for SharedThing {} + impl SharedThing { + fn new(thing: impl Thing + 'static) -> Self { + Self(Arc::new(thing)) + } + } + impl_shared_conversions!(convert SharedThing from Thing using SharedThing::new); + + #[test] + fn test() { + let thing = Thingamajig; + assert_eq!("Thingamajig", format!("{thing:?}"), "precondition"); + + let shared_thing: SharedThing = thing.into_shared(); + assert_eq!( + "SharedThing(Thingamajig)", + format!("{shared_thing:?}"), + "precondition" + ); + + let very_shared_thing: SharedThing = shared_thing.into_shared(); + assert_eq!( + "SharedThing(Thingamajig)", + format!("{very_shared_thing:?}"), + "it should not nest the shared thing in another shared thing" + ); + } +} diff --git a/rust-runtime/aws-smithy-runtime/src/client/interceptors.rs b/rust-runtime/aws-smithy-runtime/src/client/interceptors.rs index c9dd07460..8dff23fcc 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/interceptors.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/interceptors.rs @@ -273,7 +273,7 @@ struct ConditionallyEnabledInterceptor(SharedInterceptor); impl ConditionallyEnabledInterceptor { fn if_enabled(&self, cfg: &ConfigBag) -> Option<&dyn Interceptor> { if self.0.enabled(cfg) { - Some(self.0.as_ref()) + Some(&self.0) } else { None } diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/endpoints.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/endpoints.rs index 5257fa666..ffb3ac921 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/endpoints.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/endpoints.rs @@ -24,21 +24,22 @@ use tracing::trace; /// An endpoint resolver that uses a static URI. #[derive(Clone, Debug)] pub struct StaticUriEndpointResolver { - endpoint: Uri, + endpoint: String, } impl StaticUriEndpointResolver { /// Create a resolver that resolves to `http://localhost:{port}`. pub fn http_localhost(port: u16) -> Self { Self { - endpoint: Uri::from_str(&format!("http://localhost:{port}")) - .expect("all u16 values are valid ports"), + endpoint: format!("http://localhost:{port}"), } } /// Create a resolver that resolves to the given URI. - pub fn uri(endpoint: Uri) -> Self { - Self { endpoint } + pub fn uri(endpoint: impl Into) -> Self { + Self { + endpoint: endpoint.into(), + } } } diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/operation.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/operation.rs index eb6193287..ac61dc722 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/operation.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/operation.rs @@ -30,9 +30,9 @@ use aws_smithy_runtime_api::client::runtime_plugin::{ use aws_smithy_runtime_api::client::ser_de::{ RequestSerializer, ResponseDeserializer, SharedRequestSerializer, SharedResponseDeserializer, }; +use aws_smithy_runtime_api::shared::IntoShared; use aws_smithy_types::config_bag::{ConfigBag, Layer}; use aws_smithy_types::retry::RetryConfig; -use http::Uri; use std::borrow::Cow; use std::fmt; use std::marker::PhantomData; @@ -100,7 +100,7 @@ impl fmt::Debug for FnDeserializer { /// Orchestrates execution of a HTTP request without any modeled input or output. #[doc(hidden)] -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct Operation { service_name: Cow<'static, str>, operation_name: Cow<'static, str>, @@ -108,6 +108,18 @@ pub struct Operation { _phantom: PhantomData<(I, O, E)>, } +// Manual Clone implementation needed to get rid of Clone bounds on I, O, and E +impl Clone for Operation { + fn clone(&self) -> Self { + Self { + service_name: self.service_name.clone(), + operation_name: self.operation_name.clone(), + runtime_plugins: self.runtime_plugins.clone(), + _phantom: self._phantom, + } + } +} + impl Operation<(), (), ()> { pub fn builder() -> OperationBuilder { OperationBuilder::new() @@ -177,7 +189,7 @@ impl OperationBuilder { self } - pub fn http_connector(mut self, connector: SharedHttpConnector) -> Self { + pub fn http_connector(mut self, connector: impl IntoShared) -> Self { self.runtime_components.set_http_connector(Some(connector)); self } @@ -186,7 +198,7 @@ impl OperationBuilder { self.config.store_put(EndpointResolverParams::new(())); self.runtime_components .set_endpoint_resolver(Some(SharedEndpointResolver::new( - StaticUriEndpointResolver::uri(Uri::try_from(url).expect("valid URI")), + StaticUriEndpointResolver::uri(url), ))); self } @@ -237,13 +249,13 @@ impl OperationBuilder { self } - pub fn interceptor(mut self, interceptor: SharedInterceptor) -> Self { + pub fn interceptor(mut self, interceptor: impl IntoShared) -> Self { self.runtime_components.push_interceptor(interceptor); self } - pub fn runtime_plugin(mut self, runtime_plugin: SharedRuntimePlugin) -> Self { - self.runtime_plugins.push(runtime_plugin); + pub fn runtime_plugin(mut self, runtime_plugin: impl IntoShared) -> Self { + self.runtime_plugins.push(runtime_plugin.into_shared()); self } @@ -294,26 +306,6 @@ impl OperationBuilder { pub fn build(self) -> Operation { let service_name = self.service_name.expect("service_name required"); let operation_name = self.operation_name.expect("operation_name required"); - assert!( - self.runtime_components.http_connector().is_some(), - "a http_connector is required" - ); - assert!( - self.runtime_components.endpoint_resolver().is_some(), - "a endpoint_resolver is required" - ); - assert!( - self.runtime_components.retry_strategy().is_some(), - "a retry_strategy is required" - ); - assert!( - self.config.load::().is_some(), - "a serializer is required" - ); - assert!( - self.config.load::().is_some(), - "a deserializer is required" - ); let mut runtime_plugins = RuntimePlugins::new().with_client_plugin( StaticRuntimePlugin::new() .with_config(self.config.freeze()) @@ -323,6 +315,39 @@ impl OperationBuilder { runtime_plugins = runtime_plugins.with_client_plugin(runtime_plugin); } + #[cfg(debug_assertions)] + { + let mut config = ConfigBag::base(); + let components = runtime_plugins + .apply_client_configuration(&mut config) + .expect("the runtime plugins should succeed"); + + assert!( + components.http_connector().is_some(), + "a http_connector is required" + ); + assert!( + components.endpoint_resolver().is_some(), + "a endpoint_resolver is required" + ); + assert!( + components.retry_strategy().is_some(), + "a retry_strategy is required" + ); + assert!( + config.load::().is_some(), + "a serializer is required" + ); + assert!( + config.load::().is_some(), + "a deserializer is required" + ); + assert!( + config.load::().is_some(), + "endpoint resolver params are required" + ); + } + Operation { service_name, operation_name,