Use the orchestrator client for ECS and IMDS credentials in aws-config (#2997)

This ports the direct uses of the `aws_smithy_client::Client` in
aws_config to the orchestrator.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
This commit is contained in:
John DiSanti 2023-09-28 18:07:49 -07:00 committed by GitHub
parent 33cd698f7f
commit d800d33e28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 1470 additions and 784 deletions

View File

@ -241,3 +241,15 @@ message = "The `futures_core::stream::Stream` trait has been removed from [`Byte
references = ["smithy-rs#2983"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "ysaito1001"
[[smithy-rs]]
message = "`StaticUriEndpointResolver`'s `uri` constructor now takes a `String` instead of a `Uri`."
references = ["smithy-rs#2997"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "jdisanti"
[[aws-sdk-rust]]
message = "The IMDS Client builder's `build()` method is no longer async."
references = ["smithy-rs#2997"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "jdisanti"

View File

@ -9,7 +9,7 @@ license = "Apache-2.0"
repository = "https://github.com/awslabs/smithy-rs"
[features]
client-hyper = ["aws-smithy-client/client-hyper"]
client-hyper = ["aws-smithy-client/client-hyper", "aws-smithy-runtime/connector-hyper"]
rustls = ["aws-smithy-client/rustls", "client-hyper"]
native-tls = []
allow-compilation = [] # our tests use `cargo test --all-features` and native-tls breaks CI
@ -27,7 +27,10 @@ aws-smithy-client = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-client", de
aws-smithy-http = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-http" }
aws-smithy-http-tower = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-http-tower" }
aws-smithy-json = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-json" }
aws-smithy-runtime = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-runtime", features = ["client"] }
aws-smithy-runtime-api = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-runtime-api", features = ["client"] }
aws-smithy-types = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-types" }
aws-runtime = { path = "../../sdk/build/aws-sdk/sdk/aws-runtime" }
aws-types = { path = "../../sdk/build/aws-sdk/sdk/aws-types" }
hyper = { version = "0.14.26", default-features = false }
time = { version = "0.3.4", features = ["parsing"] }
@ -48,6 +51,7 @@ hex = { version = "0.4.3", optional = true }
zeroize = { version = "1", optional = true }
[dev-dependencies]
aws-smithy-runtime = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-runtime", features = ["client", "test-util"] }
futures-util = { version = "0.3.16", default-features = false }
tracing-test = "0.2.1"
tracing-subscriber = { version = "0.3.16", features = ["fmt", "json"] }

View File

@ -12,8 +12,8 @@
async fn main() -> Result<(), Box<dyn std::error::Error>> {
use aws_config::imds::Client;
let imds = Client::builder().build().await?;
let imds = Client::builder().build();
let instance_id = imds.get("/latest/meta-data/instance-id").await?;
println!("current instance id: {}", instance_id);
println!("current instance id: {}", instance_id.as_ref());
Ok(())
}

View File

@ -55,7 +55,7 @@ use aws_credential_types::provider::{self, error::CredentialsError, future, Prov
use aws_smithy_client::erase::boxclone::BoxCloneService;
use aws_smithy_http::endpoint::apply_endpoint;
use aws_smithy_types::error::display::DisplayErrorContext;
use http::uri::{InvalidUri, Scheme};
use http::uri::{InvalidUri, PathAndQuery, Scheme};
use http::{HeaderValue, Uri};
use tower::{Service, ServiceExt};
@ -166,6 +166,15 @@ impl Provider {
Err(EcsConfigurationError::NotConfigured) => return Provider::NotConfigured,
Err(err) => return Provider::InvalidConfiguration(err),
};
let path = uri.path().to_string();
let endpoint = {
let mut parts = uri.into_parts();
parts.path_and_query = Some(PathAndQuery::from_static("/"));
Uri::from_parts(parts)
}
.expect("parts will be valid")
.to_string();
let http_provider = HttpCredentialProvider::builder()
.configure(&provider_config)
.connector_settings(
@ -174,7 +183,7 @@ impl Provider {
.read_timeout(DEFAULT_READ_TIMEOUT)
.build(),
)
.build("EcsContainer", uri);
.build("EcsContainer", &endpoint, path);
Provider::Configured(http_provider)
}

View File

@ -8,35 +8,43 @@
//!
//! Future work will stabilize this interface and enable it to be used directly.
use aws_credential_types::provider::{self, error::CredentialsError};
use aws_credential_types::Credentials;
use aws_smithy_client::erase::DynConnector;
use aws_smithy_client::http_connector::ConnectorSettings;
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::operation::{Operation, Request};
use aws_smithy_http::response::ParseStrictResponse;
use aws_smithy_http::result::{SdkError, SdkSuccess};
use aws_smithy_http::retry::ClassifyRetry;
use aws_smithy_types::retry::{ErrorKind, RetryKind};
use crate::connector::expect_connector;
use crate::json_credentials::{parse_json_credentials, JsonCredentials, RefreshableCredentials};
use crate::provider_config::ProviderConfig;
use bytes::Bytes;
use aws_credential_types::provider::{self, error::CredentialsError};
use aws_credential_types::Credentials;
use aws_smithy_client::http_connector::ConnectorSettings;
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::result::SdkError;
use aws_smithy_runtime::client::connectors::adapter::DynConnectorAdapter;
use aws_smithy_runtime::client::orchestrator::operation::Operation;
use aws_smithy_runtime::client::retries::classifier::{
HttpStatusCodeClassifier, SmithyErrorClassifier,
};
use aws_smithy_runtime_api::client::connectors::SharedHttpConnector;
use aws_smithy_runtime_api::client::interceptors::context::{Error, InterceptorContext};
use aws_smithy_runtime_api::client::orchestrator::{
HttpResponse, OrchestratorError, SensitiveOutput,
};
use aws_smithy_runtime_api::client::retries::{ClassifyRetry, RetryClassifiers, RetryReason};
use aws_smithy_runtime_api::client::runtime_plugin::StaticRuntimePlugin;
use aws_smithy_types::config_bag::Layer;
use aws_smithy_types::retry::{ErrorKind, RetryConfig};
use http::header::{ACCEPT, AUTHORIZATION};
use http::{HeaderValue, Response, Uri};
use http::{HeaderValue, Response};
use std::time::Duration;
use tower::layer::util::Identity;
const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(5);
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
#[derive(Debug)]
struct HttpProviderAuth {
auth: Option<HeaderValue>,
}
#[derive(Debug)]
pub(crate) struct HttpCredentialProvider {
uri: Uri,
client: aws_smithy_client::Client<DynConnector, Identity>,
provider_name: &'static str,
operation: Operation<HttpProviderAuth, Credentials, CredentialsError>,
}
impl HttpCredentialProvider {
@ -45,34 +53,13 @@ impl HttpCredentialProvider {
}
pub(crate) async fn credentials(&self, auth: Option<HeaderValue>) -> provider::Result {
let credentials = self.client.call(self.operation(auth)).await;
let credentials = self.operation.invoke(HttpProviderAuth { auth }).await;
match credentials {
Ok(creds) => Ok(creds),
Err(SdkError::ServiceError(context)) => Err(context.into_err()),
Err(other) => Err(CredentialsError::unhandled(other)),
}
}
fn operation(
&self,
auth: Option<HeaderValue>,
) -> Operation<CredentialsResponseParser, HttpCredentialRetryClassifier> {
let mut http_req = http::Request::builder()
.uri(&self.uri)
.header(ACCEPT, "application/json");
if let Some(auth) = auth {
http_req = http_req.header(AUTHORIZATION, auth);
}
let http_req = http_req.body(SdkBody::empty()).expect("valid request");
Operation::new(
Request::new(http_req),
CredentialsResponseParser {
provider_name: self.provider_name,
},
)
.with_retry_classifier(HttpCredentialRetryClassifier)
}
}
#[derive(Default)]
@ -92,7 +79,12 @@ impl Builder {
self
}
pub(crate) fn build(self, provider_name: &'static str, uri: Uri) -> HttpCredentialProvider {
pub(crate) fn build(
self,
provider_name: &'static str,
endpoint: &str,
path: impl Into<String>,
) -> HttpCredentialProvider {
let provider_config = self.provider_config.unwrap_or_default();
let connector_settings = self.connector_settings.unwrap_or_else(|| {
ConnectorSettings::builder()
@ -104,198 +96,241 @@ impl Builder {
"The HTTP credentials provider",
provider_config.connector(&connector_settings),
);
let mut client_builder = aws_smithy_client::Client::builder()
.connector(connector)
.middleware(Identity::new());
client_builder.set_sleep_impl(provider_config.sleep());
let client = client_builder.build();
HttpCredentialProvider {
uri,
client,
provider_name,
// The following errors are retryable:
// - Socket errors
// - Networking timeouts
// - 5xx errors
// - Non-parseable 200 responses.
let retry_classifiers = RetryClassifiers::new()
.with_classifier(HttpCredentialRetryClassifier)
// Socket errors and network timeouts
.with_classifier(SmithyErrorClassifier::<Error>::new())
// 5xx errors
.with_classifier(HttpStatusCodeClassifier::default());
let mut builder = Operation::builder()
.service_name("HttpCredentialProvider")
.operation_name("LoadCredentials")
.http_connector(SharedHttpConnector::new(DynConnectorAdapter::new(
connector,
)))
.endpoint_url(endpoint)
.no_auth()
.runtime_plugin(StaticRuntimePlugin::new().with_config({
let mut layer = Layer::new("SensitiveOutput");
layer.store_put(SensitiveOutput);
layer.freeze()
}));
if let Some(sleep_impl) = provider_config.sleep() {
builder = builder
.standard_retry(&RetryConfig::standard())
.retry_classifiers(retry_classifiers)
.sleep_impl(sleep_impl);
} else {
builder = builder.no_retry();
}
let path = path.into();
let operation = builder
.serializer(move |input: HttpProviderAuth| {
let mut http_req = http::Request::builder()
.uri(path.clone())
.header(ACCEPT, "application/json");
if let Some(auth) = input.auth {
http_req = http_req.header(AUTHORIZATION, auth);
}
Ok(http_req.body(SdkBody::empty()).expect("valid request"))
})
.deserializer(move |response| parse_response(provider_name, response))
.build();
HttpCredentialProvider { operation }
}
}
#[derive(Clone, Debug)]
struct CredentialsResponseParser {
fn parse_response(
provider_name: &'static str,
}
impl ParseStrictResponse for CredentialsResponseParser {
type Output = provider::Result;
fn parse(&self, response: &Response<Bytes>) -> Self::Output {
if !response.status().is_success() {
return Err(CredentialsError::provider_error(format!(
response: &Response<SdkBody>,
) -> Result<Credentials, OrchestratorError<CredentialsError>> {
if !response.status().is_success() {
return Err(OrchestratorError::operation(
CredentialsError::provider_error(format!(
"Non-success status from HTTP credential provider: {:?}",
response.status()
)));
}
let str_resp =
std::str::from_utf8(response.body().as_ref()).map_err(CredentialsError::unhandled)?;
let json_creds = parse_json_credentials(str_resp).map_err(CredentialsError::unhandled)?;
match json_creds {
JsonCredentials::RefreshableCredentials(RefreshableCredentials {
access_key_id,
secret_access_key,
session_token,
expiration,
}) => Ok(Credentials::new(
access_key_id,
secret_access_key,
Some(session_token.to_string()),
Some(expiration),
self.provider_name,
)),
JsonCredentials::Error { code, message } => Err(CredentialsError::provider_error(
format!("failed to load credentials [{}]: {}", code, message),
)),
}
));
}
fn sensitive(&self) -> bool {
true
let resp_bytes = response.body().bytes().expect("non-streaming deserializer");
let str_resp = std::str::from_utf8(resp_bytes)
.map_err(|err| OrchestratorError::operation(CredentialsError::unhandled(err)))?;
let json_creds = parse_json_credentials(str_resp)
.map_err(|err| OrchestratorError::operation(CredentialsError::unhandled(err)))?;
match json_creds {
JsonCredentials::RefreshableCredentials(RefreshableCredentials {
access_key_id,
secret_access_key,
session_token,
expiration,
}) => Ok(Credentials::new(
access_key_id,
secret_access_key,
Some(session_token.to_string()),
Some(expiration),
provider_name,
)),
JsonCredentials::Error { code, message } => Err(OrchestratorError::operation(
CredentialsError::provider_error(format!(
"failed to load credentials [{}]: {}",
code, message
)),
)),
}
}
#[derive(Clone, Debug)]
struct HttpCredentialRetryClassifier;
impl ClassifyRetry<SdkSuccess<Credentials>, SdkError<CredentialsError>>
for HttpCredentialRetryClassifier
{
fn classify_retry(
&self,
response: Result<&SdkSuccess<Credentials>, &SdkError<CredentialsError>>,
) -> RetryKind {
/* The following errors are retryable:
* - Socket errors
* - Networking timeouts
* - 5xx errors
* - Non-parseable 200 responses.
* */
match response {
Ok(_) => RetryKind::Unnecessary,
// socket errors, networking timeouts
Err(SdkError::DispatchFailure(client_err))
if client_err.is_timeout() || client_err.is_io() =>
{
RetryKind::Error(ErrorKind::TransientError)
impl ClassifyRetry for HttpCredentialRetryClassifier {
fn name(&self) -> &'static str {
"HttpCredentialRetryClassifier"
}
fn classify_retry(&self, ctx: &InterceptorContext) -> Option<RetryReason> {
let output_or_error = ctx.output_or_error()?;
let error = match output_or_error {
Ok(_) => return None,
Err(err) => err,
};
// Retry non-parseable 200 responses
if let Some((err, status)) = error
.as_operation_error()
.and_then(|err| err.downcast_ref::<CredentialsError>())
.zip(ctx.response().map(HttpResponse::status))
{
if matches!(err, CredentialsError::Unhandled { .. }) && status.is_success() {
return Some(RetryReason::Error(ErrorKind::ServerError));
}
// non-parseable 200s
Err(SdkError::ServiceError(context))
if matches!(context.err(), CredentialsError::Unhandled { .. })
&& context.raw().http().status().is_success() =>
{
RetryKind::Error(ErrorKind::ServerError)
}
// 5xx errors
Err(SdkError::ResponseError(context))
if context.raw().http().status().is_server_error() =>
{
RetryKind::Error(ErrorKind::ServerError)
}
Err(SdkError::ServiceError(context))
if context.raw().http().status().is_server_error() =>
{
RetryKind::Error(ErrorKind::ServerError)
}
Err(_) => RetryKind::UnretryableFailure,
}
None
}
}
#[cfg(test)]
mod test {
use crate::http_credential_provider::{
CredentialsResponseParser, HttpCredentialRetryClassifier,
};
use super::*;
use aws_credential_types::provider::error::CredentialsError;
use aws_credential_types::Credentials;
use aws_smithy_client::test_connection::TestConnection;
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::operation;
use aws_smithy_http::response::ParseStrictResponse;
use aws_smithy_http::result::{SdkError, SdkSuccess};
use aws_smithy_http::retry::ClassifyRetry;
use aws_smithy_types::retry::{ErrorKind, RetryKind};
use bytes::Bytes;
use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
use http::{Request, Response, Uri};
use std::time::SystemTime;
fn sdk_resp(
resp: http::Response<&'static str>,
) -> Result<SdkSuccess<Credentials>, SdkError<CredentialsError>> {
let resp = resp.map(|data| Bytes::from_static(data.as_bytes()));
match (CredentialsResponseParser {
provider_name: "test",
})
.parse(&resp)
{
Ok(creds) => Ok(SdkSuccess {
raw: operation::Response::new(resp.map(SdkBody::from)),
parsed: creds,
}),
Err(err) => Err(SdkError::service_error(
err,
operation::Response::new(resp.map(SdkBody::from)),
)),
}
async fn provide_creds(
connector: TestConnection<SdkBody>,
) -> Result<Credentials, CredentialsError> {
let provider_config = ProviderConfig::default().with_http_connector(connector.clone());
let provider = HttpCredentialProvider::builder()
.configure(&provider_config)
.build("test", "http://localhost:1234/", "/some-creds");
provider.credentials(None).await
}
#[test]
fn non_parseable_is_retriable() {
let bad_response = http::Response::builder()
.status(200)
.body("notjson")
.unwrap();
assert_eq!(
HttpCredentialRetryClassifier.classify_retry(sdk_resp(bad_response).as_ref()),
RetryKind::Error(ErrorKind::ServerError)
);
fn successful_req_resp() -> (HttpRequest, HttpResponse) {
(
Request::builder()
.uri(Uri::from_static("http://localhost:1234/some-creds"))
.body(SdkBody::empty())
.unwrap(),
Response::builder()
.status(200)
.body(SdkBody::from(
r#"{
"AccessKeyId" : "MUA...",
"SecretAccessKey" : "/7PC5om....",
"Token" : "AQoDY....=",
"Expiration" : "2016-02-25T06:03:31Z"
}"#,
))
.unwrap(),
)
}
#[test]
fn ok_response_not_retriable() {
let ok_response = http::Response::builder()
.status(200)
.body(
r#" {
"AccessKeyId" : "MUA...",
"SecretAccessKey" : "/7PC5om....",
"Token" : "AQoDY....=",
"Expiration" : "2016-02-25T06:03:31Z"
}"#,
)
.unwrap();
let sdk_result = sdk_resp(ok_response);
#[tokio::test]
async fn successful_response() {
let connector = TestConnection::new(vec![successful_req_resp()]);
let creds = provide_creds(connector.clone()).await.expect("success");
assert_eq!("MUA...", creds.access_key_id());
assert_eq!("/7PC5om....", creds.secret_access_key());
assert_eq!(Some("AQoDY....="), creds.session_token());
assert_eq!(
HttpCredentialRetryClassifier.classify_retry(sdk_result.as_ref()),
RetryKind::Unnecessary
Some(SystemTime::UNIX_EPOCH + Duration::from_secs(1456380211)),
creds.expiry()
);
assert!(sdk_result.is_ok(), "should be ok: {:?}", sdk_result)
connector.assert_requests_match(&[]);
}
#[test]
fn explicit_error_not_retriable() {
let error_response = http::Response::builder()
.status(400)
.body(r#"{ "Code": "Error", "Message": "There was a problem, it was your fault" }"#)
.unwrap();
let sdk_result = sdk_resp(error_response);
assert_eq!(
HttpCredentialRetryClassifier.classify_retry(sdk_result.as_ref()),
RetryKind::UnretryableFailure
);
let sdk_error = sdk_result.expect_err("should be error");
assert!(
matches!(
sdk_error,
SdkError::ServiceError(ref context) if matches!(context.err(), CredentialsError::ProviderError { .. })
#[tokio::test]
async fn retry_nonparseable_response() {
let connector = TestConnection::new(vec![
(
Request::builder()
.uri(Uri::from_static("http://localhost:1234/some-creds"))
.body(SdkBody::empty())
.unwrap(),
Response::builder()
.status(200)
.body(SdkBody::from(r#"not json"#))
.unwrap(),
),
"should be provider error: {}",
sdk_error
successful_req_resp(),
]);
let creds = provide_creds(connector.clone()).await.expect("success");
assert_eq!("MUA...", creds.access_key_id());
connector.assert_requests_match(&[]);
}
#[tokio::test]
async fn retry_error_code() {
let connector = TestConnection::new(vec![
(
Request::builder()
.uri(Uri::from_static("http://localhost:1234/some-creds"))
.body(SdkBody::empty())
.unwrap(),
Response::builder()
.status(500)
.body(SdkBody::from(r#"it broke"#))
.unwrap(),
),
successful_req_resp(),
]);
let creds = provide_creds(connector.clone()).await.expect("success");
assert_eq!("MUA...", creds.access_key_id());
connector.assert_requests_match(&[]);
}
#[tokio::test]
async fn explicit_error_not_retriable() {
let connector = TestConnection::new(vec![(
Request::builder()
.uri(Uri::from_static("http://localhost:1234/some-creds"))
.body(SdkBody::empty())
.unwrap(),
Response::builder()
.status(400)
.body(SdkBody::from(
r#"{ "Code": "Error", "Message": "There was a problem, it was your fault" }"#,
))
.unwrap(),
)]);
let err = provide_creds(connector.clone())
.await
.expect_err("it should fail");
assert!(
matches!(err, CredentialsError::ProviderError { .. }),
"should be CredentialsError::ProviderError: {err}",
);
connector.assert_requests_match(&[]);
}
}

View File

@ -9,34 +9,48 @@
use crate::connector::expect_connector;
use crate::imds::client::error::{BuildError, ImdsError, InnerImdsError, InvalidEndpointMode};
use crate::imds::client::token::TokenMiddleware;
use crate::imds::client::token::TokenRuntimePlugin;
use crate::provider_config::ProviderConfig;
use crate::PKG_VERSION;
use aws_http::user_agent::{ApiMetadata, AwsUserAgent, UserAgentStage};
use aws_http::user_agent::{ApiMetadata, AwsUserAgent};
use aws_runtime::user_agent::UserAgentInterceptor;
use aws_smithy_async::rt::sleep::SharedAsyncSleep;
use aws_smithy_async::time::SharedTimeSource;
use aws_smithy_client::erase::DynConnector;
use aws_smithy_client::http_connector::ConnectorSettings;
use aws_smithy_client::{erase::DynConnector, SdkSuccess};
use aws_smithy_client::{retry, SdkError};
use aws_smithy_client::SdkError;
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::endpoint::apply_endpoint;
use aws_smithy_http::operation;
use aws_smithy_http::operation::{Metadata, Operation};
use aws_smithy_http::response::ParseStrictResponse;
use aws_smithy_http::retry::ClassifyRetry;
use aws_smithy_http_tower::map_request::{
AsyncMapRequestLayer, AsyncMapRequestService, MapRequestLayer, MapRequestService,
use aws_smithy_http::result::ConnectorError;
use aws_smithy_runtime::client::connectors::adapter::DynConnectorAdapter;
use aws_smithy_runtime::client::orchestrator::operation::Operation;
use aws_smithy_runtime::client::retries::strategy::StandardRetryStrategy;
use aws_smithy_runtime_api::client::auth::AuthSchemeOptionResolverParams;
use aws_smithy_runtime_api::client::connectors::SharedHttpConnector;
use aws_smithy_runtime_api::client::endpoint::{
EndpointResolver, EndpointResolverParams, SharedEndpointResolver,
};
use aws_smithy_types::error::display::DisplayErrorContext;
use aws_smithy_types::retry::{ErrorKind, RetryKind};
use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
use aws_smithy_runtime_api::client::interceptors::SharedInterceptor;
use aws_smithy_runtime_api::client::orchestrator::{
Future, HttpResponse, OrchestratorError, SensitiveOutput,
};
use aws_smithy_runtime_api::client::retries::{
ClassifyRetry, RetryClassifiers, RetryReason, SharedRetryStrategy,
};
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
use aws_smithy_runtime_api::client::runtime_plugin::{RuntimePlugin, SharedRuntimePlugin};
use aws_smithy_types::config_bag::{FrozenLayer, Layer};
use aws_smithy_types::endpoint::Endpoint;
use aws_smithy_types::retry::{ErrorKind, RetryConfig};
use aws_smithy_types::timeout::TimeoutConfig;
use aws_types::os_shim_internal::Env;
use bytes::Bytes;
use http::{Response, Uri};
use http::Uri;
use std::borrow::Cow;
use std::error::Error;
use std::error::Error as _;
use std::fmt;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::OnceCell;
pub mod error;
mod token;
@ -85,8 +99,7 @@ fn user_agent() -> AwsUserAgent {
/// # async fn docs() {
/// let client = Client::builder()
/// .endpoint(Uri::from_static("http://customidms:456/"))
/// .build()
/// .await;
/// .build();
/// # }
/// ```
///
@ -104,7 +117,7 @@ fn user_agent() -> AwsUserAgent {
/// ```no_run
/// use aws_config::imds::client::{Client, EndpointMode};
/// # async fn docs() {
/// let client = Client::builder().endpoint_mode(EndpointMode::IpV6).build().await;
/// let client = Client::builder().endpoint_mode(EndpointMode::IpV6).build();
/// # }
/// ```
///
@ -123,49 +136,7 @@ fn user_agent() -> AwsUserAgent {
///
#[derive(Clone, Debug)]
pub struct Client {
inner: Arc<ClientInner>,
}
#[derive(Debug)]
struct ClientInner {
endpoint: Uri,
smithy_client: aws_smithy_client::Client<DynConnector, ImdsMiddleware>,
}
/// Client where build is sync, but usage is async
///
/// Building an imds::Client is actually an async operation, however, for credentials and region
/// providers, we want build to always be a synchronous operation. This allows building to be deferred
/// and cached until request time.
#[derive(Debug)]
pub(super) struct LazyClient {
client: OnceCell<Result<Client, BuildError>>,
builder: Builder,
}
impl LazyClient {
pub(super) fn from_ready_client(client: Client) -> Self {
Self {
client: OnceCell::from(Ok(client)),
// the builder will never be used in this case
builder: Builder::default(),
}
}
pub(super) async fn client(&self) -> Result<&Client, &BuildError> {
let builder = &self.builder;
self.client
// the clone will only happen once when we actually construct it for the first time,
// after that, we will use the cache.
.get_or_init(|| async {
let client = builder.clone().build().await;
if let Err(err) = &client {
tracing::warn!(err = %DisplayErrorContext(err), "failed to create IMDS client")
}
client
})
.await
.as_ref()
}
operation: Operation<String, SensitiveString, InnerImdsError>,
}
impl Client {
@ -187,18 +158,16 @@ impl Client {
/// ```no_run
/// use aws_config::imds::client::Client;
/// # async fn docs() {
/// let client = Client::builder().build().await.expect("valid client");
/// let client = Client::builder().build();
/// let ami_id = client
/// .get("/latest/meta-data/ami-id")
/// .await
/// .expect("failure communicating with IMDS");
/// # }
/// ```
pub async fn get(&self, path: &str) -> Result<String, ImdsError> {
let operation = self.make_operation(path)?;
self.inner
.smithy_client
.call(operation)
pub async fn get(&self, path: impl Into<String>) -> Result<SensitiveString, ImdsError> {
self.operation
.invoke(path.into())
.await
.map_err(|err| match err {
SdkError::ConstructionFailure(_) if err.source().is_some() => {
@ -213,76 +182,112 @@ impl Client {
InnerImdsError::InvalidUtf8 => {
ImdsError::unexpected("IMDS returned invalid UTF-8")
}
InnerImdsError::BadStatus => {
ImdsError::error_response(context.into_raw().into_parts().0)
}
InnerImdsError::BadStatus => ImdsError::error_response(context.into_raw()),
},
SdkError::TimeoutError(_)
| SdkError::DispatchFailure(_)
| SdkError::ResponseError(_) => ImdsError::io_error(err),
// If the error source is an ImdsError, then we need to directly return that source.
// That way, the IMDS token provider's errors can become the top-level ImdsError.
// There is a unit test that checks the correct error is being extracted.
err @ SdkError::DispatchFailure(_) => match err.into_source() {
Ok(source) => match source.downcast::<ConnectorError>() {
Ok(source) => match source.into_source().downcast::<ImdsError>() {
Ok(source) => *source,
Err(err) => ImdsError::unexpected(err),
},
Err(err) => ImdsError::unexpected(err),
},
Err(err) => ImdsError::unexpected(err),
},
SdkError::TimeoutError(_) | SdkError::ResponseError(_) => ImdsError::io_error(err),
_ => ImdsError::unexpected(err),
})
}
}
/// Creates a aws_smithy_http Operation to for `path`
/// - Convert the path to a URI
/// - Set the base endpoint on the URI
/// - Add a user agent
fn make_operation(
&self,
path: &str,
) -> Result<Operation<ImdsGetResponseHandler, ImdsResponseRetryClassifier>, ImdsError> {
let mut base_uri: Uri = path.parse().map_err(|_| {
ImdsError::unexpected("IMDS path was not a valid URI. Hint: does it begin with `/`?")
})?;
apply_endpoint(&mut base_uri, &self.inner.endpoint, None).map_err(ImdsError::unexpected)?;
let request = http::Request::builder()
.uri(base_uri)
.body(SdkBody::empty())
.expect("valid request");
let mut request = operation::Request::new(request);
request.properties_mut().insert(user_agent());
Ok(Operation::new(request, ImdsGetResponseHandler)
.with_metadata(Metadata::new("get", "imds"))
.with_retry_classifier(ImdsResponseRetryClassifier))
/// New-type around `String` that doesn't emit the string value in the `Debug` impl.
#[derive(Clone)]
pub struct SensitiveString(String);
impl fmt::Debug for SensitiveString {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("SensitiveString")
.field(&"** redacted **")
.finish()
}
}
/// IMDS Middleware
///
/// The IMDS middleware includes a token-loader & a UserAgent stage
#[derive(Clone, Debug)]
struct ImdsMiddleware {
token_loader: TokenMiddleware,
}
impl<S> tower::Layer<S> for ImdsMiddleware {
type Service = AsyncMapRequestService<MapRequestService<S, UserAgentStage>, TokenMiddleware>;
fn layer(&self, inner: S) -> Self::Service {
AsyncMapRequestLayer::for_mapper(self.token_loader.clone())
.layer(MapRequestLayer::for_mapper(UserAgentStage::new()).layer(inner))
impl AsRef<str> for SensitiveString {
fn as_ref(&self) -> &str {
&self.0
}
}
#[derive(Copy, Clone)]
struct ImdsGetResponseHandler;
impl From<String> for SensitiveString {
fn from(value: String) -> Self {
Self(value)
}
}
impl ParseStrictResponse for ImdsGetResponseHandler {
type Output = Result<String, InnerImdsError>;
impl From<SensitiveString> for String {
fn from(value: SensitiveString) -> Self {
value.0
}
}
fn parse(&self, response: &Response<Bytes>) -> Self::Output {
if response.status().is_success() {
std::str::from_utf8(response.body().as_ref())
.map(|data| data.to_string())
.map_err(|_| InnerImdsError::InvalidUtf8)
} else {
Err(InnerImdsError::BadStatus)
/// Runtime plugin that is used by both the IMDS client and the inner client that resolves
/// the IMDS token and attaches it to requests. This runtime plugin marks the responses as
/// sensitive, configures user agent headers, and sets up retries and timeouts.
#[derive(Debug)]
struct ImdsCommonRuntimePlugin {
config: FrozenLayer,
components: RuntimeComponentsBuilder,
}
impl ImdsCommonRuntimePlugin {
fn new(
connector: DynConnector,
endpoint_resolver: ImdsEndpointResolver,
retry_config: &RetryConfig,
timeout_config: TimeoutConfig,
time_source: SharedTimeSource,
sleep_impl: Option<SharedAsyncSleep>,
) -> Self {
let mut layer = Layer::new("ImdsCommonRuntimePlugin");
layer.store_put(AuthSchemeOptionResolverParams::new(()));
layer.store_put(EndpointResolverParams::new(()));
layer.store_put(SensitiveOutput);
layer.store_put(timeout_config);
layer.store_put(user_agent());
Self {
config: layer.freeze(),
components: RuntimeComponentsBuilder::new("ImdsCommonRuntimePlugin")
.with_http_connector(Some(SharedHttpConnector::new(DynConnectorAdapter::new(
connector,
))))
.with_endpoint_resolver(Some(SharedEndpointResolver::new(endpoint_resolver)))
.with_interceptor(SharedInterceptor::new(UserAgentInterceptor::new()))
.with_retry_classifiers(Some(
RetryClassifiers::new().with_classifier(ImdsResponseRetryClassifier),
))
.with_retry_strategy(Some(SharedRetryStrategy::new(StandardRetryStrategy::new(
retry_config,
))))
.with_time_source(Some(time_source))
.with_sleep_impl(sleep_impl),
}
}
}
fn sensitive(&self) -> bool {
true
impl RuntimePlugin for ImdsCommonRuntimePlugin {
fn config(&self) -> Option<FrozenLayer> {
Some(self.config.clone())
}
fn runtime_components(
&self,
_current_components: &RuntimeComponentsBuilder,
) -> Cow<'_, RuntimeComponentsBuilder> {
Cow::Borrowed(&self.components)
}
}
@ -415,15 +420,8 @@ impl Builder {
self
}*/
pub(super) fn build_lazy(self) -> LazyClient {
LazyClient {
client: OnceCell::new(),
builder: self,
}
}
/// Build an IMDSv2 Client
pub async fn build(self) -> Result<Client, BuildError> {
pub fn build(self) -> Client {
let config = self.config.unwrap_or_default();
let timeout_config = TimeoutConfig::builder()
.connect_timeout(self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT))
@ -437,34 +435,46 @@ impl Builder {
let endpoint_source = self
.endpoint
.unwrap_or_else(|| EndpointSource::Env(config.clone()));
let endpoint = endpoint_source.endpoint(self.mode_override).await?;
let retry_config = retry::Config::default()
.with_max_attempts(self.max_attempts.unwrap_or(DEFAULT_ATTEMPTS));
let token_loader = token::TokenMiddleware::new(
connector.clone(),
config.time_source(),
endpoint.clone(),
self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL),
retry_config.clone(),
timeout_config.clone(),
config.sleep(),
);
let middleware = ImdsMiddleware { token_loader };
let mut smithy_builder = aws_smithy_client::Client::builder()
.connector(connector.clone())
.middleware(middleware)
.retry_config(retry_config)
.operation_timeout_config(timeout_config.into());
smithy_builder.set_sleep_impl(config.sleep());
let smithy_client = smithy_builder.build();
let client = Client {
inner: Arc::new(ClientInner {
endpoint,
smithy_client,
}),
let endpoint_resolver = ImdsEndpointResolver {
endpoint_source: Arc::new(endpoint_source),
mode_override: self.mode_override,
};
Ok(client)
let retry_config = RetryConfig::standard()
.with_max_attempts(self.max_attempts.unwrap_or(DEFAULT_ATTEMPTS));
let common_plugin = SharedRuntimePlugin::new(ImdsCommonRuntimePlugin::new(
connector,
endpoint_resolver,
&retry_config,
timeout_config,
config.time_source(),
config.sleep(),
));
let operation = Operation::builder()
.service_name("imds")
.operation_name("get")
.runtime_plugin(common_plugin.clone())
.runtime_plugin(TokenRuntimePlugin::new(
common_plugin,
config.time_source(),
self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL),
))
.serializer(|path| {
Ok(http::Request::builder()
.uri(path)
.body(SdkBody::empty())
.expect("valid request"))
})
.deserializer(|response| {
if response.status().is_success() {
std::str::from_utf8(response.body().bytes().expect("non-streaming response"))
.map(|data| SensitiveString::from(data.to_string()))
.map_err(|_| OrchestratorError::operation(InnerImdsError::InvalidUtf8))
} else {
Err(OrchestratorError::operation(InnerImdsError::BadStatus))
}
})
.build();
Client { operation }
}
}
@ -531,19 +541,22 @@ impl EndpointSource {
}
}
#[derive(Clone)]
struct ImdsResponseRetryClassifier;
#[derive(Clone, Debug)]
struct ImdsEndpointResolver {
endpoint_source: Arc<EndpointSource>,
mode_override: Option<EndpointMode>,
}
impl ImdsResponseRetryClassifier {
fn classify(response: &operation::Response) -> RetryKind {
let status = response.http().status();
match status {
_ if status.is_server_error() => RetryKind::Error(ErrorKind::ServerError),
// 401 indicates that the token has expired, this is retryable
_ if status.as_u16() == 401 => RetryKind::Error(ErrorKind::ServerError),
// This catch-all includes successful responses that fail to parse. These should not be retried.
_ => RetryKind::UnretryableFailure,
}
impl EndpointResolver for ImdsEndpointResolver {
fn resolve_endpoint(&self, _: &EndpointResolverParams) -> Future<Endpoint> {
let this = self.clone();
Future::new(Box::pin(async move {
this.endpoint_source
.endpoint(this.mode_override)
.await
.map(|uri| Endpoint::builder().url(uri.to_string()).build())
.map_err(|err| err.into())
}))
}
}
@ -556,13 +569,35 @@ impl ImdsResponseRetryClassifier {
/// - 403 (IMDS disabled): **Not Retryable**
/// - 404 (Not found): **Not Retryable**
/// - >=500 (server error): **Retryable**
impl<T, E> ClassifyRetry<SdkSuccess<T>, SdkError<E>> for ImdsResponseRetryClassifier {
fn classify_retry(&self, response: Result<&SdkSuccess<T>, &SdkError<E>>) -> RetryKind {
match response {
Ok(_) => RetryKind::Unnecessary,
Err(SdkError::ResponseError(context)) => Self::classify(context.raw()),
Err(SdkError::ServiceError(context)) => Self::classify(context.raw()),
_ => RetryKind::UnretryableFailure,
#[derive(Clone, Debug)]
struct ImdsResponseRetryClassifier;
impl ImdsResponseRetryClassifier {
fn classify(response: &HttpResponse) -> Option<RetryReason> {
let status = response.status();
match status {
_ if status.is_server_error() => Some(RetryReason::Error(ErrorKind::ServerError)),
// 401 indicates that the token has expired, this is retryable
_ if status.as_u16() == 401 => Some(RetryReason::Error(ErrorKind::ServerError)),
// This catch-all includes successful responses that fail to parse. These should not be retried.
_ => None,
}
}
}
impl ClassifyRetry for ImdsResponseRetryClassifier {
fn name(&self) -> &'static str {
"ImdsResponseRetryClassifier"
}
fn classify_retry(&self, ctx: &InterceptorContext) -> Option<RetryReason> {
if let Some(response) = ctx.response() {
Self::classify(response)
} else {
// Don't retry timeouts for IMDS, or else it will take ~30 seconds for the default
// credentials provider chain to fail to provide credentials.
// Also don't retry non-responses.
None
}
}
}
@ -575,10 +610,15 @@ pub(crate) mod test {
use aws_smithy_async::test_util::instant_time_and_sleep;
use aws_smithy_client::erase::DynConnector;
use aws_smithy_client::test_connection::{capture_request, TestConnection};
use aws_smithy_client::{SdkError, SdkSuccess};
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::operation;
use aws_smithy_types::retry::RetryKind;
use aws_smithy_http::result::ConnectorError;
use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
use aws_smithy_runtime_api::client::interceptors::context::{
Input, InterceptorContext, Output,
};
use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
use aws_smithy_runtime_api::client::retries::ClassifyRetry;
use aws_smithy_types::error::display::DisplayErrorContext;
use aws_types::os_shim_internal::{Env, Fs};
use http::header::USER_AGENT;
use http::Uri;
@ -637,7 +677,7 @@ pub(crate) mod test {
http::Response::builder().status(200).body(body).unwrap()
}
pub(crate) async fn make_client<T>(conn: &TestConnection<T>) -> super::Client
pub(crate) fn make_client<T>(conn: &TestConnection<T>) -> super::Client
where
SdkBody: From<T>,
T: Send + 'static,
@ -650,8 +690,6 @@ pub(crate) mod test {
.with_http_connector(DynConnector::new(conn.clone())),
)
.build()
.await
.expect("valid client")
}
#[tokio::test]
@ -670,13 +708,13 @@ pub(crate) mod test {
imds_response("output2"),
),
]);
let client = make_client(&connection).await;
let client = make_client(&connection);
// load once
let metadata = client.get("/latest/metadata").await.expect("failed");
assert_eq!(metadata, "test-imds-output");
assert_eq!("test-imds-output", metadata.as_ref());
// load again: the cached token should be used
let metadata = client.get("/latest/metadata2").await.expect("failed");
assert_eq!(metadata, "output2");
assert_eq!("output2", metadata.as_ref());
connection.assert_requests_match(&[]);
}
@ -710,17 +748,15 @@ pub(crate) mod test {
)
.endpoint_mode(EndpointMode::IpV6)
.token_ttl(Duration::from_secs(600))
.build()
.await
.expect("valid client");
.build();
let resp1 = client.get("/latest/metadata").await.expect("success");
// now the cached credential has expired
time_source.advance(Duration::from_secs(600));
let resp2 = client.get("/latest/metadata").await.expect("success");
connection.assert_requests_match(&[]);
assert_eq!(resp1, "test-imds-output1");
assert_eq!(resp2, "test-imds-output2");
assert_eq!("test-imds-output1", resp1.as_ref());
assert_eq!("test-imds-output2", resp2.as_ref());
}
/// Tokens are refreshed up to 120 seconds early to avoid using an expired token.
@ -761,9 +797,7 @@ pub(crate) mod test {
)
.endpoint_mode(EndpointMode::IpV6)
.token_ttl(Duration::from_secs(600))
.build()
.await
.expect("valid client");
.build();
let resp1 = client.get("/latest/metadata").await.expect("success");
// now the cached credential has expired
@ -772,9 +806,9 @@ pub(crate) mod test {
time_source.advance(Duration::from_secs(150));
let resp3 = client.get("/latest/metadata").await.expect("success");
connection.assert_requests_match(&[]);
assert_eq!(resp1, "test-imds-output1");
assert_eq!(resp2, "test-imds-output2");
assert_eq!(resp3, "test-imds-output3");
assert_eq!("test-imds-output1", resp1.as_ref());
assert_eq!("test-imds-output2", resp2.as_ref());
assert_eq!("test-imds-output3", resp3.as_ref());
}
/// 500 error during the GET should be retried
@ -795,8 +829,15 @@ pub(crate) mod test {
imds_response("ok"),
),
]);
let client = make_client(&connection).await;
assert_eq!(client.get("/latest/metadata").await.expect("success"), "ok");
let client = make_client(&connection);
assert_eq!(
"ok",
client
.get("/latest/metadata")
.await
.expect("success")
.as_ref()
);
connection.assert_requests_match(&[]);
// all requests should have a user agent header
@ -823,8 +864,15 @@ pub(crate) mod test {
imds_response("ok"),
),
]);
let client = make_client(&connection).await;
assert_eq!(client.get("/latest/metadata").await.expect("success"), "ok");
let client = make_client(&connection);
assert_eq!(
"ok",
client
.get("/latest/metadata")
.await
.expect("success")
.as_ref()
);
connection.assert_requests_match(&[]);
}
@ -850,8 +898,15 @@ pub(crate) mod test {
imds_response("ok"),
),
]);
let client = make_client(&connection).await;
assert_eq!(client.get("/latest/metadata").await.expect("success"), "ok");
let client = make_client(&connection);
assert_eq!(
"ok",
client
.get("/latest/metadata")
.await
.expect("success")
.as_ref()
);
connection.assert_requests_match(&[]);
}
@ -863,7 +918,7 @@ pub(crate) mod test {
token_request("http://169.254.169.254", 21600),
http::Response::builder().status(403).body("").unwrap(),
)]);
let client = make_client(&connection).await;
let client = make_client(&connection);
let err = client.get("/latest/metadata").await.expect_err("no token");
assert_full_error_contains!(err, "forbidden");
connection.assert_requests_match(&[]);
@ -872,30 +927,18 @@ pub(crate) mod test {
/// Successful responses should classify as `RetryKind::Unnecessary`
#[test]
fn successful_response_properly_classified() {
use aws_smithy_http::retry::ClassifyRetry;
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
ctx.set_output_or_error(Ok(Output::doesnt_matter()));
ctx.set_response(imds_response("").map(|_| SdkBody::empty()));
let classifier = ImdsResponseRetryClassifier;
fn response_200() -> operation::Response {
operation::Response::new(imds_response("").map(|_| SdkBody::empty()))
}
let success = SdkSuccess {
raw: response_200(),
parsed: (),
};
assert_eq!(
RetryKind::Unnecessary,
classifier.classify_retry(Ok::<_, &SdkError<()>>(&success))
);
assert_eq!(None, classifier.classify_retry(&ctx));
// Emulate a failure to parse the response body (using an io error since it's easy to construct in a test)
let failure = SdkError::<()>::response_error(
io::Error::new(io::ErrorKind::BrokenPipe, "fail to parse"),
response_200(),
);
assert_eq!(
RetryKind::UnretryableFailure,
classifier.classify_retry(Err::<&SdkSuccess<()>, _>(&failure))
);
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
ctx.set_output_or_error(Err(OrchestratorError::connector(ConnectorError::io(
io::Error::new(io::ErrorKind::BrokenPipe, "fail to parse").into(),
))));
assert_eq!(None, classifier.classify_retry(&ctx));
}
// since tokens are sent as headers, the tokens need to be valid header values
@ -905,7 +948,7 @@ pub(crate) mod test {
token_request("http://169.254.169.254", 21600),
token_response(21600, "replaced").map(|_| vec![1, 0]),
)]);
let client = make_client(&connection).await;
let client = make_client(&connection);
let err = client.get("/latest/metadata").await.expect_err("no token");
assert_full_error_contains!(err, "invalid token");
connection.assert_requests_match(&[]);
@ -926,7 +969,7 @@ pub(crate) mod test {
.unwrap(),
),
]);
let client = make_client(&connection).await;
let client = make_client(&connection);
let err = client.get("/latest/metadata").await.expect_err("no token");
assert_full_error_contains!(err, "invalid UTF-8");
connection.assert_requests_match(&[]);
@ -943,14 +986,20 @@ pub(crate) mod test {
let client = Client::builder()
// 240.* can never be resolved
.endpoint(Uri::from_static("http://240.0.0.0"))
.build()
.await
.expect("valid client");
.build();
let now = SystemTime::now();
let resp = client
.get("/latest/metadata")
.await
.expect_err("240.0.0.0 will never resolve");
match resp {
err @ ImdsError::FailedToLoadToken(_)
if format!("{}", DisplayErrorContext(&err)).contains("timeout") => {} // ok,
other => panic!(
"wrong error, expected construction failure with TimedOutError inside: {}",
DisplayErrorContext(&other)
),
}
let time_elapsed = now.elapsed().unwrap();
assert!(
time_elapsed > Duration::from_secs(1),
@ -962,14 +1011,6 @@ pub(crate) mod test {
"time_elapsed should be less than 2s but was {:?}",
time_elapsed
);
match resp {
err @ ImdsError::FailedToLoadToken(_)
if format!("{}", DisplayErrorContext(&err)).contains("timeout") => {} // ok,
other => panic!(
"wrong error, expected construction failure with TimedOutError inside: {}",
other
),
}
}
#[derive(Debug, Deserialize)]
@ -983,8 +1024,10 @@ pub(crate) mod test {
}
#[tokio::test]
async fn config_tests() -> Result<(), Box<dyn Error>> {
let test_cases = std::fs::read_to_string("test-data/imds-config/imds-tests.json")?;
async fn endpoint_config_tests() -> Result<(), Box<dyn Error>> {
let _logs = capture_test_logs();
let test_cases = std::fs::read_to_string("test-data/imds-config/imds-endpoint-tests.json")?;
#[derive(Deserialize)]
struct TestCases {
tests: Vec<ImdsConfigTest>,
@ -1014,24 +1057,22 @@ pub(crate) mod test {
imds_client = imds_client.endpoint_mode(mode_override.parse().unwrap());
}
let imds_client = imds_client.build().await;
let (uri, imds_client) = match (&test_case.result, imds_client) {
(Ok(uri), Ok(client)) => (uri, client),
(Err(test), Ok(_client)) => panic!(
"test should fail: {} but a valid client was made. {}",
test, test_case.docs
),
(Err(substr), Err(err)) => {
assert_full_error_contains!(err, substr);
return;
let imds_client = imds_client.build();
match &test_case.result {
Ok(uri) => {
// this request will fail, we just want to capture the endpoint configuration
let _ = imds_client.get("/hello").await;
assert_eq!(&watcher.expect_request().uri().to_string(), uri);
}
Err(expected) => {
let err = imds_client.get("/hello").await.expect_err("it should fail");
let message = format!("{}", DisplayErrorContext(&err));
assert!(
message.contains(expected),
"{}\nexpected error: {expected}\nactual error: {message}",
test_case.docs
);
}
(Ok(_uri), Err(e)) => panic!(
"a valid client should be made but: {}. {}",
e, test_case.docs
),
};
// this request will fail, we just want to capture the endpoint configuration
let _ = imds_client.get("/hello").await;
assert_eq!(&watcher.expect_request().uri().to_string(), uri);
}
}

View File

@ -8,13 +8,14 @@
use aws_smithy_client::SdkError;
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::endpoint::error::InvalidEndpointError;
use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
use std::error::Error;
use std::fmt;
/// Error context for [`ImdsError::FailedToLoadToken`]
#[derive(Debug)]
pub struct FailedToLoadToken {
source: SdkError<TokenError>,
source: SdkError<TokenError, HttpResponse>,
}
impl FailedToLoadToken {
@ -23,7 +24,7 @@ impl FailedToLoadToken {
matches!(self.source, SdkError::DispatchFailure(_))
}
pub(crate) fn into_source(self) -> SdkError<TokenError> {
pub(crate) fn into_source(self) -> SdkError<TokenError, HttpResponse> {
self.source
}
}
@ -76,7 +77,7 @@ pub enum ImdsError {
}
impl ImdsError {
pub(super) fn failed_to_load_token(source: SdkError<TokenError>) -> Self {
pub(super) fn failed_to_load_token(source: SdkError<TokenError, HttpResponse>) -> Self {
Self::FailedToLoadToken(FailedToLoadToken { source })
}

View File

@ -15,26 +15,29 @@
//! - Attach the token to the request in the `x-aws-ec2-metadata-token` header
use crate::imds::client::error::{ImdsError, TokenError, TokenErrorKind};
use crate::imds::client::ImdsResponseRetryClassifier;
use aws_credential_types::cache::ExpiringCache;
use aws_http::user_agent::UserAgentStage;
use aws_smithy_async::rt::sleep::SharedAsyncSleep;
use aws_smithy_async::time::SharedTimeSource;
use aws_smithy_client::erase::DynConnector;
use aws_smithy_client::retry;
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::endpoint::apply_endpoint;
use aws_smithy_http::middleware::AsyncMapRequest;
use aws_smithy_http::operation;
use aws_smithy_http::operation::Operation;
use aws_smithy_http::operation::{Metadata, Request};
use aws_smithy_http::response::ParseStrictResponse;
use aws_smithy_http_tower::map_request::MapRequestLayer;
use aws_smithy_types::timeout::TimeoutConfig;
use aws_smithy_runtime::client::orchestrator::operation::Operation;
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::auth::static_resolver::StaticAuthSchemeOptionResolver;
use aws_smithy_runtime_api::client::auth::{
AuthScheme, AuthSchemeEndpointConfig, AuthSchemeId, Signer,
};
use aws_smithy_runtime_api::client::identity::{
Identity, IdentityResolver, SharedIdentityResolver,
};
use aws_smithy_runtime_api::client::orchestrator::{
Future, HttpRequest, HttpResponse, OrchestratorError,
};
use aws_smithy_runtime_api::client::runtime_components::{
GetIdentityResolver, RuntimeComponents, RuntimeComponentsBuilder,
};
use aws_smithy_runtime_api::client::runtime_plugin::{RuntimePlugin, SharedRuntimePlugin};
use aws_smithy_types::config_bag::ConfigBag;
use http::{HeaderValue, Uri};
use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::pin::Pin;
use std::borrow::Cow;
use std::fmt;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
@ -47,6 +50,7 @@ const TOKEN_REFRESH_BUFFER: Duration = Duration::from_secs(120);
const X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS: &str = "x-aws-ec2-metadata-token-ttl-seconds";
const X_AWS_EC2_METADATA_TOKEN: &str = "x-aws-ec2-metadata-token";
const IMDS_TOKEN_AUTH_SCHEME: AuthSchemeId = AuthSchemeId::new(X_AWS_EC2_METADATA_TOKEN);
/// IMDS Token
#[derive(Clone)]
@ -54,151 +58,214 @@ struct Token {
value: HeaderValue,
expiry: SystemTime,
}
/// Token Middleware
///
/// Token middleware will load/cache a token when required and handle caching/expiry.
///
/// It will attach the token to the incoming request on the `x-aws-ec2-metadata-token` header.
#[derive(Clone)]
pub(super) struct TokenMiddleware {
client: Arc<aws_smithy_client::Client<DynConnector, MapRequestLayer<UserAgentStage>>>,
token_parser: GetTokenResponseHandler,
token: ExpiringCache<Token, ImdsError>,
time_source: SharedTimeSource,
endpoint: Uri,
token_ttl: Duration,
}
impl Debug for TokenMiddleware {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "ImdsTokenMiddleware")
impl fmt::Debug for Token {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Token")
.field("value", &"** redacted **")
.field("expiry", &self.expiry)
.finish()
}
}
impl TokenMiddleware {
/// Token Runtime Plugin
///
/// This runtime plugin wires up the necessary components to load/cache a token
/// when required and handle caching/expiry. This token will get attached to the
/// request to IMDS on the `x-aws-ec2-metadata-token` header.
#[derive(Debug)]
pub(super) struct TokenRuntimePlugin {
components: RuntimeComponentsBuilder,
}
impl TokenRuntimePlugin {
pub(super) fn new(
connector: DynConnector,
common_plugin: SharedRuntimePlugin,
time_source: SharedTimeSource,
endpoint: Uri,
token_ttl: Duration,
retry_config: retry::Config,
timeout_config: TimeoutConfig,
sleep_impl: Option<SharedAsyncSleep>,
) -> Self {
let mut inner_builder = aws_smithy_client::Client::builder()
.connector(connector)
.middleware(MapRequestLayer::<UserAgentStage>::default())
.retry_config(retry_config)
.operation_timeout_config(timeout_config.into());
inner_builder.set_sleep_impl(sleep_impl);
let inner_client = inner_builder.build();
let client = Arc::new(inner_client);
Self {
client,
token_parser: GetTokenResponseHandler {
time: time_source.clone(),
},
token: ExpiringCache::new(TOKEN_REFRESH_BUFFER),
time_source,
endpoint,
token_ttl,
components: RuntimeComponentsBuilder::new("TokenRuntimePlugin")
.with_auth_scheme(TokenAuthScheme::new())
.with_auth_scheme_option_resolver(Some(StaticAuthSchemeOptionResolver::new(vec![
IMDS_TOKEN_AUTH_SCHEME,
])))
.with_identity_resolver(
IMDS_TOKEN_AUTH_SCHEME,
TokenResolver::new(common_plugin, time_source, token_ttl),
),
}
}
async fn add_token(&self, request: Request) -> Result<Request, ImdsError> {
let preloaded_token = self
.token
.yield_or_clear_if_expired(self.time_source.now())
.await;
let token = match preloaded_token {
Some(token) => Ok(token),
None => {
self.token
.get_or_load(|| async move { self.get_token().await })
.await
}
}?;
request.augment(|mut request, _| {
request
.headers_mut()
.insert(X_AWS_EC2_METADATA_TOKEN, token.value);
Ok(request)
})
}
impl RuntimePlugin for TokenRuntimePlugin {
fn runtime_components(
&self,
_current_components: &RuntimeComponentsBuilder,
) -> Cow<'_, RuntimeComponentsBuilder> {
Cow::Borrowed(&self.components)
}
}
#[derive(Debug)]
struct TokenResolverInner {
cache: ExpiringCache<Token, ImdsError>,
refresh: Operation<(), Token, TokenError>,
time_source: SharedTimeSource,
}
#[derive(Clone, Debug)]
struct TokenResolver {
inner: Arc<TokenResolverInner>,
}
impl TokenResolver {
fn new(
common_plugin: SharedRuntimePlugin,
time_source: SharedTimeSource,
token_ttl: Duration,
) -> Self {
Self {
inner: Arc::new(TokenResolverInner {
cache: ExpiringCache::new(TOKEN_REFRESH_BUFFER),
refresh: Operation::builder()
.service_name("imds")
.operation_name("get-token")
.runtime_plugin(common_plugin)
.no_auth()
.serializer(move |_| {
Ok(http::Request::builder()
.method("PUT")
.uri(Uri::from_static("/latest/api/token"))
.header(X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS, token_ttl.as_secs())
.body(SdkBody::empty())
.expect("valid HTTP request"))
})
.deserializer({
let time_source = time_source.clone();
move |response| {
let now = time_source.now();
parse_token_response(response, now)
.map_err(OrchestratorError::operation)
}
})
.build(),
time_source,
}),
}
}
async fn get_token(&self) -> Result<(Token, SystemTime), ImdsError> {
let mut uri = Uri::from_static("/latest/api/token");
apply_endpoint(&mut uri, &self.endpoint, None).map_err(ImdsError::unexpected)?;
let request = http::Request::builder()
.header(
X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS,
self.token_ttl.as_secs(),
)
.uri(uri)
.method("PUT")
.body(SdkBody::empty())
.expect("valid HTTP request");
let mut request = operation::Request::new(request);
request.properties_mut().insert(super::user_agent());
let operation = Operation::new(request, self.token_parser.clone())
.with_retry_classifier(ImdsResponseRetryClassifier)
.with_metadata(Metadata::new("get-token", "imds"));
let response = self
.client
.call(operation)
self.inner
.refresh
.invoke(())
.await
.map_err(ImdsError::failed_to_load_token)?;
let expiry = response.expiry;
Ok((response, expiry))
.map(|token| {
let expiry = token.expiry;
(token, expiry)
})
.map_err(ImdsError::failed_to_load_token)
}
}
impl AsyncMapRequest for TokenMiddleware {
type Error = ImdsError;
type Future = Pin<Box<dyn Future<Output = Result<Request, Self::Error>> + Send + 'static>>;
fn name(&self) -> &'static str {
"attach_imds_token"
fn parse_token_response(response: &HttpResponse, now: SystemTime) -> Result<Token, TokenError> {
match response.status().as_u16() {
400 => return Err(TokenErrorKind::InvalidParameters.into()),
403 => return Err(TokenErrorKind::Forbidden.into()),
_ => {}
}
fn apply(&self, request: Request) -> Self::Future {
let this = self.clone();
Box::pin(async move { this.add_token(request).await })
}
}
#[derive(Clone)]
struct GetTokenResponseHandler {
time: SharedTimeSource,
}
impl ParseStrictResponse for GetTokenResponseHandler {
type Output = Result<Token, TokenError>;
fn parse(&self, response: &http::Response<bytes::Bytes>) -> Self::Output {
match response.status().as_u16() {
400 => return Err(TokenErrorKind::InvalidParameters.into()),
403 => return Err(TokenErrorKind::Forbidden.into()),
_ => {}
}
let value = HeaderValue::from_maybe_shared(response.body().clone())
let mut value =
HeaderValue::from_bytes(response.body().bytes().expect("non-streaming response"))
.map_err(|_| TokenErrorKind::InvalidToken)?;
let ttl: u64 = response
.headers()
.get(X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS)
.ok_or(TokenErrorKind::NoTtl)?
.to_str()
.map_err(|_| TokenErrorKind::InvalidTtl)?
.parse()
.map_err(|_parse_error| TokenErrorKind::InvalidTtl)?;
Ok(Token {
value,
expiry: self.time.now() + Duration::from_secs(ttl),
})
}
value.set_sensitive(true);
fn sensitive(&self) -> bool {
true
let ttl: u64 = response
.headers()
.get(X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS)
.ok_or(TokenErrorKind::NoTtl)?
.to_str()
.map_err(|_| TokenErrorKind::InvalidTtl)?
.parse()
.map_err(|_parse_error| TokenErrorKind::InvalidTtl)?;
Ok(Token {
value,
expiry: now + Duration::from_secs(ttl),
})
}
impl IdentityResolver for TokenResolver {
fn resolve_identity(&self, _config_bag: &ConfigBag) -> Future<Identity> {
let this = self.clone();
Future::new(Box::pin(async move {
let preloaded_token = this
.inner
.cache
.yield_or_clear_if_expired(this.inner.time_source.now())
.await;
let token = match preloaded_token {
Some(token) => Ok(token),
None => {
this.inner
.cache
.get_or_load(|| {
let this = this.clone();
async move { this.get_token().await }
})
.await
}
}?;
let expiry = token.expiry;
Ok(Identity::new(token, Some(expiry)))
}))
}
}
#[derive(Debug)]
struct TokenAuthScheme {
signer: TokenSigner,
}
impl TokenAuthScheme {
fn new() -> Self {
Self {
signer: TokenSigner,
}
}
}
impl AuthScheme for TokenAuthScheme {
fn scheme_id(&self) -> AuthSchemeId {
IMDS_TOKEN_AUTH_SCHEME
}
fn identity_resolver(
&self,
identity_resolvers: &dyn GetIdentityResolver,
) -> Option<SharedIdentityResolver> {
identity_resolvers.identity_resolver(IMDS_TOKEN_AUTH_SCHEME)
}
fn signer(&self) -> &dyn Signer {
&self.signer
}
}
#[derive(Debug)]
struct TokenSigner;
impl Signer for TokenSigner {
fn sign_http_request(
&self,
request: &mut HttpRequest,
identity: &Identity,
_auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
_runtime_components: &RuntimeComponents,
_config_bag: &ConfigBag,
) -> Result<(), BoxError> {
let token = identity.data::<Token>().expect("correct type");
request
.headers_mut()
.append(X_AWS_EC2_METADATA_TOKEN, token.value.clone());
Ok(())
}
}

View File

@ -9,8 +9,7 @@
//! This credential provider will NOT fallback to IMDSv1. Ensure that IMDSv2 is enabled on your instances.
use super::client::error::ImdsError;
use crate::imds;
use crate::imds::client::LazyClient;
use crate::imds::{self, Client};
use crate::json_credentials::{parse_json_credentials, JsonCredentials, RefreshableCredentials};
use crate::provider_config::ProviderConfig;
use aws_credential_types::provider::{self, error::CredentialsError, future, ProvideCredentials};
@ -50,7 +49,7 @@ impl StdError for ImdsCommunicationError {
/// _Note: This credentials provider will NOT fallback to the IMDSv1 flow._
#[derive(Debug)]
pub struct ImdsCredentialsProvider {
client: LazyClient,
client: Client,
env: Env,
profile: Option<String>,
time_source: SharedTimeSource,
@ -110,12 +109,7 @@ impl Builder {
let env = provider_config.env();
let client = self
.imds_override
.map(LazyClient::from_ready_client)
.unwrap_or_else(|| {
imds::Client::builder()
.configure(&provider_config)
.build_lazy()
});
.unwrap_or_else(|| imds::Client::builder().configure(&provider_config).build());
ImdsCredentialsProvider {
client,
env,
@ -156,23 +150,14 @@ impl ImdsCredentialsProvider {
}
}
/// Load an inner IMDS client from the OnceCell
async fn client(&self) -> Result<&imds::Client, CredentialsError> {
self.client.client().await.map_err(|build_error| {
// need to format the build error since we don't own it and it can't be cloned
CredentialsError::invalid_configuration(format!("{}", build_error))
})
}
/// Retrieve the instance profile from IMDS
async fn get_profile_uncached(&self) -> Result<String, CredentialsError> {
match self
.client()
.await?
.client
.get("/latest/meta-data/iam/security-credentials/")
.await
{
Ok(profile) => Ok(profile),
Ok(profile) => Ok(profile.as_ref().into()),
Err(ImdsError::ErrorResponse(context))
if context.response().status().as_u16() == 404 =>
{
@ -223,9 +208,11 @@ impl ImdsCredentialsProvider {
async fn retrieve_credentials(&self) -> provider::Result {
if self.imds_disabled() {
tracing::debug!("IMDS disabled because $AWS_EC2_METADATA_DISABLED was set to `true`");
tracing::debug!(
"IMDS disabled because AWS_EC2_METADATA_DISABLED env var was set to `true`"
);
return Err(CredentialsError::not_loaded(
"IMDS disabled by $AWS_ECS_METADATA_DISABLED",
"IMDS disabled by AWS_ECS_METADATA_DISABLED env var",
));
}
tracing::debug!("loading credentials from IMDS");
@ -235,15 +222,14 @@ impl ImdsCredentialsProvider {
};
tracing::debug!(profile = %profile, "loaded profile");
let credentials = self
.client()
.await?
.get(&format!(
.client
.get(format!(
"/latest/meta-data/iam/security-credentials/{}",
profile
))
.await
.map_err(CredentialsError::provider_error)?;
match parse_json_credentials(&credentials) {
match parse_json_credentials(credentials.as_ref()) {
Ok(JsonCredentials::RefreshableCredentials(RefreshableCredentials {
access_key_id,
secret_access_key,
@ -296,19 +282,16 @@ impl ImdsCredentialsProvider {
#[cfg(test)]
mod test {
use std::time::{Duration, UNIX_EPOCH};
use super::*;
use crate::imds::client::test::{
imds_request, imds_response, make_client, token_request, token_response,
};
use crate::imds::credentials::{
ImdsCredentialsProvider, WARNING_FOR_EXTENDING_CREDENTIALS_EXPIRY,
};
use crate::provider_config::ProviderConfig;
use aws_credential_types::provider::ProvideCredentials;
use aws_smithy_async::test_util::instant_time_and_sleep;
use aws_smithy_client::erase::DynConnector;
use aws_smithy_client::test_connection::TestConnection;
use std::time::{Duration, UNIX_EPOCH};
use tracing_test::traced_test;
const TOKEN_A: &str = "token_a";
@ -338,7 +321,7 @@ mod test {
),
]);
let client = ImdsCredentialsProvider::builder()
.imds_client(make_client(&connection).await)
.imds_client(make_client(&connection))
.build();
let creds1 = client.provide_credentials().await.expect("valid creds");
let creds2 = client.provide_credentials().await.expect("valid creds");
@ -376,9 +359,7 @@ mod test {
.with_time_source(time_source);
let client = crate::imds::Client::builder()
.configure(&provider_config)
.build()
.await
.expect("valid client");
.build();
let provider = ImdsCredentialsProvider::builder()
.configure(&provider_config)
.imds_client(client)
@ -422,9 +403,7 @@ mod test {
.with_time_source(time_source);
let client = crate::imds::Client::builder()
.configure(&provider_config)
.build()
.await
.expect("valid client");
.build();
let provider = ImdsCredentialsProvider::builder()
.configure(&provider_config)
.imds_client(client)
@ -443,9 +422,7 @@ mod test {
let client = crate::imds::Client::builder()
// 240.* can never be resolved
.endpoint(http::Uri::from_static("http://240.0.0.0"))
.build()
.await
.expect("valid client");
.build();
let expected = aws_credential_types::Credentials::for_tests();
let provider = ImdsCredentialsProvider::builder()
.imds_client(client)
@ -463,18 +440,16 @@ mod test {
let client = crate::imds::Client::builder()
// 240.* can never be resolved
.endpoint(http::Uri::from_static("http://240.0.0.0"))
.build()
.await
.expect("valid client");
.build();
let provider = ImdsCredentialsProvider::builder()
.imds_client(client)
// no fallback credentials provided
.build();
let actual = provider.provide_credentials().await;
assert!(matches!(
actual,
Err(aws_credential_types::provider::error::CredentialsError::CredentialsNotLoaded(_))
));
assert!(
matches!(actual, Err(CredentialsError::CredentialsNotLoaded(_))),
"\nexpected: Err(CredentialsError::CredentialsNotLoaded(_))\nactual: {actual:?}"
);
}
#[tokio::test]
@ -484,9 +459,7 @@ mod test {
let client = crate::imds::Client::builder()
// 240.* can never be resolved
.endpoint(http::Uri::from_static("http://240.0.0.0"))
.build()
.await
.expect("valid client");
.build();
let expected = aws_credential_types::Credentials::for_tests();
let provider = ImdsCredentialsProvider::builder()
.imds_client(client)
@ -536,7 +509,7 @@ mod test {
),
]);
let provider = ImdsCredentialsProvider::builder()
.imds_client(make_client(&connection).await)
.imds_client(make_client(&connection))
.build();
let creds1 = provider.provide_credentials().await.expect("valid creds");
assert_eq!(creds1.access_key_id(), "ASIARTEST");

View File

@ -8,8 +8,7 @@
//! Load region from IMDS from `/latest/meta-data/placement/region`
//! This provider has a 5 second timeout.
use crate::imds;
use crate::imds::client::LazyClient;
use crate::imds::{self, Client};
use crate::meta::region::{future, ProvideRegion};
use crate::provider_config::ProviderConfig;
use aws_smithy_types::error::display::DisplayErrorContext;
@ -22,7 +21,7 @@ use tracing::Instrument;
/// This provider is included in the default region chain, so it does not need to be used manually.
#[derive(Debug)]
pub struct ImdsRegionProvider {
client: LazyClient,
client: Client,
env: Env,
}
@ -49,11 +48,10 @@ impl ImdsRegionProvider {
tracing::debug!("not using IMDS to load region, IMDS is disabled");
return None;
}
let client = self.client.client().await.ok()?;
match client.get(REGION_PATH).await {
match self.client.get(REGION_PATH).await {
Ok(region) => {
tracing::debug!(region = %region, "loaded region from IMDS");
Some(Region::new(region))
tracing::debug!(region = %region.as_ref(), "loaded region from IMDS");
Some(Region::new(String::from(region)))
}
Err(err) => {
tracing::warn!(err = %DisplayErrorContext(&err), "failed to load region from IMDS");
@ -99,12 +97,7 @@ impl Builder {
let provider_config = self.provider_config.unwrap_or_default();
let client = self
.imds_client_override
.map(LazyClient::from_ready_client)
.unwrap_or_else(|| {
imds::Client::builder()
.configure(&provider_config)
.build_lazy()
});
.unwrap_or_else(|| imds::Client::builder().configure(&provider_config).build());
ImdsRegionProvider {
client,
env: provider_config.env(),

View File

@ -82,25 +82,25 @@ impl Interceptor for UserAgentInterceptor {
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let api_metadata = cfg
.load::<ApiMetadata>()
.ok_or(UserAgentInterceptorError::MissingApiMetadata)?;
// Allow for overriding the user agent by an earlier interceptor (so, for example,
// tests can use `AwsUserAgent::for_tests()`) by attempting to grab one out of the
// config bag before creating one.
let ua: Cow<'_, AwsUserAgent> = cfg
.load::<AwsUserAgent>()
.map(Cow::Borrowed)
.map(Result::<_, UserAgentInterceptorError>::Ok)
.unwrap_or_else(|| {
let api_metadata = cfg
.load::<ApiMetadata>()
.ok_or(UserAgentInterceptorError::MissingApiMetadata)?;
let mut ua = AwsUserAgent::new_from_environment(Env::real(), api_metadata.clone());
let maybe_app_name = cfg.load::<AppName>();
if let Some(app_name) = maybe_app_name {
ua.set_app_name(app_name.clone());
}
Cow::Owned(ua)
});
Ok(Cow::Owned(ua))
})?;
let headers = context.request_mut().headers_mut();
let (user_agent, x_amz_user_agent) = header_values(&ua)?;
@ -250,4 +250,30 @@ mod tests {
"`{error}` should contain message `This is a bug`"
);
}
#[test]
fn test_api_metadata_missing_with_ua_override() {
let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
let mut context = context();
let mut layer = Layer::new("test");
layer.store_put(AwsUserAgent::for_tests());
let mut config = ConfigBag::of_layers(vec![layer]);
let interceptor = UserAgentInterceptor::new();
let mut ctx = Into::into(&mut context);
interceptor
.modify_before_signing(&mut ctx, &rc, &mut config)
.expect("it should succeed");
let header = expect_header(&context, "user-agent");
assert_eq!(AwsUserAgent::for_tests().ua_header(), header);
assert!(!header.contains("unused"));
assert_eq!(
AwsUserAgent::for_tests().aws_ua_header(),
expect_header(&context, "x-amz-user-agent")
);
}
}

View File

@ -12,7 +12,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationCus
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationSection
import software.amazon.smithy.rust.codegen.client.smithy.generators.SensitiveIndex
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType

View File

@ -10,7 +10,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
@ -72,17 +71,9 @@ sealed class OperationSection(name: String) : Section(name) {
val operationShape: OperationShape,
) : OperationSection("AdditionalInterceptors") {
fun registerInterceptor(runtimeConfig: RuntimeConfig, writer: RustWriter, interceptor: Writable) {
val smithyRuntimeApi = RuntimeType.smithyRuntimeApi(runtimeConfig)
writer.rustTemplate(
"""
.with_interceptor(
#{SharedInterceptor}::new(
#{interceptor}
) as _
)
""",
".with_interceptor(#{interceptor})",
"interceptor" to interceptor,
"SharedInterceptor" to smithyRuntimeApi.resolve("client::interceptors::SharedInterceptor"),
)
}
}

View File

@ -43,10 +43,9 @@ sealed class ServiceRuntimePluginSection(name: String) : Section(name) {
fun registerInterceptor(runtimeConfig: RuntimeConfig, writer: RustWriter, interceptor: Writable) {
writer.rustTemplate(
"""
runtime_components.push_interceptor(#{SharedInterceptor}::new(#{interceptor}) as _);
runtime_components.push_interceptor(#{interceptor});
""",
"interceptor" to interceptor,
"SharedInterceptor" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::interceptors::SharedInterceptor"),
)
}

View File

@ -134,7 +134,7 @@ pub struct ValidateRequest {
impl ValidateRequest {
pub fn assert_matches(&self, ignore_headers: &[HeaderName]) {
let (actual, expected) = (&self.actual, &self.expected);
assert_eq!(actual.uri(), expected.uri());
assert_eq!(expected.uri(), actual.uri());
for (name, value) in expected.headers() {
if !ignore_headers.contains(name) {
let actual_header = actual

View File

@ -651,6 +651,11 @@ impl ConnectorError {
}
}
/// Grants ownership of this error's source.
pub fn into_source(self) -> BoxError {
self.source
}
/// Returns metadata about the connection
///
/// If a connection was established and provided by the internal connector, a connection will

View File

@ -9,6 +9,7 @@ use crate::box_error::BoxError;
use crate::client::identity::{Identity, SharedIdentityResolver};
use crate::client::orchestrator::HttpRequest;
use crate::client::runtime_components::{GetIdentityResolver, RuntimeComponents};
use crate::impl_shared_conversions;
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
use aws_smithy_types::type_erasure::TypeErasedBox;
use aws_smithy_types::Document;
@ -120,6 +121,12 @@ impl AuthSchemeOptionResolver for SharedAuthSchemeOptionResolver {
}
}
impl_shared_conversions!(
convert SharedAuthSchemeOptionResolver
from AuthSchemeOptionResolver
using SharedAuthSchemeOptionResolver::new
);
/// An auth scheme.
///
/// Auth schemes have unique identifiers (the `scheme_id`),
@ -177,6 +184,8 @@ impl AuthScheme for SharedAuthScheme {
}
}
impl_shared_conversions!(convert SharedAuthScheme from AuthScheme using SharedAuthScheme::new);
/// Signing implementation for an auth scheme.
pub trait Signer: Send + Sync + fmt::Debug {
/// Sign the given request with the given identity, components, and config.

View File

@ -31,6 +31,7 @@
//! The Smithy clients have no knowledge of such concepts.
use crate::client::orchestrator::{HttpRequest, HttpResponse};
use crate::impl_shared_conversions;
use aws_smithy_async::future::now_or_later::NowOrLater;
use aws_smithy_http::result::ConnectorError;
use pin_project_lite::pin_project;
@ -117,3 +118,5 @@ impl HttpConnector for SharedHttpConnector {
(*self.0).call(request)
}
}
impl_shared_conversions!(convert SharedHttpConnector from HttpConnector using SharedHttpConnector::new);

View File

@ -6,6 +6,7 @@
//! APIs needed to configure endpoint resolution for clients.
use crate::client::orchestrator::Future;
use crate::impl_shared_conversions;
use aws_smithy_types::config_bag::{Storable, StoreReplace};
use aws_smithy_types::endpoint::Endpoint;
use aws_smithy_types::type_erasure::TypeErasedBox;
@ -60,3 +61,5 @@ impl EndpointResolver for SharedEndpointResolver {
self.0.resolve_endpoint(params)
}
}
impl_shared_conversions!(convert SharedEndpointResolver from EndpointResolver using SharedEndpointResolver::new);

View File

@ -5,6 +5,7 @@
use crate::client::auth::AuthSchemeId;
use crate::client::orchestrator::Future;
use crate::impl_shared_conversions;
use aws_smithy_types::config_bag::ConfigBag;
use std::any::Any;
use std::fmt;
@ -48,6 +49,8 @@ impl IdentityResolver for SharedIdentityResolver {
}
}
impl_shared_conversions!(convert SharedIdentityResolver from IdentityResolver using SharedIdentityResolver::new);
/// An identity resolver paired with an auth scheme ID that it resolves for.
#[derive(Clone, Debug)]
pub(crate) struct ConfiguredIdentityResolver {

View File

@ -19,12 +19,12 @@ use crate::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
use std::fmt;
use std::marker::PhantomData;
use std::ops::Deref;
use std::sync::Arc;
pub mod context;
pub mod error;
use crate::impl_shared_conversions;
pub use error::InterceptorError;
macro_rules! interceptor_trait_fn {
@ -618,18 +618,201 @@ impl SharedInterceptor {
}
}
impl AsRef<dyn Interceptor> for SharedInterceptor {
fn as_ref(&self) -> &(dyn Interceptor + 'static) {
self.interceptor.as_ref()
impl Interceptor for SharedInterceptor {
fn name(&self) -> &'static str {
self.interceptor.name()
}
fn modify_before_attempt_completion(
&self,
context: &mut FinalizerInterceptorContextMut<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.interceptor
.modify_before_attempt_completion(context, runtime_components, cfg)
}
fn modify_before_completion(
&self,
context: &mut FinalizerInterceptorContextMut<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.interceptor
.modify_before_completion(context, runtime_components, cfg)
}
fn modify_before_deserialization(
&self,
context: &mut BeforeDeserializationInterceptorContextMut<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.interceptor
.modify_before_deserialization(context, runtime_components, cfg)
}
fn modify_before_retry_loop(
&self,
context: &mut BeforeTransmitInterceptorContextMut<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.interceptor
.modify_before_retry_loop(context, runtime_components, cfg)
}
fn modify_before_serialization(
&self,
context: &mut BeforeSerializationInterceptorContextMut<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.interceptor
.modify_before_serialization(context, runtime_components, cfg)
}
fn modify_before_signing(
&self,
context: &mut BeforeTransmitInterceptorContextMut<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.interceptor
.modify_before_signing(context, runtime_components, cfg)
}
fn modify_before_transmit(
&self,
context: &mut BeforeTransmitInterceptorContextMut<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.interceptor
.modify_before_transmit(context, runtime_components, cfg)
}
fn read_after_attempt(
&self,
context: &FinalizerInterceptorContextRef<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.interceptor
.read_after_attempt(context, runtime_components, cfg)
}
fn read_after_deserialization(
&self,
context: &AfterDeserializationInterceptorContextRef<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.interceptor
.read_after_deserialization(context, runtime_components, cfg)
}
fn read_after_execution(
&self,
context: &FinalizerInterceptorContextRef<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.interceptor
.read_after_execution(context, runtime_components, cfg)
}
fn read_after_serialization(
&self,
context: &BeforeTransmitInterceptorContextRef<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.interceptor
.read_after_serialization(context, runtime_components, cfg)
}
fn read_after_signing(
&self,
context: &BeforeTransmitInterceptorContextRef<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.interceptor
.read_after_signing(context, runtime_components, cfg)
}
fn read_after_transmit(
&self,
context: &BeforeDeserializationInterceptorContextRef<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.interceptor
.read_after_transmit(context, runtime_components, cfg)
}
fn read_before_attempt(
&self,
context: &BeforeTransmitInterceptorContextRef<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.interceptor
.read_before_attempt(context, runtime_components, cfg)
}
fn read_before_deserialization(
&self,
context: &BeforeDeserializationInterceptorContextRef<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.interceptor
.read_before_deserialization(context, runtime_components, cfg)
}
fn read_before_execution(
&self,
context: &BeforeSerializationInterceptorContextRef<'_>,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.interceptor.read_before_execution(context, cfg)
}
fn read_before_serialization(
&self,
context: &BeforeSerializationInterceptorContextRef<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.interceptor
.read_before_serialization(context, runtime_components, cfg)
}
fn read_before_signing(
&self,
context: &BeforeTransmitInterceptorContextRef<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.interceptor
.read_before_signing(context, runtime_components, cfg)
}
fn read_before_transmit(
&self,
context: &BeforeTransmitInterceptorContextRef<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.interceptor
.read_before_transmit(context, runtime_components, cfg)
}
}
impl Deref for SharedInterceptor {
type Target = Arc<dyn Interceptor>;
fn deref(&self) -> &Self::Target {
&self.interceptor
}
}
impl_shared_conversions!(convert SharedInterceptor from Interceptor using SharedInterceptor::new);
/// Generalized interceptor disabling interface
///

View File

@ -25,7 +25,8 @@ use aws_smithy_http::body::SdkBody;
use aws_smithy_http::result::{ConnectorError, SdkError};
use aws_smithy_types::config_bag::{Storable, StoreReplace};
use bytes::Bytes;
use std::fmt::Debug;
use std::error::Error as StdError;
use std::fmt;
use std::future::Future as StdFuture;
use std::pin::Pin;
@ -244,6 +245,35 @@ impl<E> OrchestratorError<E> {
}
}
impl<E> StdError for OrchestratorError<E>
where
E: StdError + 'static,
{
fn source(&self) -> Option<&(dyn StdError + 'static)> {
Some(match &self.kind {
ErrorKind::Connector { source } => source as _,
ErrorKind::Operation { err } => err as _,
ErrorKind::Interceptor { source } => source as _,
ErrorKind::Response { source } => source.as_ref(),
ErrorKind::Timeout { source } => source.as_ref(),
ErrorKind::Other { source } => source.as_ref(),
})
}
}
impl<E> fmt::Display for OrchestratorError<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self.kind {
ErrorKind::Connector { .. } => "connector error",
ErrorKind::Operation { .. } => "operation error",
ErrorKind::Interceptor { .. } => "interceptor error",
ErrorKind::Response { .. } => "response error",
ErrorKind::Timeout { .. } => "timeout",
ErrorKind::Other { .. } => "an unknown error occurred",
})
}
}
fn convert_dispatch_error<O>(
err: BoxError,
response: Option<HttpResponse>,
@ -262,7 +292,7 @@ fn convert_dispatch_error<O>(
impl<E> From<InterceptorError> for OrchestratorError<E>
where
E: Debug + std::error::Error + 'static,
E: fmt::Debug + std::error::Error + 'static,
{
fn from(err: InterceptorError) -> Self {
Self::interceptor(err)

View File

@ -98,6 +98,8 @@ impl RetryStrategy for SharedRetryStrategy {
}
}
impl_shared_conversions!(convert SharedRetryStrategy from RetryStrategy using SharedRetryStrategy::new);
/// Classification result from [`ClassifyRetry`].
#[non_exhaustive]
#[derive(Clone, Eq, PartialEq, Debug)]
@ -231,6 +233,7 @@ mod test_util {
use crate::box_error::BoxError;
use crate::client::runtime_components::RuntimeComponents;
use crate::impl_shared_conversions;
use std::sync::Arc;
#[cfg(feature = "test-util")]
pub use test_util::AlwaysRetry;

View File

@ -19,6 +19,7 @@ use crate::client::endpoint::SharedEndpointResolver;
use crate::client::identity::{ConfiguredIdentityResolver, SharedIdentityResolver};
use crate::client::interceptors::SharedInterceptor;
use crate::client::retries::{RetryClassifiers, SharedRetryStrategy};
use crate::shared::IntoShared;
use aws_smithy_async::rt::sleep::SharedAsyncSleep;
use aws_smithy_async::time::SharedTimeSource;
use std::fmt;
@ -273,17 +274,17 @@ impl RuntimeComponentsBuilder {
/// Sets the auth scheme option resolver.
pub fn set_auth_scheme_option_resolver(
&mut self,
auth_scheme_option_resolver: Option<SharedAuthSchemeOptionResolver>,
auth_scheme_option_resolver: Option<impl IntoShared<SharedAuthSchemeOptionResolver>>,
) -> &mut Self {
self.auth_scheme_option_resolver =
auth_scheme_option_resolver.map(|r| Tracked::new(self.builder_name, r));
auth_scheme_option_resolver.map(|r| Tracked::new(self.builder_name, r.into_shared()));
self
}
/// Sets the auth scheme option resolver.
pub fn with_auth_scheme_option_resolver(
mut self,
auth_scheme_option_resolver: Option<SharedAuthSchemeOptionResolver>,
auth_scheme_option_resolver: Option<impl IntoShared<SharedAuthSchemeOptionResolver>>,
) -> Self {
self.set_auth_scheme_option_resolver(auth_scheme_option_resolver);
self
@ -295,13 +296,19 @@ impl RuntimeComponentsBuilder {
}
/// Sets the HTTP connector.
pub fn set_http_connector(&mut self, connector: Option<SharedHttpConnector>) -> &mut Self {
self.http_connector = connector.map(|c| Tracked::new(self.builder_name, c));
pub fn set_http_connector(
&mut self,
connector: Option<impl IntoShared<SharedHttpConnector>>,
) -> &mut Self {
self.http_connector = connector.map(|c| Tracked::new(self.builder_name, c.into_shared()));
self
}
/// Sets the HTTP connector.
pub fn with_http_connector(mut self, connector: Option<SharedHttpConnector>) -> Self {
pub fn with_http_connector(
mut self,
connector: Option<impl IntoShared<SharedHttpConnector>>,
) -> Self {
self.set_http_connector(connector);
self
}
@ -314,16 +321,17 @@ impl RuntimeComponentsBuilder {
/// Sets the endpoint resolver.
pub fn set_endpoint_resolver(
&mut self,
endpoint_resolver: Option<SharedEndpointResolver>,
endpoint_resolver: Option<impl IntoShared<SharedEndpointResolver>>,
) -> &mut Self {
self.endpoint_resolver = endpoint_resolver.map(|s| Tracked::new(self.builder_name, s));
self.endpoint_resolver =
endpoint_resolver.map(|s| Tracked::new(self.builder_name, s.into_shared()));
self
}
/// Sets the endpoint resolver.
pub fn with_endpoint_resolver(
mut self,
endpoint_resolver: Option<SharedEndpointResolver>,
endpoint_resolver: Option<impl IntoShared<SharedEndpointResolver>>,
) -> Self {
self.set_endpoint_resolver(endpoint_resolver);
self
@ -335,14 +343,17 @@ impl RuntimeComponentsBuilder {
}
/// Adds an auth scheme.
pub fn push_auth_scheme(&mut self, auth_scheme: SharedAuthScheme) -> &mut Self {
pub fn push_auth_scheme(
&mut self,
auth_scheme: impl IntoShared<SharedAuthScheme>,
) -> &mut Self {
self.auth_schemes
.push(Tracked::new(self.builder_name, auth_scheme));
.push(Tracked::new(self.builder_name, auth_scheme.into_shared()));
self
}
/// Adds an auth scheme.
pub fn with_auth_scheme(mut self, auth_scheme: SharedAuthScheme) -> Self {
pub fn with_auth_scheme(mut self, auth_scheme: impl IntoShared<SharedAuthScheme>) -> Self {
self.push_auth_scheme(auth_scheme);
self
}
@ -351,11 +362,11 @@ impl RuntimeComponentsBuilder {
pub fn push_identity_resolver(
&mut self,
scheme_id: AuthSchemeId,
identity_resolver: SharedIdentityResolver,
identity_resolver: impl IntoShared<SharedIdentityResolver>,
) -> &mut Self {
self.identity_resolvers.push(Tracked::new(
self.builder_name,
ConfiguredIdentityResolver::new(scheme_id, identity_resolver),
ConfiguredIdentityResolver::new(scheme_id, identity_resolver.into_shared()),
));
self
}
@ -364,7 +375,7 @@ impl RuntimeComponentsBuilder {
pub fn with_identity_resolver(
mut self,
scheme_id: AuthSchemeId,
identity_resolver: SharedIdentityResolver,
identity_resolver: impl IntoShared<SharedIdentityResolver>,
) -> Self {
self.push_identity_resolver(scheme_id, identity_resolver);
self
@ -386,14 +397,17 @@ impl RuntimeComponentsBuilder {
}
/// Adds an interceptor.
pub fn push_interceptor(&mut self, interceptor: SharedInterceptor) -> &mut Self {
pub fn push_interceptor(
&mut self,
interceptor: impl IntoShared<SharedInterceptor>,
) -> &mut Self {
self.interceptors
.push(Tracked::new(self.builder_name, interceptor));
.push(Tracked::new(self.builder_name, interceptor.into_shared()));
self
}
/// Adds an interceptor.
pub fn with_interceptor(mut self, interceptor: SharedInterceptor) -> Self {
pub fn with_interceptor(mut self, interceptor: impl IntoShared<SharedInterceptor>) -> Self {
self.push_interceptor(interceptor);
self
}
@ -444,14 +458,22 @@ impl RuntimeComponentsBuilder {
}
/// Sets the retry strategy.
pub fn set_retry_strategy(&mut self, retry_strategy: Option<SharedRetryStrategy>) -> &mut Self {
self.retry_strategy = retry_strategy.map(|s| Tracked::new(self.builder_name, s));
pub fn set_retry_strategy(
&mut self,
retry_strategy: Option<impl IntoShared<SharedRetryStrategy>>,
) -> &mut Self {
self.retry_strategy =
retry_strategy.map(|s| Tracked::new(self.builder_name, s.into_shared()));
self
}
/// Sets the retry strategy.
pub fn with_retry_strategy(mut self, retry_strategy: Option<SharedRetryStrategy>) -> Self {
self.retry_strategy = retry_strategy.map(|s| Tracked::new(self.builder_name, s));
pub fn with_retry_strategy(
mut self,
retry_strategy: Option<impl IntoShared<SharedRetryStrategy>>,
) -> Self {
self.retry_strategy =
retry_strategy.map(|s| Tracked::new(self.builder_name, s.into_shared()));
self
}
@ -617,13 +639,13 @@ impl RuntimeComponentsBuilder {
}
Self::new("aws_smithy_runtime_api::client::runtime_components::RuntimeComponentBuilder::for_tests")
.with_auth_scheme(SharedAuthScheme::new(FakeAuthScheme))
.with_auth_scheme_option_resolver(Some(SharedAuthSchemeOptionResolver::new(FakeAuthSchemeOptionResolver)))
.with_endpoint_resolver(Some(SharedEndpointResolver::new(FakeEndpointResolver)))
.with_http_connector(Some(SharedHttpConnector::new(FakeConnector)))
.with_identity_resolver(AuthSchemeId::new("fake"), SharedIdentityResolver::new(FakeIdentityResolver))
.with_auth_scheme(FakeAuthScheme)
.with_auth_scheme_option_resolver(Some(FakeAuthSchemeOptionResolver))
.with_endpoint_resolver(Some(FakeEndpointResolver))
.with_http_connector(Some(FakeConnector))
.with_identity_resolver(AuthSchemeId::new("fake"), FakeIdentityResolver)
.with_retry_classifiers(Some(RetryClassifiers::new()))
.with_retry_strategy(Some(SharedRetryStrategy::new(FakeRetryStrategy)))
.with_retry_strategy(Some(FakeRetryStrategy))
.with_sleep_impl(Some(SharedAsyncSleep::new(FakeSleep)))
.with_time_source(Some(SharedTimeSource::new(FakeTimeSource)))
}

View File

@ -22,6 +22,8 @@ use crate::box_error::BoxError;
use crate::client::runtime_components::{
RuntimeComponentsBuilder, EMPTY_RUNTIME_COMPONENTS_BUILDER,
};
use crate::impl_shared_conversions;
use crate::shared::IntoShared;
use aws_smithy_types::config_bag::{ConfigBag, FrozenLayer};
use std::borrow::Cow;
use std::fmt::Debug;
@ -56,7 +58,7 @@ pub enum Order {
/// Runtime plugin trait
///
/// A `RuntimePlugin` is the unit of configuration for augmenting the SDK with new behavior.
/// A `RuntimePlugin` is the unit of configuration for augmenting the client with new behavior.
///
/// Runtime plugins can register interceptors, set runtime components, and modify configuration.
pub trait RuntimePlugin: Debug + Send + Sync {
@ -113,7 +115,7 @@ pub trait RuntimePlugin: Debug + Send + Sync {
pub struct SharedRuntimePlugin(Arc<dyn RuntimePlugin>);
impl SharedRuntimePlugin {
/// Returns a new [`SharedRuntimePlugin`].
/// Creates a new [`SharedRuntimePlugin`].
pub fn new(plugin: impl RuntimePlugin + 'static) -> Self {
Self(Arc::new(plugin))
}
@ -136,6 +138,8 @@ impl RuntimePlugin for SharedRuntimePlugin {
}
}
impl_shared_conversions!(convert SharedRuntimePlugin from RuntimePlugin using SharedRuntimePlugin::new);
/// Runtime plugin that simply returns the config and components given at construction time.
#[derive(Default, Debug)]
pub struct StaticRuntimePlugin {
@ -185,15 +189,16 @@ impl RuntimePlugin for StaticRuntimePlugin {
self.runtime_components
.as_ref()
.map(Cow::Borrowed)
.unwrap_or_else(|| RuntimePlugin::runtime_components(self, _current_components))
.unwrap_or_else(|| Cow::Borrowed(&EMPTY_RUNTIME_COMPONENTS_BUILDER))
}
}
macro_rules! insert_plugin {
($vec:expr, $plugin:ident, $create_rp:expr) => {{
($vec:expr, $plugin:expr) => {{
// Insert the plugin in the correct order
let plugin = $plugin;
let mut insert_index = 0;
let order = $plugin.order();
let order = plugin.order();
for (index, other_plugin) in $vec.iter().enumerate() {
let other_order = other_plugin.order();
if other_order <= order {
@ -202,7 +207,7 @@ macro_rules! insert_plugin {
break;
}
}
$vec.insert(insert_index, $create_rp);
$vec.insert(insert_index, plugin);
}};
}
@ -235,21 +240,13 @@ impl RuntimePlugins {
Default::default()
}
pub fn with_client_plugin(mut self, plugin: impl RuntimePlugin + 'static) -> Self {
insert_plugin!(
self.client_plugins,
plugin,
SharedRuntimePlugin::new(plugin)
);
pub fn with_client_plugin(mut self, plugin: impl IntoShared<SharedRuntimePlugin>) -> Self {
insert_plugin!(self.client_plugins, plugin.into_shared());
self
}
pub fn with_operation_plugin(mut self, plugin: impl RuntimePlugin + 'static) -> Self {
insert_plugin!(
self.operation_plugins,
plugin,
SharedRuntimePlugin::new(plugin)
);
pub fn with_operation_plugin(mut self, plugin: impl IntoShared<SharedRuntimePlugin>) -> Self {
insert_plugin!(self.operation_plugins, plugin.into_shared());
self
}
@ -274,7 +271,7 @@ mod tests {
use crate::client::connectors::{HttpConnector, HttpConnectorFuture, SharedHttpConnector};
use crate::client::orchestrator::HttpRequest;
use crate::client::runtime_components::RuntimeComponentsBuilder;
use crate::client::runtime_plugin::Order;
use crate::client::runtime_plugin::{Order, SharedRuntimePlugin};
use aws_smithy_http::body::SdkBody;
use aws_smithy_types::config_bag::ConfigBag;
use http::HeaderValue;
@ -307,7 +304,7 @@ mod tests {
}
fn insert_plugin(vec: &mut Vec<RP>, plugin: RP) {
insert_plugin!(vec, plugin, plugin);
insert_plugin!(vec, plugin);
}
let mut vec = Vec::new();
@ -450,4 +447,25 @@ mod tests {
assert_eq!("1", resp.headers().get("rp1").unwrap());
assert_eq!("1", resp.headers().get("rp2").unwrap());
}
#[test]
fn shared_runtime_plugin_new_specialization() {
#[derive(Debug)]
struct RP;
impl RuntimePlugin for RP {}
use crate::shared::IntoShared;
let shared1 = SharedRuntimePlugin::new(RP);
let shared2: SharedRuntimePlugin = shared1.clone().into_shared();
assert_eq!(
"SharedRuntimePlugin(RP)",
format!("{shared1:?}"),
"precondition: RP shows up in the debug format"
);
assert_eq!(
format!("{shared1:?}"),
format!("{shared2:?}"),
"it should not nest the shared runtime plugins"
);
}
}

View File

@ -8,6 +8,7 @@
use crate::box_error::BoxError;
use crate::client::interceptors::context::{Error, Input, Output};
use crate::client::orchestrator::{HttpRequest, HttpResponse, OrchestratorError};
use crate::impl_shared_conversions;
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
use std::fmt;
use std::sync::Arc;
@ -48,6 +49,8 @@ impl Storable for SharedRequestSerializer {
type Storer = StoreReplace<Self>;
}
impl_shared_conversions!(convert SharedRequestSerializer from RequestSerializer using SharedRequestSerializer::new);
/// Deserialization implementation that converts an [`HttpResponse`] into an [`Output`] or [`Error`].
pub trait ResponseDeserializer: Send + Sync + fmt::Debug {
/// For streaming requests, deserializes the response headers.
@ -103,3 +106,5 @@ impl ResponseDeserializer for SharedResponseDeserializer {
impl Storable for SharedResponseDeserializer {
type Storer = StoreReplace<Self>;
}
impl_shared_conversions!(convert SharedResponseDeserializer from ResponseDeserializer using SharedResponseDeserializer::new);

View File

@ -34,3 +34,5 @@ pub mod client;
/// Internal builder macros. Not intended to be used outside of the aws-smithy-runtime crates.
#[doc(hidden)]
pub mod macros;
pub mod shared;

View File

@ -0,0 +1,224 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
//! Conversion traits for converting an unshared type into a shared type.
//!
//! The standard [`From`](std::convert::From)/[`Into`](std::convert::Into) traits can't be
//! used for this purpose due to the blanket implementation of `Into`.
//!
//! This implementation also adds a [`maybe_shared`] method and [`impl_shared_conversions`](crate::impl_shared_conversions)
//! macro to trivially avoid nesting shared types with other shared types.
//!
//! # What is a shared type?
//!
//! A shared type is a new-type around a `Send + Sync` reference counting smart pointer
//! (i.e., an [`Arc`](std::sync::Arc)) around an object-safe trait. Shared types are
//! used to share a trait object among multiple threads/clients/requests.
#![cfg_attr(
feature = "client",
doc = "
For example, [`SharedHttpConnector`](crate::client::connectors::SharedHttpConnector), is
a shared type for the [`HttpConnector`](crate::client::connectors::HttpConnector) trait,
which allows for sharing a single HTTP connector instance (and its connection pool) among multiple clients.
"
)]
//!
//! A shared type implements the [`FromUnshared`] trait, which allows any implementation
//! of the trait it wraps to easily be converted into it.
//!
#![cfg_attr(
feature = "client",
doc = "
To illustrate, let's examine the
[`RuntimePlugin`](crate::client::runtime_plugin::RuntimePlugin)/[`SharedRuntimePlugin`](crate::client::runtime_plugin::SharedRuntimePlugin)
duo.
The following instantiates a concrete implementation of the `RuntimePlugin` trait.
We can do `RuntimePlugin` things on this instance.
```rust,no_run
use aws_smithy_runtime_api::client::runtime_plugin::StaticRuntimePlugin;
let some_plugin = StaticRuntimePlugin::new();
```
We can convert this instance into a shared type in two different ways.
```rust,no_run
# use aws_smithy_runtime_api::client::runtime_plugin::StaticRuntimePlugin;
# let some_plugin = StaticRuntimePlugin::new();
use aws_smithy_runtime_api::client::runtime_plugin::SharedRuntimePlugin;
use aws_smithy_runtime_api::shared::{IntoShared, FromUnshared};
// Using the `IntoShared` trait
let shared: SharedRuntimePlugin = some_plugin.into_shared();
// Using the `FromUnshared` trait:
# let some_plugin = StaticRuntimePlugin::new();
let shared = SharedRuntimePlugin::from_unshared(some_plugin);
```
The `IntoShared` trait is useful for making functions that take any `RuntimePlugin` impl and convert it to a shared type.
For example, this function will convert the given `plugin` argument into a `SharedRuntimePlugin`.
```rust,no_run
# use aws_smithy_runtime_api::client::runtime_plugin::{SharedRuntimePlugin, StaticRuntimePlugin};
# use aws_smithy_runtime_api::shared::{IntoShared, FromUnshared};
fn take_shared(plugin: impl IntoShared<SharedRuntimePlugin>) {
let _plugin: SharedRuntimePlugin = plugin.into_shared();
}
```
This can be called with different types, and even if a `SharedRuntimePlugin` is passed in, it won't nest that
`SharedRuntimePlugin` inside of another `SharedRuntimePlugin`.
```rust,no_run
# use aws_smithy_runtime_api::client::runtime_plugin::{SharedRuntimePlugin, StaticRuntimePlugin};
# use aws_smithy_runtime_api::shared::{IntoShared, FromUnshared};
# fn take_shared(plugin: impl IntoShared<SharedRuntimePlugin>) {
# let _plugin: SharedRuntimePlugin = plugin.into_shared();
# }
// Automatically converts it to `SharedRuntimePlugin(StaticRuntimePlugin)`
take_shared(StaticRuntimePlugin::new());
// This is OK.
// It create a `SharedRuntimePlugin(StaticRuntimePlugin))`
// instead of a nested `SharedRuntimePlugin(SharedRuntimePlugin(StaticRuntimePlugin)))`
take_shared(SharedRuntimePlugin::new(StaticRuntimePlugin::new()));
```
"
)]
use std::any::{Any, TypeId};
/// Like the `From` trait, but for converting to a shared type.
///
/// See the [module docs](crate::shared) for information about shared types.
pub trait FromUnshared<Unshared> {
/// Creates a shared type from an unshared type.
fn from_unshared(value: Unshared) -> Self;
}
/// Like the `Into` trait, but for (efficiently) converting into a shared type.
///
/// If the type is already a shared type, it won't be nested in another shared type.
///
/// See the [module docs](crate::shared) for information about shared types.
pub trait IntoShared<Shared> {
/// Creates a shared type from an unshared type.
fn into_shared(self) -> Shared;
}
impl<Unshared, Shared> IntoShared<Shared> for Unshared
where
Shared: FromUnshared<Unshared>,
{
fn into_shared(self) -> Shared {
FromUnshared::from_unshared(self)
}
}
/// Given a `value`, determine if that value is already shared. If it is, return it. Otherwise, wrap it in a shared type.
///
/// See the [module docs](crate::shared) for information about shared types.
pub fn maybe_shared<Shared, MaybeShared, F>(value: MaybeShared, ctor: F) -> Shared
where
Shared: 'static,
MaybeShared: IntoShared<Shared> + 'static,
F: FnOnce(MaybeShared) -> Shared,
{
// Check if the type is already a shared type
if TypeId::of::<MaybeShared>() == TypeId::of::<Shared>() {
// Convince the compiler it is already a shared type and return it
let mut placeholder = Some(value);
let value: Shared = (&mut placeholder as &mut dyn Any)
.downcast_mut::<Option<Shared>>()
.expect("type checked above")
.take()
.expect("set to Some above");
value
} else {
(ctor)(value)
}
}
/// Implements `FromUnshared` for a shared type.
///
/// See the [`shared` module docs](crate::shared) for information about shared types.
///
/// # Example
/// ```rust,no_run
/// use aws_smithy_runtime_api::impl_shared_conversions;
/// use std::sync::Arc;
///
/// trait Thing {}
///
/// struct Thingamajig;
/// impl Thing for Thingamajig {}
///
/// struct SharedThing(Arc<dyn Thing>);
/// impl Thing for SharedThing {}
/// impl SharedThing {
/// fn new(thing: impl Thing + 'static) -> Self {
/// Self(Arc::new(thing))
/// }
/// }
/// impl_shared_conversions!(convert SharedThing from Thing using SharedThing::new);
/// ```
#[macro_export]
macro_rules! impl_shared_conversions {
(convert $shared_type:ident from $unshared_trait:ident using $ctor:expr) => {
impl<T> $crate::shared::FromUnshared<T> for $shared_type
where
T: $unshared_trait + 'static,
{
fn from_unshared(value: T) -> Self {
$crate::shared::maybe_shared(value, $ctor)
}
}
};
}
#[cfg(test)]
mod tests {
use super::*;
use std::fmt;
use std::sync::Arc;
trait Thing: fmt::Debug {}
#[derive(Debug)]
struct Thingamajig;
impl Thing for Thingamajig {}
#[derive(Debug)]
struct SharedThing(Arc<dyn Thing>);
impl Thing for SharedThing {}
impl SharedThing {
fn new(thing: impl Thing + 'static) -> Self {
Self(Arc::new(thing))
}
}
impl_shared_conversions!(convert SharedThing from Thing using SharedThing::new);
#[test]
fn test() {
let thing = Thingamajig;
assert_eq!("Thingamajig", format!("{thing:?}"), "precondition");
let shared_thing: SharedThing = thing.into_shared();
assert_eq!(
"SharedThing(Thingamajig)",
format!("{shared_thing:?}"),
"precondition"
);
let very_shared_thing: SharedThing = shared_thing.into_shared();
assert_eq!(
"SharedThing(Thingamajig)",
format!("{very_shared_thing:?}"),
"it should not nest the shared thing in another shared thing"
);
}
}

View File

@ -273,7 +273,7 @@ struct ConditionallyEnabledInterceptor(SharedInterceptor);
impl ConditionallyEnabledInterceptor {
fn if_enabled(&self, cfg: &ConfigBag) -> Option<&dyn Interceptor> {
if self.0.enabled(cfg) {
Some(self.0.as_ref())
Some(&self.0)
} else {
None
}

View File

@ -24,21 +24,22 @@ use tracing::trace;
/// An endpoint resolver that uses a static URI.
#[derive(Clone, Debug)]
pub struct StaticUriEndpointResolver {
endpoint: Uri,
endpoint: String,
}
impl StaticUriEndpointResolver {
/// Create a resolver that resolves to `http://localhost:{port}`.
pub fn http_localhost(port: u16) -> Self {
Self {
endpoint: Uri::from_str(&format!("http://localhost:{port}"))
.expect("all u16 values are valid ports"),
endpoint: format!("http://localhost:{port}"),
}
}
/// Create a resolver that resolves to the given URI.
pub fn uri(endpoint: Uri) -> Self {
Self { endpoint }
pub fn uri(endpoint: impl Into<String>) -> Self {
Self {
endpoint: endpoint.into(),
}
}
}

View File

@ -30,9 +30,9 @@ use aws_smithy_runtime_api::client::runtime_plugin::{
use aws_smithy_runtime_api::client::ser_de::{
RequestSerializer, ResponseDeserializer, SharedRequestSerializer, SharedResponseDeserializer,
};
use aws_smithy_runtime_api::shared::IntoShared;
use aws_smithy_types::config_bag::{ConfigBag, Layer};
use aws_smithy_types::retry::RetryConfig;
use http::Uri;
use std::borrow::Cow;
use std::fmt;
use std::marker::PhantomData;
@ -100,7 +100,7 @@ impl<F, O, E> fmt::Debug for FnDeserializer<F, O, E> {
/// Orchestrates execution of a HTTP request without any modeled input or output.
#[doc(hidden)]
#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct Operation<I, O, E> {
service_name: Cow<'static, str>,
operation_name: Cow<'static, str>,
@ -108,6 +108,18 @@ pub struct Operation<I, O, E> {
_phantom: PhantomData<(I, O, E)>,
}
// Manual Clone implementation needed to get rid of Clone bounds on I, O, and E
impl<I, O, E> Clone for Operation<I, O, E> {
fn clone(&self) -> Self {
Self {
service_name: self.service_name.clone(),
operation_name: self.operation_name.clone(),
runtime_plugins: self.runtime_plugins.clone(),
_phantom: self._phantom,
}
}
}
impl Operation<(), (), ()> {
pub fn builder() -> OperationBuilder {
OperationBuilder::new()
@ -177,7 +189,7 @@ impl<I, O, E> OperationBuilder<I, O, E> {
self
}
pub fn http_connector(mut self, connector: SharedHttpConnector) -> Self {
pub fn http_connector(mut self, connector: impl IntoShared<SharedHttpConnector>) -> Self {
self.runtime_components.set_http_connector(Some(connector));
self
}
@ -186,7 +198,7 @@ impl<I, O, E> OperationBuilder<I, O, E> {
self.config.store_put(EndpointResolverParams::new(()));
self.runtime_components
.set_endpoint_resolver(Some(SharedEndpointResolver::new(
StaticUriEndpointResolver::uri(Uri::try_from(url).expect("valid URI")),
StaticUriEndpointResolver::uri(url),
)));
self
}
@ -237,13 +249,13 @@ impl<I, O, E> OperationBuilder<I, O, E> {
self
}
pub fn interceptor(mut self, interceptor: SharedInterceptor) -> Self {
pub fn interceptor(mut self, interceptor: impl IntoShared<SharedInterceptor>) -> Self {
self.runtime_components.push_interceptor(interceptor);
self
}
pub fn runtime_plugin(mut self, runtime_plugin: SharedRuntimePlugin) -> Self {
self.runtime_plugins.push(runtime_plugin);
pub fn runtime_plugin(mut self, runtime_plugin: impl IntoShared<SharedRuntimePlugin>) -> Self {
self.runtime_plugins.push(runtime_plugin.into_shared());
self
}
@ -294,26 +306,6 @@ impl<I, O, E> OperationBuilder<I, O, E> {
pub fn build(self) -> Operation<I, O, E> {
let service_name = self.service_name.expect("service_name required");
let operation_name = self.operation_name.expect("operation_name required");
assert!(
self.runtime_components.http_connector().is_some(),
"a http_connector is required"
);
assert!(
self.runtime_components.endpoint_resolver().is_some(),
"a endpoint_resolver is required"
);
assert!(
self.runtime_components.retry_strategy().is_some(),
"a retry_strategy is required"
);
assert!(
self.config.load::<SharedRequestSerializer>().is_some(),
"a serializer is required"
);
assert!(
self.config.load::<SharedResponseDeserializer>().is_some(),
"a deserializer is required"
);
let mut runtime_plugins = RuntimePlugins::new().with_client_plugin(
StaticRuntimePlugin::new()
.with_config(self.config.freeze())
@ -323,6 +315,39 @@ impl<I, O, E> OperationBuilder<I, O, E> {
runtime_plugins = runtime_plugins.with_client_plugin(runtime_plugin);
}
#[cfg(debug_assertions)]
{
let mut config = ConfigBag::base();
let components = runtime_plugins
.apply_client_configuration(&mut config)
.expect("the runtime plugins should succeed");
assert!(
components.http_connector().is_some(),
"a http_connector is required"
);
assert!(
components.endpoint_resolver().is_some(),
"a endpoint_resolver is required"
);
assert!(
components.retry_strategy().is_some(),
"a retry_strategy is required"
);
assert!(
config.load::<SharedRequestSerializer>().is_some(),
"a serializer is required"
);
assert!(
config.load::<SharedResponseDeserializer>().is_some(),
"a deserializer is required"
);
assert!(
config.load::<EndpointResolverParams>().is_some(),
"endpoint resolver params are required"
);
}
Operation {
service_name,
operation_name,