Split runtime components out of config in the orchestrator impl (#2832)

This PR moves all the "runtime components", pieces that are core to the
operation of the orchestrator, into a separate `RuntimeComponents` type
for the orchestrator to reference directly.

The reason for this is so that these core components cannot be changed
by interceptors while the orchestrator is executing a request.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
This commit is contained in:
John DiSanti 2023-07-12 16:59:54 -07:00 committed by GitHub
parent 3ee63a8486
commit c1a1daeee0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
100 changed files with 2513 additions and 1583 deletions

View File

@ -43,6 +43,7 @@ tracing = "0.1"
aws-credential-types = { path = "../aws-credential-types", features = ["test-util"] }
aws-smithy-client = { path = "../../../rust-runtime/aws-smithy-client", features = ["test-util"] }
aws-smithy-http = { path = "../../../rust-runtime/aws-smithy-http", features = ["rt-tokio"] }
aws-smithy-runtime-api = { path = "../../../rust-runtime/aws-smithy-runtime-api", features = ["test-util"] }
tempfile = "3.6.0"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
aws-smithy-async = { path = "../../../rust-runtime/aws-smithy-async", features = ["test-util"] }

View File

@ -8,6 +8,7 @@
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut;
use aws_smithy_runtime_api::client::interceptors::Interceptor;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::ConfigBag;
use http::header::ACCEPT;
use http::HeaderValue;
@ -20,6 +21,7 @@ impl Interceptor for AcceptHeaderInterceptor {
fn modify_before_signing(
&self,
context: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
_cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
context

View File

@ -17,6 +17,7 @@ use aws_smithy_runtime_api::client::interceptors::context::{
};
use aws_smithy_runtime_api::client::interceptors::Interceptor;
use aws_smithy_runtime_api::client::orchestrator::LoadedRequestBody;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::ConfigBag;
use bytes::Bytes;
use http::header::{HeaderName, HeaderValue};
@ -71,6 +72,7 @@ impl<I: GlacierAccountId + Send + Sync + 'static> Interceptor
fn modify_before_serialization(
&self,
context: &mut BeforeSerializationInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
_cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let erased_input = context.input_mut();
@ -99,6 +101,7 @@ impl Interceptor for GlacierApiVersionInterceptor {
fn modify_before_signing(
&self,
context: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
_cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
context.request_mut().headers_mut().insert(
@ -117,6 +120,7 @@ impl Interceptor for GlacierTreeHashHeaderInterceptor {
fn modify_before_serialization(
&self,
_context: &mut BeforeSerializationInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
// Request the request body to be loaded into memory immediately after serialization
@ -129,6 +133,7 @@ impl Interceptor for GlacierTreeHashHeaderInterceptor {
fn modify_before_retry_loop(
&self,
context: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let maybe_loaded_body = cfg.load::<LoadedRequestBody>();
@ -237,6 +242,7 @@ fn compute_hash_tree(mut hashes: Vec<Digest>) -> Digest {
mod account_id_autofill_tests {
use super::*;
use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
use aws_smithy_types::type_erasure::TypedBox;
#[test]
@ -251,13 +257,14 @@ mod account_id_autofill_tests {
}
}
let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
let mut cfg = ConfigBag::base();
let mut context =
InterceptorContext::new(TypedBox::new(SomeInput { account_id: None }).erase());
let mut context = BeforeSerializationInterceptorContextMut::from(&mut context);
let interceptor = GlacierAccountIdAutofillInterceptor::<SomeInput>::new();
interceptor
.modify_before_serialization(&mut context, &mut cfg)
.modify_before_serialization(&mut context, &rc, &mut cfg)
.expect("success");
assert_eq!(
DEFAULT_ACCOUNT_ID,
@ -276,10 +283,12 @@ mod account_id_autofill_tests {
mod api_version_tests {
use super::*;
use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
use aws_smithy_types::type_erasure::TypedBox;
#[test]
fn api_version_interceptor() {
let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
let mut cfg = ConfigBag::base();
let mut context = InterceptorContext::new(TypedBox::new("dontcare").erase());
context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap());
@ -287,7 +296,7 @@ mod api_version_tests {
let interceptor = GlacierApiVersionInterceptor::new("some-version");
interceptor
.modify_before_signing(&mut context, &mut cfg)
.modify_before_signing(&mut context, &rc, &mut cfg)
.expect("success");
assert_eq!(

View File

@ -19,6 +19,7 @@ use aws_smithy_runtime_api::client::interceptors::context::{
BeforeSerializationInterceptorContextRef, BeforeTransmitInterceptorContextMut, Input,
};
use aws_smithy_runtime_api::client::interceptors::Interceptor;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::{ConfigBag, Layer, Storable, StoreReplace};
use http::HeaderValue;
use http_body::Body;
@ -81,6 +82,7 @@ where
fn read_before_serialization(
&self,
context: &BeforeSerializationInterceptorContextRef<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let checksum_algorithm = (self.algorithm_provider)(context.input())?;
@ -98,6 +100,7 @@ where
fn modify_before_retry_loop(
&self,
context: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let state = cfg

View File

@ -14,6 +14,7 @@ use aws_smithy_runtime_api::client::interceptors::context::{
BeforeDeserializationInterceptorContextMut, BeforeSerializationInterceptorContextRef, Input,
};
use aws_smithy_runtime_api::client::interceptors::Interceptor;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::{ConfigBag, Layer, Storable, StoreReplace};
use http::HeaderValue;
use std::{fmt, mem};
@ -58,6 +59,7 @@ where
fn read_before_serialization(
&self,
context: &BeforeSerializationInterceptorContextRef<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let validation_enabled = (self.validation_enabled)(context.input());
@ -72,6 +74,7 @@ where
fn modify_before_deserialization(
&self,
context: &mut BeforeDeserializationInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let state = cfg

View File

@ -15,16 +15,19 @@ use aws_sigv4::http_request::SignableBody;
use aws_smithy_async::time::{SharedTimeSource, StaticTimeSource};
use aws_smithy_runtime::client::retries::strategy::NeverRetryStrategy;
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::config_bag_accessors::ConfigBagAccessors;
use aws_smithy_runtime_api::client::interceptors::context::{
BeforeSerializationInterceptorContextMut, BeforeTransmitInterceptorContextMut,
};
use aws_smithy_runtime_api::client::interceptors::{
disable_interceptor, Interceptor, InterceptorRegistrar, SharedInterceptor,
disable_interceptor, Interceptor, SharedInterceptor,
};
use aws_smithy_runtime_api::client::retries::SharedRetryStrategy;
use aws_smithy_runtime_api::client::runtime_components::{
RuntimeComponents, RuntimeComponentsBuilder,
};
use aws_smithy_runtime_api::client::retries::DynRetryStrategy;
use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin;
use aws_smithy_types::config_bag::{ConfigBag, FrozenLayer, Layer};
use std::borrow::Cow;
/// Interceptor that tells the SigV4 signer to add the signature to query params,
/// and sets the request expiration time from the presigning config.
@ -47,6 +50,7 @@ impl Interceptor for SigV4PresigningInterceptor {
fn modify_before_serialization(
&self,
_context: &mut BeforeSerializationInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
cfg.interceptor_state()
@ -55,16 +59,13 @@ impl Interceptor for SigV4PresigningInterceptor {
.omit_default_content_length()
.omit_default_content_type(),
);
cfg.interceptor_state()
.set_request_time(SharedTimeSource::new(StaticTimeSource::new(
self.config.start_time(),
)));
Ok(())
}
fn modify_before_signing(
&self,
_context: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
if let Some(mut config) = cfg.load::<SigV4OperationSigningConfig>().cloned() {
@ -86,16 +87,20 @@ impl Interceptor for SigV4PresigningInterceptor {
/// Runtime plugin that registers the SigV4PresigningInterceptor.
#[derive(Debug)]
pub(crate) struct SigV4PresigningRuntimePlugin {
interceptor: SharedInterceptor,
runtime_components: RuntimeComponentsBuilder,
}
impl SigV4PresigningRuntimePlugin {
pub(crate) fn new(config: PresigningConfig, payload_override: SignableBody<'static>) -> Self {
let time_source = SharedTimeSource::new(StaticTimeSource::new(config.start_time()));
Self {
interceptor: SharedInterceptor::new(SigV4PresigningInterceptor::new(
config,
payload_override,
)),
runtime_components: RuntimeComponentsBuilder::new("SigV4PresigningRuntimePlugin")
.with_interceptor(SharedInterceptor::new(SigV4PresigningInterceptor::new(
config,
payload_override,
)))
.with_retry_strategy(Some(SharedRetryStrategy::new(NeverRetryStrategy::new())))
.with_time_source(Some(time_source)),
}
}
}
@ -103,14 +108,13 @@ impl SigV4PresigningRuntimePlugin {
impl RuntimePlugin for SigV4PresigningRuntimePlugin {
fn config(&self) -> Option<FrozenLayer> {
let mut layer = Layer::new("Presigning");
layer.set_retry_strategy(DynRetryStrategy::new(NeverRetryStrategy::new()));
layer.store_put(disable_interceptor::<InvocationIdInterceptor>("presigning"));
layer.store_put(disable_interceptor::<RequestInfoInterceptor>("presigning"));
layer.store_put(disable_interceptor::<UserAgentInterceptor>("presigning"));
Some(layer.freeze())
}
fn interceptors(&self, interceptors: &mut InterceptorRegistrar) {
interceptors.register(self.interceptor.clone());
fn runtime_components(&self) -> Cow<'_, RuntimeComponentsBuilder> {
Cow::Borrowed(&self.runtime_components)
}
}

View File

@ -8,6 +8,7 @@
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::interceptors::context::BeforeSerializationInterceptorContextMut;
use aws_smithy_runtime_api::client::interceptors::Interceptor;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::ConfigBag;
use std::fmt;
use std::marker::PhantomData;
@ -74,6 +75,7 @@ where
fn modify_before_serialization(
&self,
context: &mut BeforeSerializationInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
_cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let input: &mut T = context.input_mut().downcast_mut().expect("correct type");

View File

@ -12,11 +12,9 @@ use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::auth::{
AuthSchemeEndpointConfig, AuthSchemeId, HttpAuthScheme, HttpRequestSigner,
};
use aws_smithy_runtime_api::client::config_bag_accessors::ConfigBagAccessors;
use aws_smithy_runtime_api::client::identity::{
Identity, IdentityResolvers, SharedIdentityResolver,
};
use aws_smithy_runtime_api::client::identity::{Identity, SharedIdentityResolver};
use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
use aws_smithy_runtime_api::client::runtime_components::{GetIdentityResolver, RuntimeComponents};
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
use aws_smithy_types::Document;
use aws_types::region::{Region, SigningRegion};
@ -99,7 +97,7 @@ impl HttpAuthScheme for SigV4HttpAuthScheme {
fn identity_resolver(
&self,
identity_resolvers: &IdentityResolvers,
identity_resolvers: &dyn GetIdentityResolver,
) -> Option<SharedIdentityResolver> {
identity_resolvers.identity_resolver(self.scheme_id())
}
@ -319,11 +317,12 @@ impl HttpRequestSigner for SigV4HttpRequestSigner {
request: &mut HttpRequest,
identity: &Identity,
auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
runtime_components: &RuntimeComponents,
config_bag: &ConfigBag,
) -> Result<(), BoxError> {
let operation_config =
Self::extract_operation_config(auth_scheme_endpoint_config, config_bag)?;
let request_time = config_bag.request_time().unwrap_or_default().now();
let request_time = runtime_components.time_source().unwrap_or_default().now();
let credentials = if let Some(creds) = identity.data::<Credentials>() {
creds
@ -373,7 +372,7 @@ impl HttpRequestSigner for SigV4HttpRequestSigner {
use event_stream::SigV4MessageSigner;
if let Some(signer_sender) = config_bag.load::<DeferredSignerSender>() {
let time_source = config_bag.request_time().unwrap_or_default();
let time_source = runtime_components.time_source().unwrap_or_default();
signer_sender
.send(Box::new(SigV4MessageSigner::new(
_signature,

View File

@ -11,6 +11,7 @@ use http::{HeaderName, HeaderValue};
use std::fmt::Debug;
use uuid::Uuid;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
#[cfg(feature = "test-util")]
pub use test_util::{NoInvocationIdGenerator, PredefinedInvocationIdGenerator};
@ -61,6 +62,7 @@ impl Interceptor for InvocationIdInterceptor {
fn modify_before_retry_loop(
&self,
_ctx: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let id = cfg
@ -77,6 +79,7 @@ impl Interceptor for InvocationIdInterceptor {
fn modify_before_transmit(
&self,
ctx: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let headers = ctx.request_mut().headers_mut();
@ -184,6 +187,7 @@ mod tests {
BeforeTransmitInterceptorContextMut, InterceptorContext,
};
use aws_smithy_runtime_api::client::interceptors::Interceptor;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
use aws_smithy_types::config_bag::ConfigBag;
use aws_smithy_types::type_erasure::TypeErasedBox;
use http::HeaderValue;
@ -197,6 +201,7 @@ mod tests {
#[test]
fn test_id_is_generated_and_set() {
let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
let mut ctx = InterceptorContext::new(TypeErasedBox::doesnt_matter());
ctx.enter_serialization_phase();
ctx.set_request(http::Request::builder().body(SdkBody::empty()).unwrap());
@ -207,10 +212,10 @@ mod tests {
let interceptor = InvocationIdInterceptor::new();
let mut ctx = Into::into(&mut ctx);
interceptor
.modify_before_retry_loop(&mut ctx, &mut cfg)
.modify_before_retry_loop(&mut ctx, &rc, &mut cfg)
.unwrap();
interceptor
.modify_before_transmit(&mut ctx, &mut cfg)
.modify_before_transmit(&mut ctx, &rc, &mut cfg)
.unwrap();
let expected = cfg.load::<InvocationId>().expect("invocation ID was set");

View File

@ -6,6 +6,7 @@
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut;
use aws_smithy_runtime_api::client::interceptors::Interceptor;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::ConfigBag;
use aws_types::os_shim_internal::Env;
use http::HeaderValue;
@ -42,6 +43,7 @@ impl Interceptor for RecursionDetectionInterceptor {
fn modify_before_signing(
&self,
context: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
_cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let request = context.request_mut();
@ -75,6 +77,7 @@ mod tests {
use aws_smithy_http::body::SdkBody;
use aws_smithy_protocol_test::{assert_ok, validate_headers};
use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
use aws_smithy_types::type_erasure::TypeErasedBox;
use aws_types::os_shim_internal::Env;
use http::HeaderValue;
@ -142,6 +145,7 @@ mod tests {
}
fn check(test_case: TestCase) {
let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
let env = test_case.env();
let mut request = http::Request::builder();
for (name, value) in test_case.request_headers_before() {
@ -157,7 +161,7 @@ mod tests {
let mut ctx = Into::into(&mut context);
RecursionDetectionInterceptor { env }
.modify_before_signing(&mut ctx, &mut config)
.modify_before_signing(&mut ctx, &rc, &mut config)
.expect("interceptor must succeed");
let mutated_request = context.request().expect("request is set");
for name in mutated_request.headers().keys() {

View File

@ -8,6 +8,7 @@ use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut;
use aws_smithy_runtime_api::client::interceptors::Interceptor;
use aws_smithy_runtime_api::client::request_attempts::RequestAttempts;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::ConfigBag;
use aws_smithy_types::date_time::Format;
use aws_smithy_types::retry::RetryConfig;
@ -89,6 +90,7 @@ impl Interceptor for RequestInfoInterceptor {
fn modify_before_transmit(
&self,
context: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let mut pairs = RequestPairs::new();
@ -166,6 +168,7 @@ mod tests {
use aws_smithy_http::body::SdkBody;
use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
use aws_smithy_runtime_api::client::interceptors::Interceptor;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
use aws_smithy_types::config_bag::{ConfigBag, Layer};
use aws_smithy_types::retry::RetryConfig;
use aws_smithy_types::timeout::TimeoutConfig;
@ -186,6 +189,7 @@ mod tests {
#[test]
fn test_request_pairs_for_initial_attempt() {
let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
let mut context = InterceptorContext::new(TypeErasedBox::doesnt_matter());
context.enter_serialization_phase();
context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap());
@ -204,7 +208,7 @@ mod tests {
let interceptor = RequestInfoInterceptor::new();
let mut ctx = (&mut context).into();
interceptor
.modify_before_transmit(&mut ctx, &mut config)
.modify_before_transmit(&mut ctx, &rc, &mut config)
.unwrap();
assert_eq!(

View File

@ -7,6 +7,7 @@ use aws_http::user_agent::{ApiMetadata, AwsUserAgent};
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut;
use aws_smithy_runtime_api::client::interceptors::Interceptor;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::ConfigBag;
use aws_types::app_name::AppName;
use aws_types::os_shim_internal::Env;
@ -74,6 +75,7 @@ impl Interceptor for UserAgentInterceptor {
fn modify_before_signing(
&self,
context: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let api_metadata = cfg
@ -110,6 +112,7 @@ mod tests {
use aws_smithy_http::body::SdkBody;
use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
use aws_smithy_runtime_api::client::interceptors::Interceptor;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
use aws_smithy_types::config_bag::{ConfigBag, Layer};
use aws_smithy_types::error::display::DisplayErrorContext;
use aws_smithy_types::type_erasure::TypeErasedBox;
@ -136,6 +139,7 @@ mod tests {
#[test]
fn test_overridden_ua() {
let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
let mut context = context();
let mut layer = Layer::new("test");
@ -146,7 +150,7 @@ mod tests {
let interceptor = UserAgentInterceptor::new();
let mut ctx = Into::into(&mut context);
interceptor
.modify_before_signing(&mut ctx, &mut cfg)
.modify_before_signing(&mut ctx, &rc, &mut cfg)
.unwrap();
let header = expect_header(&context, "user-agent");
@ -161,6 +165,7 @@ mod tests {
#[test]
fn test_default_ua() {
let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
let mut context = context();
let api_metadata = ApiMetadata::new("some-service", "some-version");
@ -171,7 +176,7 @@ mod tests {
let interceptor = UserAgentInterceptor::new();
let mut ctx = Into::into(&mut context);
interceptor
.modify_before_signing(&mut ctx, &mut config)
.modify_before_signing(&mut ctx, &rc, &mut config)
.unwrap();
let expected_ua = AwsUserAgent::new_from_environment(Env::real(), api_metadata);
@ -191,6 +196,7 @@ mod tests {
#[test]
fn test_app_name() {
let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
let mut context = context();
let api_metadata = ApiMetadata::new("some-service", "some-version");
@ -202,7 +208,7 @@ mod tests {
let interceptor = UserAgentInterceptor::new();
let mut ctx = Into::into(&mut context);
interceptor
.modify_before_signing(&mut ctx, &mut config)
.modify_before_signing(&mut ctx, &rc, &mut config)
.unwrap();
let app_value = "app/my_awesome_app";
@ -221,6 +227,7 @@ mod tests {
#[test]
fn test_api_metadata_missing() {
let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
let mut context = context();
let mut config = ConfigBag::base();
@ -231,7 +238,7 @@ mod tests {
"{}",
DisplayErrorContext(
&*interceptor
.modify_before_signing(&mut ctx, &mut config)
.modify_before_signing(&mut ctx, &rc, &mut config)
.expect_err("it should error")
)
);

View File

@ -87,12 +87,11 @@ if (isTestingEnabled.toBoolean()) {
tasks.test {
useJUnitPlatform()
testLogging {
events("passed", "skipped", "failed")
events("failed")
exceptionFormat = TestExceptionFormat.FULL
showCauses = true
showExceptions = true
showStackTraces = true
showStandardStreams = true
}
}

View File

@ -25,6 +25,9 @@ class CustomizableOperationTestHelpers(runtimeConfig: RuntimeConfig) :
"ConfigBagAccessors" to RuntimeType.configBagAccessors(runtimeConfig),
"http" to CargoDependency.Http.toType(),
"InterceptorContext" to RuntimeType.interceptorContext(runtimeConfig),
"StaticRuntimePlugin" to RuntimeType.smithyRuntimeApi(runtimeConfig)
.resolve("client::runtime_plugin::StaticRuntimePlugin"),
"RuntimeComponentsBuilder" to RuntimeType.runtimeComponentsBuilder(runtimeConfig),
"SharedTimeSource" to CargoDependency.smithyAsync(runtimeConfig).withFeature("test-util").toType()
.resolve("time::SharedTimeSource"),
"SharedInterceptor" to RuntimeType.smithyRuntimeApi(runtimeConfig)
@ -41,13 +44,14 @@ class CustomizableOperationTestHelpers(runtimeConfig: RuntimeConfig) :
"""
##[doc(hidden)]
// This is a temporary method for testing. NEVER use it in production
pub fn request_time_for_tests(mut self, request_time: ::std::time::SystemTime) -> Self {
use #{ConfigBagAccessors};
let interceptor = #{TestParamsSetterInterceptor}::new(move |_: &mut #{BeforeTransmitInterceptorContextMut}<'_>, cfg: &mut #{ConfigBag}| {
cfg.interceptor_state().set_request_time(#{SharedTimeSource}::new(request_time));
});
self.interceptors.push(#{SharedInterceptor}::new(interceptor));
self
pub fn request_time_for_tests(self, request_time: ::std::time::SystemTime) -> Self {
self.runtime_plugin(
#{StaticRuntimePlugin}::new()
.with_runtime_components(
#{RuntimeComponentsBuilder}::new("request_time_for_tests")
.with_time_source(Some(#{SharedTimeSource}::new(request_time)))
)
)
}
##[doc(hidden)]

View File

@ -80,7 +80,7 @@ class CredentialCacheConfig(codegenContext: ClientCodegenContext) : ConfigCustom
"""
/// Returns the credentials cache.
pub fn credentials_cache(&self) -> #{Option}<#{SharedCredentialsCache}> {
self.inner.load::<#{SharedCredentialsCache}>().cloned()
self.config.load::<#{SharedCredentialsCache}>().cloned()
}
""",
*codegenScope,
@ -121,7 +121,7 @@ class CredentialCacheConfig(codegenContext: ClientCodegenContext) : ConfigCustom
"""
/// Sets the credentials cache for this service
pub fn set_credentials_cache(&mut self, credentials_cache: #{Option}<#{CredentialsCache}>) -> &mut Self {
self.inner.store_or_unset(credentials_cache);
self.config.store_or_unset(credentials_cache);
self
}
""",
@ -148,7 +148,7 @@ class CredentialCacheConfig(codegenContext: ClientCodegenContext) : ConfigCustom
if let Some(credentials_provider) = layer.load::<#{SharedCredentialsProvider}>().cloned() {
let cache_config = layer.load::<#{CredentialsCache}>().cloned()
.unwrap_or_else({
let sleep = layer.load::<#{SharedAsyncSleep}>().cloned();
let sleep = self.runtime_components.sleep_impl();
|| match sleep {
Some(sleep) => {
#{CredentialsCache}::lazy_builder()

View File

@ -101,7 +101,7 @@ class CredentialProviderConfig(codegenContext: ClientCodegenContext) : ConfigCus
"""
/// Sets the credentials provider for this service
pub fn set_credentials_provider(&mut self, credentials_provider: #{Option}<#{SharedCredentialsProvider}>) -> &mut Self {
self.inner.store_or_unset(credentials_provider);
self.config.store_or_unset(credentials_provider);
self
}
""",
@ -138,9 +138,9 @@ class CredentialsIdentityResolverRegistration(
override fun section(section: ServiceRuntimePluginSection): Writable = writable {
when (section) {
is ServiceRuntimePluginSection.AdditionalConfig -> {
is ServiceRuntimePluginSection.RegisterRuntimeComponents -> {
rustBlockTemplate("if let Some(credentials_cache) = ${section.serviceConfigName}.credentials_cache()") {
section.registerIdentityResolver(this, runtimeConfig) {
section.registerIdentityResolver(this) {
rustTemplate(
"""
#{SIGV4_SCHEME_ID},

View File

@ -55,7 +55,7 @@ class HttpConnectorConfigCustomization(
"""
/// Return an [`HttpConnector`](#{HttpConnector}) to use when making requests, if any.
pub fn http_connector(&self) -> Option<&#{HttpConnector}> {
self.inner.load::<#{HttpConnector}>()
self.config.load::<#{HttpConnector}>()
}
""",
*codegenScope,
@ -161,7 +161,7 @@ class HttpConnectorConfigCustomization(
rustTemplate(
"""
pub fn set_http_connector(&mut self, http_connector: #{Option}<impl #{Into}<#{HttpConnector}>>) -> &mut Self {
http_connector.map(|c| self.inner.store_put(c.into()));
http_connector.map(|c| self.config.store_put(c.into()));
self
}
""",

View File

@ -36,7 +36,7 @@ private class InvocationIdRuntimePluginCustomization(
)
override fun section(section: ServiceRuntimePluginSection): Writable = writable {
if (section is ServiceRuntimePluginSection.RegisterInterceptor) {
if (section is ServiceRuntimePluginSection.RegisterRuntimeComponents) {
section.registerInterceptor(codegenContext.runtimeConfig, this) {
rustTemplate("#{InvocationIdInterceptor}::new()", *codegenScope)
}

View File

@ -31,7 +31,7 @@ private class RecursionDetectionRuntimePluginCustomization(
private val codegenContext: ClientCodegenContext,
) : ServiceRuntimePluginCustomization() {
override fun section(section: ServiceRuntimePluginSection): Writable = writable {
if (section is ServiceRuntimePluginSection.RegisterInterceptor) {
if (section is ServiceRuntimePluginSection.RegisterRuntimeComponents) {
section.registerInterceptor(codegenContext.runtimeConfig, this) {
rust(
"#T::new()",

View File

@ -181,7 +181,7 @@ class RegionProviderConfig(codegenContext: ClientCodegenContext) : ConfigCustomi
"""
/// Returns the AWS region, if it was provided.
pub fn region(&self) -> #{Option}<&#{Region}> {
self.inner.load::<#{Region}>()
self.config.load::<#{Region}>()
}
""",
*codegenScope,
@ -232,7 +232,7 @@ class RegionProviderConfig(codegenContext: ClientCodegenContext) : ConfigCustomi
"""
/// Sets the AWS region to use when making requests.
pub fn set_region(&mut self, region: #{Option}<#{Region}>) -> &mut Self {
self.inner.store_or_unset(region);
self.config.store_or_unset(region);
self
}
""",

View File

@ -35,7 +35,7 @@ private class AddRetryInformationHeaderInterceptors(codegenContext: ClientCodege
private val awsRuntime = AwsRuntimeType.awsRuntime(runtimeConfig)
override fun section(section: ServiceRuntimePluginSection): Writable = writable {
if (section is ServiceRuntimePluginSection.RegisterInterceptor) {
if (section is ServiceRuntimePluginSection.RegisterRuntimeComponents) {
// Track the latency between client and server.
section.registerInterceptor(runtimeConfig, this) {
rust(

View File

@ -78,13 +78,13 @@ private class AuthServiceRuntimePluginCustomization(private val codegenContext:
override fun section(section: ServiceRuntimePluginSection): Writable = writable {
when (section) {
is ServiceRuntimePluginSection.AdditionalConfig -> {
is ServiceRuntimePluginSection.RegisterRuntimeComponents -> {
val serviceHasEventStream = codegenContext.serviceShape.hasEventStreamOperations(codegenContext.model)
if (serviceHasEventStream) {
// enable the aws-runtime `sign-eventstream` feature
addDependency(AwsCargoDependency.awsRuntime(runtimeConfig).withFeature("event-stream").toType().toSymbol())
}
section.registerHttpAuthScheme(this, runtimeConfig) {
section.registerHttpAuthScheme(this) {
rustTemplate("#{SharedHttpAuthScheme}::new(#{SigV4HttpAuthScheme}::new())", *codegenScope)
}
}

View File

@ -104,7 +104,7 @@ class UserAgentDecorator : ClientCodegenDecorator {
override fun section(section: ServiceRuntimePluginSection): Writable = writable {
when (section) {
is ServiceRuntimePluginSection.RegisterInterceptor -> {
is ServiceRuntimePluginSection.RegisterRuntimeComponents -> {
section.registerInterceptor(runtimeConfig, this) {
rust("#T::new()", awsRuntime.resolve("user_agent::UserAgentInterceptor"))
}
@ -182,7 +182,7 @@ class UserAgentDecorator : ClientCodegenDecorator {
/// This _optional_ name is used to identify the application in the user agent that
/// gets sent along with requests.
pub fn set_app_name(&mut self, app_name: #{Option}<#{AppName}>) -> &mut Self {
self.inner.store_or_unset(app_name);
self.config.store_or_unset(app_name);
self
}
""",
@ -228,7 +228,7 @@ class UserAgentDecorator : ClientCodegenDecorator {
/// This _optional_ name is used to identify the application in the user agent that
/// gets sent along with requests.
pub fn app_name(&self) -> #{Option}<&#{AppName}> {
self.inner.load::<#{AppName}>()
self.config.load::<#{AppName}>()
}
""",
*codegenScope,

View File

@ -66,7 +66,7 @@ private class ApiGatewayAddAcceptHeader : OperationCustomization() {
private class ApiGatewayAcceptHeaderInterceptorCustomization(private val codegenContext: ClientCodegenContext) :
ServiceRuntimePluginCustomization() {
override fun section(section: ServiceRuntimePluginSection): Writable = writable {
if (section is ServiceRuntimePluginSection.RegisterInterceptor) {
if (section is ServiceRuntimePluginSection.RegisterRuntimeComponents) {
section.registerInterceptor(codegenContext.runtimeConfig, this) {
rustTemplate(
"#{Interceptor}::default()",

View File

@ -101,7 +101,7 @@ private class GlacierAccountIdCustomization(private val codegenContext: ClientCo
private class GlacierApiVersionCustomization(private val codegenContext: ClientCodegenContext) :
ServiceRuntimePluginCustomization() {
override fun section(section: ServiceRuntimePluginSection): Writable = writable {
if (section is ServiceRuntimePluginSection.RegisterInterceptor) {
if (section is ServiceRuntimePluginSection.RegisterRuntimeComponents) {
val apiVersion = codegenContext.serviceShape.version
section.registerInterceptor(codegenContext.runtimeConfig, this) {
rustTemplate(

View File

@ -62,7 +62,8 @@ class TimestreamDecorator : ClientCodegenDecorator {
#{ResolveEndpointError}::from_source("failed to call describe_endpoints", e)
})?;
let endpoint = describe_endpoints.endpoints().unwrap().get(0).unwrap();
let expiry = client.conf().time_source().now() + #{Duration}::from_secs(endpoint.cache_period_in_minutes() as u64 * 60);
let expiry = client.conf().time_source().expect("checked when ep discovery was enabled").now()
+ #{Duration}::from_secs(endpoint.cache_period_in_minutes() as u64 * 60);
Ok((
#{Endpoint}::builder()
.url(format!("https://{}", endpoint.address().unwrap()))
@ -78,7 +79,7 @@ class TimestreamDecorator : ClientCodegenDecorator {
pub async fn enable_endpoint_discovery(self) -> #{Result}<(Self, #{endpoint_discovery}::ReloadEndpoint), #{ResolveEndpointError}> {
let mut new_conf = self.conf().clone();
let sleep = self.conf().sleep_impl().expect("sleep impl must be provided");
let time = self.conf().time_source();
let time = self.conf().time_source().expect("time source must be provided");
let (resolver, reloader) = #{endpoint_discovery}::create_cache(
move || {
let client = self.clone();

View File

@ -0,0 +1,62 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rustsdk.awsSdkIntegrationTest
class SdkCodegenIntegrationTest {
val model = """
namespace test
use aws.api#service
use aws.auth#sigv4
use aws.protocols#restJson1
use smithy.rules#endpointRuleSet
@service(sdkId: "dontcare")
@restJson1
@sigv4(name: "dontcare")
@auth([sigv4])
@endpointRuleSet({
"version": "1.0",
"rules": [{ "type": "endpoint", "conditions": [], "endpoint": { "url": "https://example.com" } }],
"parameters": {
"Region": { "required": false, "type": "String", "builtIn": "AWS::Region" },
}
})
service TestService {
version: "2023-01-01",
operations: [SomeOperation]
}
structure SomeOutput {
someAttribute: Long,
someVal: String
}
@http(uri: "/SomeOperation", method: "GET")
@optionalAuth
operation SomeOperation {
output: SomeOutput
}
""".asSmithyModel()
@Test
fun smokeTestSdkCodegen() {
awsSdkIntegrationTest(
model,
defaultToOrchestrator = true,
) { _, _ -> /* it should compile */ }
}
@Test
fun smokeTestSdkCodegenMiddleware() {
awsSdkIntegrationTest(
model,
defaultToOrchestrator = false,
) { _, _ -> /* it should compile */ }
}
}

View File

@ -77,10 +77,11 @@ internal class CredentialCacheConfigTest {
let client_config = crate::config::Config::builder().build();
let config_override =
crate::config::Config::builder().credentials_provider(#{Credentials}::for_tests());
let sut = crate::config::ConfigOverrideRuntimePlugin {
client_config: client_config.config().unwrap(),
let sut = crate::config::ConfigOverrideRuntimePlugin::new(
config_override,
};
client_config.config,
&client_config.runtime_components,
);
// this should cause `panic!`
let _ = sut.config().unwrap();
@ -100,10 +101,11 @@ internal class CredentialCacheConfigTest {
let client_config = crate::config::Config::builder().build();
let config_override = crate::config::Config::builder()
.credentials_cache(#{CredentialsCache}::no_caching());
let sut = crate::config::ConfigOverrideRuntimePlugin {
client_config: client_config.config().unwrap(),
let sut = crate::config::ConfigOverrideRuntimePlugin::new(
config_override,
};
client_config.config,
&client_config.runtime_components,
);
// this should cause `panic!`
let _ = sut.config().unwrap();
@ -121,7 +123,7 @@ internal class CredentialCacheConfigTest {
let client_config = crate::config::Config::builder()
.credentials_provider(#{Credentials}::for_tests())
.build();
let client_config_layer = client_config.config().unwrap();
let client_config_layer = client_config.config;
// make sure test credentials are set in the client config level
assert_eq!(#{Credentials}::for_tests(),
@ -143,10 +145,11 @@ internal class CredentialCacheConfigTest {
let config_override = crate::config::Config::builder()
.credentials_cache(#{CredentialsCache}::lazy())
.credentials_provider(credentials.clone());
let sut = crate::config::ConfigOverrideRuntimePlugin {
client_config: client_config_layer,
let sut = crate::config::ConfigOverrideRuntimePlugin::new(
config_override,
};
client_config_layer,
&client_config.runtime_components,
);
let sut_layer = sut.config().unwrap();
// make sure `.provide_cached_credentials` returns credentials set through `config_override`
@ -170,10 +173,11 @@ internal class CredentialCacheConfigTest {
let client_config = crate::config::Config::builder().build();
let config_override = crate::config::Config::builder();
let sut = crate::config::ConfigOverrideRuntimePlugin {
client_config: client_config.config().unwrap(),
let sut = crate::config::ConfigOverrideRuntimePlugin::new(
config_override,
};
client_config.config,
&client_config.runtime_components,
);
let sut_layer = sut.config().unwrap();
assert!(sut_layer
.load::<#{SharedCredentialsCache}>()

View File

@ -114,6 +114,8 @@ project.registerGenerateSmithyBuildTask(rootProject, pluginName, allCodegenTests
project.registerGenerateCargoWorkspaceTask(rootProject, pluginName, allCodegenTests, workingDirUnderBuildDir)
project.registerGenerateCargoConfigTomlTask(buildDir.resolve(workingDirUnderBuildDir))
tasks["generateSmithyBuild"].inputs.property("smithy.runtime.mode", getSmithyRuntimeMode())
tasks["smithyBuildJar"].dependsOn("generateSmithyBuild")
tasks["assemble"].finalizedBy("generateCargoWorkspace")

View File

@ -77,12 +77,11 @@ if (isTestingEnabled.toBoolean()) {
tasks.test {
useJUnitPlatform()
testLogging {
events("passed", "skipped", "failed")
events("failed")
exceptionFormat = TestExceptionFormat.FULL
showCauses = true
showExceptions = true
showStackTraces = true
showStandardStreams = true
}
}

View File

@ -45,7 +45,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolGenerat
import software.amazon.smithy.rust.codegen.core.smithy.transformers.EventStreamNormalizer
import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer
import software.amazon.smithy.rust.codegen.core.util.CommandFailed
import software.amazon.smithy.rust.codegen.core.util.CommandError
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.inputShape
@ -172,7 +172,7 @@ class ClientCodegenVisitor(
try {
// use an increased max_width to make rustfmt fail less frequently
"cargo fmt -- --config max_width=150".runCommand(fileManifest.baseDir, timeout = settings.codegenConfig.formatTimeoutSeconds.toLong())
} catch (err: CommandFailed) {
} catch (err: CommandError) {
logger.warning("Failed to run cargo fmt: [${service.id}]\n${err.output}")
}

View File

@ -206,7 +206,7 @@ private class ApiKeyConfigCustomization(codegenContext: ClientCodegenContext) :
"""
/// Returns API key used by the client, if it was provided.
pub fn api_key(&self) -> #{Option}<&#{ApiKey}> {
self.inner.load::<#{ApiKey}>()
self.config.load::<#{ApiKey}>()
}
""",
*codegenScope,

View File

@ -20,7 +20,7 @@ class ConnectionPoisoningRuntimePluginCustomization(
override fun section(section: ServiceRuntimePluginSection): Writable = writable {
when (section) {
is ServiceRuntimePluginSection.RegisterInterceptor -> {
is ServiceRuntimePluginSection.RegisterRuntimeComponents -> {
// This interceptor assumes that a compatible Connector is set. Otherwise, connection poisoning
// won't work and an error message will be logged.
section.registerInterceptor(runtimeConfig, this) {

View File

@ -32,7 +32,7 @@ import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.letIf
fun codegenScope(runtimeConfig: RuntimeConfig): Array<Pair<String, Any>> {
private fun codegenScope(runtimeConfig: RuntimeConfig): Array<Pair<String, Any>> {
val smithyRuntime =
CargoDependency.smithyRuntime(runtimeConfig).withFeature("http-auth").toType()
val smithyRuntimeApi = CargoDependency.smithyRuntimeApi(runtimeConfig).withFeature("http-auth").toType()
@ -42,7 +42,6 @@ fun codegenScope(runtimeConfig: RuntimeConfig): Array<Pair<String, Any>> {
"AuthSchemeId" to smithyRuntimeApi.resolve("client::auth::AuthSchemeId"),
"ApiKeyAuthScheme" to authHttp.resolve("ApiKeyAuthScheme"),
"ApiKeyLocation" to authHttp.resolve("ApiKeyLocation"),
"ConfigBagAccessors" to RuntimeType.configBagAccessors(runtimeConfig),
"BasicAuthScheme" to authHttp.resolve("BasicAuthScheme"),
"BearerAuthScheme" to authHttp.resolve("BearerAuthScheme"),
"DigestAuthScheme" to authHttp.resolve("DigestAuthScheme"),
@ -163,9 +162,9 @@ private class HttpAuthServiceRuntimePluginCustomization(
override fun section(section: ServiceRuntimePluginSection): Writable = writable {
when (section) {
is ServiceRuntimePluginSection.AdditionalConfig -> {
is ServiceRuntimePluginSection.RegisterRuntimeComponents -> {
fun registerAuthScheme(scheme: Writable) {
section.registerHttpAuthScheme(this, codegenContext.runtimeConfig) {
section.registerHttpAuthScheme(this) {
rustTemplate("#{SharedHttpAuthScheme}::new(#{Scheme})", *codegenScope, "Scheme" to scheme)
}
}
@ -236,8 +235,7 @@ private class HttpAuthConfigCustomization(
/// Sets an API key resolver will be used for authentication.
pub fn api_key_resolver(mut self, api_key_resolver: impl #{IdentityResolver} + 'static) -> Self {
#{ConfigBagAccessors}::push_identity_resolver(
&mut self.inner,
self.runtime_components.push_identity_resolver(
#{HTTP_API_KEY_AUTH_SCHEME_ID},
#{SharedIdentityResolver}::new(api_key_resolver)
);
@ -257,8 +255,7 @@ private class HttpAuthConfigCustomization(
/// Sets a bearer token provider that will be used for HTTP bearer auth.
pub fn bearer_token_resolver(mut self, bearer_token_resolver: impl #{IdentityResolver} + 'static) -> Self {
#{ConfigBagAccessors}::push_identity_resolver(
&mut self.inner,
self.runtime_components.push_identity_resolver(
#{HTTP_BEARER_AUTH_SCHEME_ID},
#{SharedIdentityResolver}::new(bearer_token_resolver)
);
@ -278,8 +275,7 @@ private class HttpAuthConfigCustomization(
/// Sets a login resolver that will be used for HTTP basic auth.
pub fn basic_auth_login_resolver(mut self, basic_auth_resolver: impl #{IdentityResolver} + 'static) -> Self {
#{ConfigBagAccessors}::push_identity_resolver(
&mut self.inner,
self.runtime_components.push_identity_resolver(
#{HTTP_BASIC_AUTH_SCHEME_ID},
#{SharedIdentityResolver}::new(basic_auth_resolver)
);
@ -299,8 +295,7 @@ private class HttpAuthConfigCustomization(
/// Sets a login resolver that will be used for HTTP digest auth.
pub fn digest_auth_login_resolver(mut self, digest_auth_resolver: impl #{IdentityResolver} + 'static) -> Self {
#{ConfigBagAccessors}::push_identity_resolver(
&mut self.inner,
self.runtime_components.push_identity_resolver(
#{HTTP_DIGEST_AUTH_SCHEME_ID},
#{SharedIdentityResolver}::new(digest_auth_resolver)
);

View File

@ -40,7 +40,7 @@ class HttpChecksumRequiredGenerator(
is OperationSection.AdditionalRuntimePlugins -> writable {
section.addOperationRuntimePlugin(this) {
rustTemplate(
"#{HttpChecksumRequiredRuntimePlugin}",
"#{HttpChecksumRequiredRuntimePlugin}::new()",
"HttpChecksumRequiredRuntimePlugin" to
InlineDependency.forRustFile(
RustModule.pubCrate("client_http_checksum_required", parent = ClientRustModule.root),

View File

@ -6,6 +6,7 @@
package software.amazon.smithy.rust.codegen.client.smithy.customizations
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig
@ -40,13 +41,52 @@ private class HttpConnectorConfigCustomization(
*preludeScope,
"Connection" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::orchestrator::Connection"),
"ConnectorSettings" to RuntimeType.smithyClient(runtimeConfig).resolve("http_connector::ConnectorSettings"),
"DynConnector" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::connectors::DynConnector"),
"default_connector" to RuntimeType.smithyClient(runtimeConfig).resolve("conns::default_connector"),
"DynConnectorAdapter" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::connectors::adapter::DynConnectorAdapter"),
"HttpConnector" to RuntimeType.smithyClient(runtimeConfig).resolve("http_connector::HttpConnector"),
"Resolver" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::config_override::Resolver"),
"SharedAsyncSleep" to RuntimeType.smithyAsync(runtimeConfig).resolve("rt::sleep::SharedAsyncSleep"),
"SharedConnector" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::connectors::SharedConnector"),
"TimeoutConfig" to RuntimeType.smithyTypes(runtimeConfig).resolve("timeout::TimeoutConfig"),
)
private fun setConnectorFn(): RuntimeType = RuntimeType.forInlineFun("set_connector", ClientRustModule.config) {
rustTemplate(
"""
fn set_connector(resolver: &mut #{Resolver}<'_>) {
// Initial configuration needs to set a default if no connector is given, so it
// should always get into the condition below.
//
// Override configuration should set the connector if the override config
// contains a connector, sleep impl, or a timeout config since these are all
// incorporated into the final connector.
let must_set_connector = resolver.is_initial()
|| resolver.is_latest_set::<#{HttpConnector}>()
|| resolver.latest_sleep_impl().is_some()
|| resolver.is_latest_set::<#{TimeoutConfig}>();
if must_set_connector {
let sleep_impl = resolver.sleep_impl();
let timeout_config = resolver.resolve_config::<#{TimeoutConfig}>()
.cloned()
.unwrap_or_else(#{TimeoutConfig}::disabled);
let connector_settings = #{ConnectorSettings}::from_timeout_config(&timeout_config);
let http_connector = resolver.resolve_config::<#{HttpConnector}>();
// TODO(enableNewSmithyRuntimeCleanup): Replace the tower-based DynConnector and remove DynConnectorAdapter when deleting the middleware implementation
let connector =
http_connector
.and_then(|c| c.connector(&connector_settings, sleep_impl.clone()))
.or_else(|| #{default_connector}(&connector_settings, sleep_impl))
.map(|c| #{SharedConnector}::new(#{DynConnectorAdapter}::new(c)));
resolver.runtime_components_mut().set_connector(connector);
}
}
""",
*codegenScope,
)
}
override fun section(section: ServiceConfig): Writable {
return when (section) {
is ServiceConfig.ConfigStruct -> writable {
@ -59,9 +99,15 @@ private class HttpConnectorConfigCustomization(
if (runtimeMode.defaultToOrchestrator) {
rustTemplate(
"""
// TODO(enableNewSmithyRuntimeCleanup): Remove this function
/// Return an [`HttpConnector`](#{HttpConnector}) to use when making requests, if any.
pub fn http_connector(&self) -> Option<&#{HttpConnector}> {
self.inner.load::<#{HttpConnector}>()
self.config.load::<#{HttpConnector}>()
}
/// Return the [`SharedConnector`](#{SharedConnector}) to use when making requests, if any.
pub fn connector(&self) -> Option<#{SharedConnector}> {
self.runtime_components.connector()
}
""",
*codegenScope,
@ -164,12 +210,11 @@ private class HttpConnectorConfigCustomization(
""",
*codegenScope,
)
if (runtimeMode.defaultToOrchestrator) {
rustTemplate(
"""
pub fn set_http_connector(&mut self, http_connector: Option<impl Into<#{HttpConnector}>>) -> &mut Self {
http_connector.map(|c| self.inner.store_put(c.into()));
http_connector.map(|c| self.config.store_put(c.into()));
self
}
""",
@ -191,26 +236,8 @@ private class HttpConnectorConfigCustomization(
is ServiceConfig.BuilderBuild -> writable {
if (runtimeMode.defaultToOrchestrator) {
rustTemplate(
"""
let sleep_impl = layer.load::<#{SharedAsyncSleep}>().cloned();
let timeout_config = layer.load::<#{TimeoutConfig}>().cloned().unwrap_or_else(#{TimeoutConfig}::disabled);
let connector_settings = #{ConnectorSettings}::from_timeout_config(&timeout_config);
if let Some(connector) = layer.load::<#{HttpConnector}>()
.and_then(|c| c.connector(&connector_settings, sleep_impl.clone()))
.or_else(|| #{default_connector}(&connector_settings, sleep_impl)) {
let connector: #{DynConnector} = #{DynConnector}::new(#{DynConnectorAdapter}::new(
// TODO(enableNewSmithyRuntimeCleanup): Replace the tower-based DynConnector and remove DynConnectorAdapter when deleting the middleware implementation
connector
));
#{ConfigBagAccessors}::set_connector(&mut layer, connector);
}
""",
*codegenScope,
"ConfigBagAccessors" to RuntimeType.configBagAccessors(runtimeConfig),
"default_connector" to RuntimeType.smithyClient(runtimeConfig).resolve("conns::default_connector"),
"#{set_connector}(&mut resolver);",
"set_connector" to setConnectorFn(),
)
} else {
rust("http_connector: self.http_connector,")
@ -220,38 +247,8 @@ private class HttpConnectorConfigCustomization(
is ServiceConfig.OperationConfigOverride -> writable {
if (runtimeMode.defaultToOrchestrator) {
rustTemplate(
"""
if let #{Some}(http_connector) =
layer.load::<#{HttpConnector}>()
{
let sleep_impl = layer
.load::<#{SharedAsyncSleep}>()
.or_else(|| {
self.client_config
.load::<#{SharedAsyncSleep}>()
})
.cloned();
let timeout_config = layer
.load::<#{TimeoutConfig}>()
.or_else(|| {
self.client_config
.load::<#{TimeoutConfig}>()
})
.expect("timeout config should be set either in `config_override` or in the client config");
let connector_settings =
#{ConnectorSettings}::from_timeout_config(
timeout_config,
);
if let #{Some}(conn) = http_connector.connector(&connector_settings, sleep_impl) {
let connection: #{DynConnector} = #{DynConnector}::new(#{DynConnectorAdapter}::new(
// TODO(enableNewSmithyRuntimeCleanup): Replace the tower-based DynConnector and remove DynConnectorAdapter when deleting the middleware implementation
conn
));
layer.set_connector(connection);
}
}
""",
*codegenScope,
"#{set_connector}(&mut resolver);",
"set_connector" to setConnectorFn(),
)
}
}

View File

@ -1,32 +0,0 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.client.smithy.customizations
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
class IdentityConfigCustomization(private val codegenContext: ClientCodegenContext) : ConfigCustomization() {
override fun section(section: ServiceConfig): Writable = writable {
if (section is ServiceConfig.ConfigImpl) {
rustTemplate(
"""
/// Returns the identity resolvers.
pub fn identity_resolvers(&self) -> #{IdentityResolvers} {
#{ConfigBagAccessors}::identity_resolvers(&self.inner)
}
""",
"IdentityResolvers" to RuntimeType.smithyRuntimeApi(codegenContext.runtimeConfig)
.resolve("client::identity::IdentityResolvers"),
"ConfigBagAccessors" to RuntimeType.configBagAccessors(codegenContext.runtimeConfig),
)
}
}
}

View File

@ -32,23 +32,12 @@ class InterceptorConfigCustomization(codegenContext: ClientCodegenContext) : Con
override fun section(section: ServiceConfig) =
writable {
when (section) {
ServiceConfig.ConfigStruct -> rustTemplate(
"pub(crate) interceptors: Vec<#{SharedInterceptor}>,",
*codegenScope,
)
ServiceConfig.BuilderStruct ->
rustTemplate(
"interceptors: Vec<#{SharedInterceptor}>,",
*codegenScope,
)
ServiceConfig.ConfigImpl -> rustTemplate(
"""
#{maybe_hide_orchestrator_code}
/// Returns interceptors currently registered by the user.
pub fn interceptors(&self) -> impl Iterator<Item = &#{SharedInterceptor}> + '_ {
self.interceptors.iter()
pub fn interceptors(&self) -> impl Iterator<Item = #{SharedInterceptor}> + '_ {
self.runtime_components.interceptors()
}
""",
*codegenScope,
@ -103,7 +92,7 @@ class InterceptorConfigCustomization(codegenContext: ClientCodegenContext) : Con
/// ## }
/// ```
pub fn interceptor(mut self, interceptor: impl #{Interceptor} + Send + Sync + 'static) -> Self {
self.add_interceptor(#{SharedInterceptor}::new(interceptor));
self.push_interceptor(#{SharedInterceptor}::new(interceptor));
self
}
@ -146,7 +135,7 @@ class InterceptorConfigCustomization(codegenContext: ClientCodegenContext) : Con
/// Ok(())
/// }
/// }
/// builder.add_interceptor(SharedInterceptor::new(UriModifierInterceptor));
/// builder.push_interceptor(SharedInterceptor::new(UriModifierInterceptor));
/// }
///
/// let mut builder = Config::builder();
@ -155,29 +144,21 @@ class InterceptorConfigCustomization(codegenContext: ClientCodegenContext) : Con
/// ## }
/// ## }
/// ```
pub fn add_interceptor(&mut self, interceptor: #{SharedInterceptor}) -> &mut Self {
self.interceptors.push(interceptor);
pub fn push_interceptor(&mut self, interceptor: #{SharedInterceptor}) -> &mut Self {
self.runtime_components.push_interceptor(interceptor);
self
}
#{maybe_hide_orchestrator_code}
/// Set [`SharedInterceptor`](#{SharedInterceptor})s for the builder.
pub fn set_interceptors(&mut self, interceptors: impl IntoIterator<Item = #{SharedInterceptor}>) -> &mut Self {
self.interceptors = interceptors.into_iter().collect();
self.runtime_components.set_interceptors(interceptors.into_iter());
self
}
""",
*codegenScope,
)
is ServiceConfig.RuntimePluginInterceptors -> rust(
"""
${section.interceptors}.extend(${section.interceptorsField}.interceptors.iter().cloned());
""",
)
is ServiceConfig.BuilderBuildExtras -> rust("interceptors: self.interceptors,")
else -> emptySection
}
}

View File

@ -30,21 +30,21 @@ class ResiliencyConfigCustomization(private val codegenContext: ClientCodegenCon
private val moduleUseName = codegenContext.moduleUseName()
private val codegenScope = arrayOf(
*preludeScope,
"DynRetryStrategy" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::retries::DynRetryStrategy"),
"ClientRateLimiter" to retries.resolve("ClientRateLimiter"),
"ClientRateLimiterPartition" to retries.resolve("ClientRateLimiterPartition"),
"debug" to RuntimeType.Tracing.resolve("debug"),
"RetryConfig" to retryConfig.resolve("RetryConfig"),
"RetryMode" to RuntimeType.smithyTypes(runtimeConfig).resolve("retry::RetryMode"),
"RetryPartition" to retries.resolve("RetryPartition"),
"SharedAsyncSleep" to sleepModule.resolve("SharedAsyncSleep"),
"SharedRetryStrategy" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::retries::SharedRetryStrategy"),
"SharedTimeSource" to RuntimeType.smithyAsync(runtimeConfig).resolve("time::SharedTimeSource"),
"Sleep" to sleepModule.resolve("Sleep"),
"StandardRetryStrategy" to retries.resolve("strategy::StandardRetryStrategy"),
"SystemTime" to RuntimeType.std.resolve("time::SystemTime"),
"TimeoutConfig" to timeoutModule.resolve("TimeoutConfig"),
"RetryMode" to RuntimeType.smithyTypes(runtimeConfig).resolve("retry::RetryMode"),
"TokenBucket" to retries.resolve("TokenBucket"),
"ClientRateLimiter" to retries.resolve("ClientRateLimiter"),
"SharedTimeSource" to RuntimeType.smithyAsync(runtimeConfig).resolve("time::SharedTimeSource"),
"ClientRateLimiterPartition" to retries.resolve("ClientRateLimiterPartition"),
"TokenBucketPartition" to retries.resolve("TokenBucketPartition"),
"RetryPartition" to retries.resolve("RetryPartition"),
"debug" to RuntimeType.Tracing.resolve("debug"),
)
override fun section(section: ServiceConfig) =
@ -69,17 +69,17 @@ class ResiliencyConfigCustomization(private val codegenContext: ClientCodegenCon
"""
/// Return a reference to the retry configuration contained in this config, if any.
pub fn retry_config(&self) -> #{Option}<&#{RetryConfig}> {
self.inner.load::<#{RetryConfig}>()
self.config.load::<#{RetryConfig}>()
}
/// Return a cloned shared async sleep implementation from this config, if any.
pub fn sleep_impl(&self) -> #{Option}<#{SharedAsyncSleep}> {
self.inner.load::<#{SharedAsyncSleep}>().cloned()
self.runtime_components.sleep_impl()
}
/// Return a reference to the timeout configuration contained in this config, if any.
pub fn timeout_config(&self) -> #{Option}<&#{TimeoutConfig}> {
self.inner.load::<#{TimeoutConfig}>()
self.config.load::<#{TimeoutConfig}>()
}
##[doc(hidden)]
@ -88,7 +88,7 @@ class ResiliencyConfigCustomization(private val codegenContext: ClientCodegenCon
/// WARNING: This method is unstable and may be removed at any time. Do not rely on this
/// method for anything!
pub fn retry_partition(&self) -> #{Option}<&#{RetryPartition}> {
self.inner.load::<#{RetryPartition}>()
self.config.load::<#{RetryPartition}>()
}
""",
*codegenScope,
@ -171,7 +171,7 @@ class ResiliencyConfigCustomization(private val codegenContext: ClientCodegenCon
rustTemplate(
"""
pub fn set_retry_config(&mut self, retry_config: #{Option}<#{RetryConfig}>) -> &mut Self {
retry_config.map(|r| self.inner.store_put(r));
retry_config.map(|r| self.config.store_put(r));
self
}
""",
@ -249,7 +249,7 @@ class ResiliencyConfigCustomization(private val codegenContext: ClientCodegenCon
rustTemplate(
"""
pub fn set_sleep_impl(&mut self, sleep_impl: #{Option}<#{SharedAsyncSleep}>) -> &mut Self {
sleep_impl.map(|s| self.inner.store_put(s));
self.runtime_components.set_sleep_impl(sleep_impl);
self
}
""",
@ -317,7 +317,7 @@ class ResiliencyConfigCustomization(private val codegenContext: ClientCodegenCon
rustTemplate(
"""
pub fn set_timeout_config(&mut self, timeout_config: #{Option}<#{TimeoutConfig}>) -> &mut Self {
timeout_config.map(|t| self.inner.store_put(t));
timeout_config.map(|t| self.config.store_put(t));
self
}
""",
@ -357,7 +357,7 @@ class ResiliencyConfigCustomization(private val codegenContext: ClientCodegenCon
/// also share things like token buckets and client rate limiters. By default, all clients
/// for the same service will share a partition.
pub fn set_retry_partition(&mut self, retry_partition: #{Option}<#{RetryPartition}>) -> &mut Self {
retry_partition.map(|r| self.inner.store_put(r));
retry_partition.map(|r| self.config.store_put(r));
self
}
""",
@ -377,7 +377,7 @@ class ResiliencyConfigCustomization(private val codegenContext: ClientCodegenCon
}
if retry_config.mode() == #{RetryMode}::Adaptive {
if let #{Some}(time_source) = layer.load::<#{SharedTimeSource}>().cloned() {
if let #{Some}(time_source) = self.runtime_components.time_source() {
let seconds_since_unix_epoch = time_source
.now()
.duration_since(#{SystemTime}::UNIX_EPOCH)
@ -395,13 +395,16 @@ class ResiliencyConfigCustomization(private val codegenContext: ClientCodegenCon
let token_bucket_partition = #{TokenBucketPartition}::new(retry_partition);
let token_bucket = TOKEN_BUCKET.get_or_init(token_bucket_partition, #{TokenBucket}::default);
layer.store_put(token_bucket);
layer.set_retry_strategy(#{DynRetryStrategy}::new(#{StandardRetryStrategy}::new(&retry_config)));
// TODO(enableNewSmithyRuntimeCleanup): Should not need to provide a default once smithy-rs##2770
// is resolved
if layer.load::<#{TimeoutConfig}>().is_none() {
layer.store_put(#{TimeoutConfig}::disabled());
}
self.runtime_components.set_retry_strategy(#{Some}(
#{SharedRetryStrategy}::new(#{StandardRetryStrategy}::new(&retry_config)))
);
""",
*codegenScope,
)
@ -420,24 +423,6 @@ class ResiliencyConfigCustomization(private val codegenContext: ClientCodegenCon
}
}
is ServiceConfig.OperationConfigOverride -> {
if (runtimeMode.defaultToOrchestrator) {
rustTemplate(
"""
if let #{Some}(retry_config) = layer
.load::<#{RetryConfig}>()
.cloned()
{
layer.set_retry_strategy(
#{DynRetryStrategy}::new(#{StandardRetryStrategy}::new(&retry_config))
);
}
""",
*codegenScope,
)
}
}
else -> emptySection
}
}

View File

@ -42,16 +42,16 @@ class TimeSourceCustomization(codegenContext: ClientCodegenContext) : ConfigCust
is ServiceConfig.ConfigImpl -> {
rust("/// Return time source used for this service.")
rustBlockTemplate(
"pub fn time_source(&self) -> #{SharedTimeSource}",
"pub fn time_source(&self) -> #{Option}<#{SharedTimeSource}>",
*codegenScope,
) {
if (runtimeMode.defaultToOrchestrator) {
rustTemplate(
"""self.inner.load::<#{SharedTimeSource}>().expect("time source should be set").clone()""",
"""self.runtime_components.time_source()""",
*codegenScope,
)
} else {
rust("self.time_source.clone()")
rustTemplate("#{Some}(self.time_source.clone())", *codegenScope)
}
}
}
@ -88,7 +88,7 @@ class TimeSourceCustomization(codegenContext: ClientCodegenContext) : ConfigCust
&mut self,
time_source: #{Option}<#{SharedTimeSource}>,
) -> &mut Self {
self.inner.store_or_unset(time_source);
self.runtime_components.set_time_source(time_source);
self
}
""",
@ -114,7 +114,11 @@ class TimeSourceCustomization(codegenContext: ClientCodegenContext) : ConfigCust
ServiceConfig.BuilderBuild -> {
if (runtimeMode.defaultToOrchestrator) {
rustTemplate(
"layer.store_put(layer.load::<#{SharedTimeSource}>().cloned().unwrap_or_default());",
"""
if self.runtime_components.time_source().is_none() {
self.runtime_components.set_time_source(#{Default}::default());
}
""",
*codegenScope,
)
} else {

View File

@ -13,7 +13,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.customizations.Endpoint
import software.amazon.smithy.rust.codegen.client.smithy.customizations.HttpChecksumRequiredGenerator
import software.amazon.smithy.rust.codegen.client.smithy.customizations.HttpVersionListCustomization
import software.amazon.smithy.rust.codegen.client.smithy.customizations.IdempotencyTokenGenerator
import software.amazon.smithy.rust.codegen.client.smithy.customizations.IdentityConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.customizations.InterceptorConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.customizations.MetadataCustomization
import software.amazon.smithy.rust.codegen.client.smithy.customizations.ResiliencyConfigCustomization
@ -66,8 +65,7 @@ class RequiredCustomizations : ClientCodegenDecorator {
baseCustomizations +
ResiliencyConfigCustomization(codegenContext) +
InterceptorConfigCustomization(codegenContext) +
TimeSourceCustomization(codegenContext) +
IdentityConfigCustomization(codegenContext)
TimeSourceCustomization(codegenContext)
} else {
baseCustomizations +
ResiliencyConfigCustomization(codegenContext) +

View File

@ -28,18 +28,20 @@ internal class EndpointConfigCustomization(
private val runtimeMode = codegenContext.smithyRuntimeMode
private val types = Types(runtimeConfig)
private val codegenScope = arrayOf(
*preludeScope,
"DefaultEndpointResolver" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::orchestrator::endpoints::DefaultEndpointResolver"),
"OldSharedEndpointResolver" to types.sharedEndpointResolver,
"Params" to typesGenerator.paramsStruct(),
"Resolver" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::config_override::Resolver"),
"SharedEndpointResolver" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::orchestrator::SharedEndpointResolver"),
"SmithyResolver" to types.resolveEndpoint,
)
override fun section(section: ServiceConfig): Writable {
return writable {
val sharedEndpointResolver = "#{SharedEndpointResolver}<#{Params}>"
val sharedEndpointResolver = "#{OldSharedEndpointResolver}<#{Params}>"
val resolverTrait = "#{SmithyResolver}<#{Params}>"
val codegenScope = arrayOf(
*preludeScope,
"DefaultEndpointResolver" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::orchestrator::endpoints::DefaultEndpointResolver"),
"DynEndpointResolver" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::orchestrator::DynEndpointResolver"),
"SharedEndpointResolver" to types.sharedEndpointResolver,
"SmithyResolver" to types.resolveEndpoint,
"Params" to typesGenerator.paramsStruct(),
)
when (section) {
is ServiceConfig.ConfigStruct -> {
if (runtimeMode.defaultToMiddleware) {
@ -55,8 +57,8 @@ internal class EndpointConfigCustomization(
rustTemplate(
"""
/// Returns the endpoint resolver.
pub fn endpoint_resolver(&self) -> $sharedEndpointResolver {
self.inner.load::<$sharedEndpointResolver>().expect("endpoint resolver should be set").clone()
pub fn endpoint_resolver(&self) -> #{SharedEndpointResolver} {
self.runtime_components.endpoint_resolver().expect("resolver defaulted if not set")
}
""",
*codegenScope,
@ -128,7 +130,7 @@ internal class EndpointConfigCustomization(
/// Sets the endpoint resolver to use when making requests.
$defaultResolverDocs
pub fn endpoint_resolver(mut self, endpoint_resolver: impl $resolverTrait + 'static) -> Self {
self.set_endpoint_resolver(#{Some}(#{SharedEndpointResolver}::new(endpoint_resolver)));
self.set_endpoint_resolver(#{Some}(#{OldSharedEndpointResolver}::new(endpoint_resolver)));
self
}
@ -144,7 +146,7 @@ internal class EndpointConfigCustomization(
rustTemplate(
"""
pub fn set_endpoint_resolver(&mut self, endpoint_resolver: #{Option}<$sharedEndpointResolver>) -> &mut Self {
self.inner.store_or_unset(endpoint_resolver);
self.config.store_or_unset(endpoint_resolver);
self
}
""",
@ -164,96 +166,29 @@ internal class EndpointConfigCustomization(
}
ServiceConfig.BuilderBuild -> {
val defaultResolver = typesGenerator.defaultResolver()
if (defaultResolver != null) {
if (runtimeMode.defaultToOrchestrator) {
rustTemplate(
// TODO(enableNewSmithyRuntimeCleanup): Simplify the endpoint resolvers
"""
let endpoint_resolver = #{DynEndpointResolver}::new(
#{DefaultEndpointResolver}::<#{Params}>::new(
layer.load::<$sharedEndpointResolver>().cloned().unwrap_or_else(||
#{SharedEndpointResolver}::new(#{DefaultResolver}::new())
)
)
);
layer.set_endpoint_resolver(endpoint_resolver);
""",
*codegenScope,
"DefaultResolver" to defaultResolver,
)
} else {
rustTemplate(
"""
endpoint_resolver: self.endpoint_resolver.unwrap_or_else(||
#{SharedEndpointResolver}::new(#{DefaultResolver}::new())
),
""",
*codegenScope,
"DefaultResolver" to defaultResolver,
)
}
if (runtimeMode.defaultToOrchestrator) {
rustTemplate(
"#{set_endpoint_resolver}(&mut resolver);",
"set_endpoint_resolver" to setEndpointResolverFn(),
)
} else {
val alwaysFailsResolver =
RuntimeType.forInlineFun("MissingResolver", ClientRustModule.endpoint(codegenContext)) {
rustTemplate(
"""
##[derive(Debug)]
pub(crate) struct MissingResolver;
impl<T> #{ResolveEndpoint}<T> for MissingResolver {
fn resolve_endpoint(&self, _params: &T) -> #{Result} {
Err(#{ResolveEndpointError}::message("an endpoint resolver must be provided."))
}
}
""",
"ResolveEndpoint" to types.resolveEndpoint,
"ResolveEndpointError" to types.resolveEndpointError,
"Result" to types.smithyHttpEndpointModule.resolve("Result"),
)
}
// To keep this diff under control, rather than `.expect` here, insert a resolver that will
// always fail. In the future, this will be changed to an `expect()`
if (runtimeMode.defaultToOrchestrator) {
rustTemplate(
"""
let endpoint_resolver = #{DynEndpointResolver}::new(
#{DefaultEndpointResolver}::<#{Params}>::new(
layer.load::<$sharedEndpointResolver>().cloned().unwrap_or_else(||
#{SharedEndpointResolver}::new(#{FailingResolver})
).clone()
)
);
layer.set_endpoint_resolver(endpoint_resolver);
""",
*codegenScope,
"FailingResolver" to alwaysFailsResolver,
)
} else {
rustTemplate(
"""
endpoint_resolver: self.endpoint_resolver.unwrap_or_else(||#{SharedEndpointResolver}::new(#{FailingResolver})),
""",
*codegenScope,
"FailingResolver" to alwaysFailsResolver,
)
}
rustTemplate(
"""
endpoint_resolver: self.endpoint_resolver.unwrap_or_else(||
#{OldSharedEndpointResolver}::new(#{DefaultResolver}::new())
),
""",
*codegenScope,
"DefaultResolver" to defaultResolver(),
)
}
}
is ServiceConfig.OperationConfigOverride -> {
if (runtimeMode.defaultToOrchestrator) {
rustTemplate(
"""
if let #{Some}(resolver) = layer
.load::<$sharedEndpointResolver>()
.cloned()
{
let endpoint_resolver = #{DynEndpointResolver}::new(
#{DefaultEndpointResolver}::<#{Params}>::new(resolver));
layer.set_endpoint_resolver(endpoint_resolver);
}
""",
*codegenScope,
"#{set_endpoint_resolver}(&mut resolver);",
"set_endpoint_resolver" to setEndpointResolverFn(),
)
}
}
@ -262,4 +197,58 @@ internal class EndpointConfigCustomization(
}
}
}
private fun defaultResolver(): RuntimeType {
// For now, fallback to a default endpoint resolver that always fails. In the future,
// the endpoint resolver will be required (so that it can be unwrapped).
return typesGenerator.defaultResolver() ?: RuntimeType.forInlineFun(
"MissingResolver",
ClientRustModule.endpoint(codegenContext),
) {
rustTemplate(
"""
##[derive(Debug)]
pub(crate) struct MissingResolver;
impl MissingResolver {
pub(crate) fn new() -> Self { Self }
}
impl<T> #{ResolveEndpoint}<T> for MissingResolver {
fn resolve_endpoint(&self, _params: &T) -> #{Result} {
Err(#{ResolveEndpointError}::message("an endpoint resolver must be provided."))
}
}
""",
"ResolveEndpoint" to types.resolveEndpoint,
"ResolveEndpointError" to types.resolveEndpointError,
"Result" to types.smithyHttpEndpointModule.resolve("Result"),
)
}
}
private fun setEndpointResolverFn(): RuntimeType = RuntimeType.forInlineFun("set_endpoint_resolver", ClientRustModule.config) {
// TODO(enableNewSmithyRuntimeCleanup): Simplify the endpoint resolvers
rustTemplate(
"""
fn set_endpoint_resolver(resolver: &mut #{Resolver}<'_>) {
let endpoint_resolver = if resolver.is_initial() {
Some(resolver.resolve_config::<#{OldSharedEndpointResolver}<#{Params}>>().cloned().unwrap_or_else(||
#{OldSharedEndpointResolver}::new(#{DefaultResolver}::new())
))
} else if resolver.is_latest_set::<#{OldSharedEndpointResolver}<#{Params}>>() {
resolver.resolve_config::<#{OldSharedEndpointResolver}<#{Params}>>().cloned()
} else {
None
};
if let Some(endpoint_resolver) = endpoint_resolver {
let shared = #{SharedEndpointResolver}::new(
#{DefaultEndpointResolver}::<#{Params}>::new(endpoint_resolver)
);
resolver.runtime_components_mut().set_endpoint_resolver(#{Some}(shared));
}
}
""",
*codegenScope,
"DefaultResolver" to defaultResolver(),
)
}
}

View File

@ -23,12 +23,15 @@ class ConfigOverrideRuntimePluginGenerator(
val smithyTypes = RuntimeType.smithyTypes(rc)
arrayOf(
*RuntimeType.preludeScope,
"Cow" to RuntimeType.Cow,
"CloneableLayer" to smithyTypes.resolve("config_bag::CloneableLayer"),
"ConfigBagAccessors" to runtimeApi.resolve("client::config_bag_accessors::ConfigBagAccessors"),
"FrozenLayer" to smithyTypes.resolve("config_bag::FrozenLayer"),
"InterceptorRegistrar" to runtimeApi.resolve("client::interceptors::InterceptorRegistrar"),
"Layer" to smithyTypes.resolve("config_bag::Layer"),
"RuntimePlugin" to runtimeApi.resolve("client::runtime_plugin::RuntimePlugin"),
"Resolver" to RuntimeType.smithyRuntime(rc).resolve("client::config_override::Resolver"),
"RuntimeComponentsBuilder" to RuntimeType.runtimeComponentsBuilder(rc),
"RuntimePlugin" to RuntimeType.runtimePlugin(rc),
)
}
@ -41,31 +44,40 @@ class ConfigOverrideRuntimePluginGenerator(
/// In the case of default values requested, they will be obtained from `client_config`.
##[derive(Debug)]
pub(crate) struct ConfigOverrideRuntimePlugin {
pub(crate) config_override: Builder,
pub(crate) client_config: #{FrozenLayer},
pub(crate) config: #{FrozenLayer},
pub(crate) components: #{RuntimeComponentsBuilder},
}
impl ConfigOverrideRuntimePlugin {
pub(crate) fn new(
config_override: Builder,
initial_config: #{FrozenLayer},
initial_components: &#{RuntimeComponentsBuilder}
) -> Self {
let mut layer = #{Layer}::from(config_override.config)
.with_name("$moduleUseName::config::ConfigOverrideRuntimePlugin");
let mut components = config_override.runtime_components;
let mut resolver = #{Resolver}::overrid(initial_config, initial_components, &mut layer, &mut components);
#{config}
let _ = resolver;
Self {
config: layer.freeze(),
components,
}
}
}
impl #{RuntimePlugin} for ConfigOverrideRuntimePlugin {
fn config(&self) -> #{Option}<#{FrozenLayer}> {
use #{ConfigBagAccessors};
##[allow(unused_mut)]
let layer: #{Layer} = self
.config_override
.inner
.clone()
.into();
let mut layer = layer.with_name("$moduleUseName::config::ConfigOverrideRuntimePlugin");
#{config}
#{Some}(layer.freeze())
Some(self.config.clone())
}
fn interceptors(&self, _interceptors: &mut #{InterceptorRegistrar}) {
#{interceptors}
fn runtime_components(&self) -> #{Cow}<'_, #{RuntimeComponentsBuilder}> {
#{Cow}::Borrowed(&self.components)
}
}
""",
*codegenScope,
"config" to writable {
@ -74,9 +86,6 @@ class ConfigOverrideRuntimePluginGenerator(
ServiceConfig.OperationConfigOverride("layer"),
)
},
"interceptors" to writable {
writeCustomizations(customizations, ServiceConfig.RuntimePluginInterceptors("_interceptors", "self.config_override"))
},
)
}
}

View File

@ -107,14 +107,17 @@ sealed class OperationSection(name: String) : Section(name) {
data class AdditionalInterceptors(
override val customizations: List<OperationCustomization>,
val interceptorRegistrarName: String,
val operationShape: OperationShape,
) : OperationSection("AdditionalInterceptors") {
fun registerInterceptor(runtimeConfig: RuntimeConfig, writer: RustWriter, interceptor: Writable) {
val smithyRuntimeApi = RuntimeType.smithyRuntimeApi(runtimeConfig)
writer.rustTemplate(
"""
$interceptorRegistrarName.register(#{SharedInterceptor}::new(#{interceptor}) as _);
.with_interceptor(
#{SharedInterceptor}::new(
#{interceptor}
) as _
)
""",
"interceptor" to interceptor,
"SharedInterceptor" to smithyRuntimeApi.resolve("client::interceptors::SharedInterceptor"),

View File

@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.derive
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.implBlock
import software.amazon.smithy.rust.codegen.core.rustlang.isNotEmpty
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
@ -118,6 +119,12 @@ open class OperationGenerator(
"SdkError" to RuntimeType.sdkError(runtimeConfig),
)
if (codegenContext.smithyRuntimeMode.generateOrchestrator) {
val additionalPlugins = writable {
writeCustomizations(
operationCustomizations,
OperationSection.AdditionalRuntimePlugins(operationCustomizations, operationShape),
)
}
rustTemplate(
"""
pub(crate) async fn orchestrate(
@ -159,14 +166,16 @@ open class OperationGenerator(
config_override: #{Option}<crate::config::Builder>,
) -> #{RuntimePlugins} {
let mut runtime_plugins = client_runtime_plugins.with_operation_plugin(Self::new());
#{additional_runtime_plugins}
if let Some(config_override) = config_override {
runtime_plugins = runtime_plugins.with_operation_plugin(crate::config::ConfigOverrideRuntimePlugin {
config_override,
client_config: #{RuntimePlugin}::config(client_config).expect("frozen layer should exist in client config"),
})
for plugin in config_override.runtime_plugins.iter().cloned() {
runtime_plugins = runtime_plugins.with_operation_plugin(plugin);
}
runtime_plugins = runtime_plugins.with_operation_plugin(
crate::config::ConfigOverrideRuntimePlugin::new(config_override, client_config.config.clone(), &client_config.runtime_components)
);
}
runtime_plugins
#{additional_runtime_plugins}
}
""",
*codegenScope,
@ -179,10 +188,15 @@ open class OperationGenerator(
"StopPoint" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::orchestrator::StopPoint"),
"invoke_with_stop_point" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::orchestrator::invoke_with_stop_point"),
"additional_runtime_plugins" to writable {
writeCustomizations(
operationCustomizations,
OperationSection.AdditionalRuntimePlugins(operationCustomizations, operationShape),
)
if (additionalPlugins.isNotEmpty()) {
rustTemplate(
"""
runtime_plugins = runtime_plugins
#{additional_runtime_plugins};
""",
"additional_runtime_plugins" to additionalPlugins,
)
}
},
)
}

View File

@ -34,17 +34,19 @@ class OperationRuntimePluginGenerator(
val runtimeApi = RuntimeType.smithyRuntimeApi(rc)
val smithyTypes = RuntimeType.smithyTypes(rc)
arrayOf(
*preludeScope,
"AuthOptionResolverParams" to runtimeApi.resolve("client::auth::AuthOptionResolverParams"),
"BoxError" to RuntimeType.boxError(codegenContext.runtimeConfig),
"ConfigBag" to RuntimeType.configBag(codegenContext.runtimeConfig),
"ConfigBagAccessors" to RuntimeType.configBagAccessors(codegenContext.runtimeConfig),
"DynAuthOptionResolver" to runtimeApi.resolve("client::auth::DynAuthOptionResolver"),
"Cow" to RuntimeType.Cow,
"SharedAuthOptionResolver" to runtimeApi.resolve("client::auth::SharedAuthOptionResolver"),
"DynResponseDeserializer" to runtimeApi.resolve("client::orchestrator::DynResponseDeserializer"),
"FrozenLayer" to smithyTypes.resolve("config_bag::FrozenLayer"),
"InterceptorRegistrar" to runtimeApi.resolve("client::interceptors::InterceptorRegistrar"),
"Layer" to smithyTypes.resolve("config_bag::Layer"),
"RetryClassifiers" to runtimeApi.resolve("client::retries::RetryClassifiers"),
"RuntimePlugin" to runtimeApi.resolve("client::runtime_plugin::RuntimePlugin"),
"RuntimePlugin" to RuntimeType.runtimePlugin(codegenContext.runtimeConfig),
"RuntimeComponentsBuilder" to RuntimeType.runtimeComponentsBuilder(codegenContext.runtimeConfig),
"SharedRequestSerializer" to runtimeApi.resolve("client::orchestrator::SharedRequestSerializer"),
"StaticAuthOptionResolver" to runtimeApi.resolve("client::auth::option_resolver::StaticAuthOptionResolver"),
"StaticAuthOptionResolverParams" to runtimeApi.resolve("client::auth::option_resolver::StaticAuthOptionResolverParams"),
@ -68,22 +70,25 @@ class OperationRuntimePluginGenerator(
cfg.set_request_serializer(#{SharedRequestSerializer}::new(${operationStructName}RequestSerializer));
cfg.set_response_deserializer(#{DynResponseDeserializer}::new(${operationStructName}ResponseDeserializer));
// Retry classifiers are operation-specific because they need to downcast operation-specific error types.
let retry_classifiers = #{RetryClassifiers}::new()
#{retry_classifier_customizations};
cfg.set_retry_classifiers(retry_classifiers);
${"" /* TODO(IdentityAndAuth): Resolve auth parameters from input for services that need this */}
cfg.set_auth_option_resolver_params(#{AuthOptionResolverParams}::new(#{StaticAuthOptionResolverParams}::new()));
#{auth_options}
#{additional_config}
Some(cfg.freeze())
}
fn interceptors(&self, _interceptors: &mut #{InterceptorRegistrar}) {
#{interceptors}
fn runtime_components(&self) -> #{Cow}<'_, #{RuntimeComponentsBuilder}> {
// Retry classifiers are operation-specific because they need to downcast operation-specific error types.
let retry_classifiers = #{RetryClassifiers}::new()
#{retry_classifier_customizations};
#{Cow}::Owned(
#{RuntimeComponentsBuilder}::new(${operationShape.id.name.dq()})
.with_retry_classifiers(Some(retry_classifiers))
#{auth_options}
#{interceptors}
)
}
}
@ -117,7 +122,7 @@ class OperationRuntimePluginGenerator(
"interceptors" to writable {
writeCustomizations(
customizations,
OperationSection.AdditionalInterceptors(customizations, "_interceptors", operationShape),
OperationSection.AdditionalInterceptors(customizations, operationShape),
)
},
)
@ -135,8 +140,12 @@ class OperationRuntimePluginGenerator(
option.schemeShapeId to option
}
withBlockTemplate(
"cfg.set_auth_option_resolver(#{DynAuthOptionResolver}::new(#{StaticAuthOptionResolver}::new(vec![",
"])));",
"""
.with_auth_option_resolver(#{Some}(
#{SharedAuthOptionResolver}::new(
#{StaticAuthOptionResolver}::new(vec![
""",
"]))))",
*codegenScope,
) {
var noSupportedAuthSchemes = true

View File

@ -54,7 +54,6 @@ class ServiceGenerator(
ServiceRuntimePluginGenerator(codegenContext)
.render(this, decorator.serviceRuntimePluginCustomizations(codegenContext, emptyList()))
serviceConfigGenerator.renderRuntimePluginImplForSelf(this)
ConfigOverrideRuntimePluginGenerator(codegenContext)
.render(this, decorator.configCustomizations(codegenContext, listOf()))
}

View File

@ -36,45 +36,37 @@ sealed class ServiceRuntimePluginSection(name: String) : Section(name) {
fun putConfigValue(writer: RustWriter, value: Writable) {
writer.rust("$newLayerName.store_put(#T);", value)
}
fun registerHttpAuthScheme(writer: RustWriter, runtimeConfig: RuntimeConfig, authScheme: Writable) {
writer.rustTemplate(
"""
#{ConfigBagAccessors}::push_http_auth_scheme(
&mut $newLayerName,
#{auth_scheme}
);
""",
"ConfigBagAccessors" to RuntimeType.configBagAccessors(runtimeConfig),
"auth_scheme" to authScheme,
)
}
fun registerIdentityResolver(writer: RustWriter, runtimeConfig: RuntimeConfig, identityResolver: Writable) {
writer.rustTemplate(
"""
#{ConfigBagAccessors}::push_identity_resolver(
&mut $newLayerName,
#{identity_resolver}
);
""",
"ConfigBagAccessors" to RuntimeType.configBagAccessors(runtimeConfig),
"identity_resolver" to identityResolver,
)
}
}
data class RegisterInterceptor(val interceptorRegistrarName: String) : ServiceRuntimePluginSection("RegisterInterceptor") {
data class RegisterRuntimeComponents(val serviceConfigName: String) : ServiceRuntimePluginSection("RegisterRuntimeComponents") {
/** Generates the code to register an interceptor */
fun registerInterceptor(runtimeConfig: RuntimeConfig, writer: RustWriter, interceptor: Writable) {
writer.rustTemplate(
"""
$interceptorRegistrarName.register(#{SharedInterceptor}::new(#{interceptor}) as _);
runtime_components.push_interceptor(#{SharedInterceptor}::new(#{interceptor}) as _);
""",
"interceptor" to interceptor,
"SharedInterceptor" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::interceptors::SharedInterceptor"),
)
}
fun registerHttpAuthScheme(writer: RustWriter, authScheme: Writable) {
writer.rustTemplate(
"""
runtime_components.push_http_auth_scheme(#{auth_scheme});
""",
"auth_scheme" to authScheme,
)
}
fun registerIdentityResolver(writer: RustWriter, identityResolver: Writable) {
writer.rustTemplate(
"""
runtime_components.push_identity_resolver(#{identity_resolver});
""",
"identity_resolver" to identityResolver,
)
}
}
}
typealias ServiceRuntimePluginCustomization = NamedCustomization<ServiceRuntimePluginSection>
@ -92,12 +84,11 @@ class ServiceRuntimePluginGenerator(
*preludeScope,
"Arc" to RuntimeType.Arc,
"BoxError" to RuntimeType.boxError(codegenContext.runtimeConfig),
"ConfigBag" to RuntimeType.configBag(codegenContext.runtimeConfig),
"Cow" to RuntimeType.Cow,
"Layer" to smithyTypes.resolve("config_bag::Layer"),
"FrozenLayer" to smithyTypes.resolve("config_bag::FrozenLayer"),
"ConfigBagAccessors" to RuntimeType.configBagAccessors(rc),
"InterceptorRegistrar" to runtimeApi.resolve("client::interceptors::InterceptorRegistrar"),
"RuntimePlugin" to runtimeApi.resolve("client::runtime_plugin::RuntimePlugin"),
"RuntimeComponentsBuilder" to RuntimeType.runtimeComponentsBuilder(rc),
"RuntimePlugin" to RuntimeType.runtimePlugin(rc),
)
}
@ -113,15 +104,15 @@ class ServiceRuntimePluginGenerator(
##[derive(Debug)]
pub(crate) struct ServiceRuntimePlugin {
config: #{Option}<#{FrozenLayer}>,
runtime_components: #{RuntimeComponentsBuilder},
}
impl ServiceRuntimePlugin {
pub fn new(_service_config: crate::config::Config) -> Self {
Self {
config: {
#{config}
},
}
let config = { #{config} };
let mut runtime_components = #{RuntimeComponentsBuilder}::new("ServiceRuntimePlugin");
#{runtime_components}
Self { config, runtime_components }
}
}
@ -130,9 +121,8 @@ class ServiceRuntimePluginGenerator(
self.config.clone()
}
fn interceptors(&self, interceptors: &mut #{InterceptorRegistrar}) {
let _interceptors = interceptors;
#{additional_interceptors}
fn runtime_components(&self) -> #{Cow}<'_, #{RuntimeComponentsBuilder}> {
#{Cow}::Borrowed(&self.runtime_components)
}
}
@ -155,8 +145,8 @@ class ServiceRuntimePluginGenerator(
rust("None")
}
},
"additional_interceptors" to writable {
writeCustomizations(customizations, ServiceRuntimePluginSection.RegisterInterceptor("_interceptors"))
"runtime_components" to writable {
writeCustomizations(customizations, ServiceRuntimePluginSection.RegisterRuntimeComponents("_service_config"))
},
"declare_singletons" to writable {
writeCustomizations(customizations, ServiceRuntimePluginSection.DeclareSingletons())

View File

@ -156,6 +156,7 @@ class CustomizableOperationGenerator(
"MutateRequestInterceptor" to RuntimeType.smithyRuntime(runtimeConfig)
.resolve("client::interceptors::MutateRequestInterceptor"),
"RuntimePlugin" to RuntimeType.runtimePlugin(runtimeConfig),
"SharedRuntimePlugin" to RuntimeType.sharedRuntimePlugin(runtimeConfig),
"SendResult" to ClientRustModule.Client.customize.toType()
.resolve("internal::SendResult"),
"SdkBody" to RuntimeType.sdkBody(runtimeConfig),
@ -188,6 +189,7 @@ class CustomizableOperationGenerator(
pub(crate) customizable_send: #{Box}<dyn #{CustomizableSend}<T, E>>,
pub(crate) config_override: #{Option}<crate::config::Builder>,
pub(crate) interceptors: Vec<#{SharedInterceptor}>,
pub(crate) runtime_plugins: Vec<#{SharedRuntimePlugin}>,
}
impl<T, E> CustomizableOperation<T, E> {
@ -202,6 +204,13 @@ class CustomizableOperationGenerator(
self
}
/// Adds a runtime plugin.
##[allow(unused)]
pub(crate) fn runtime_plugin(mut self, runtime_plugin: impl #{RuntimePlugin} + 'static) -> Self {
self.runtime_plugins.push(#{SharedRuntimePlugin}::new(runtime_plugin));
self
}
/// Allows for customizing the operation's request.
pub fn map_request<F, MapE>(mut self, f: F) -> Self
where
@ -264,7 +273,10 @@ class CustomizableOperationGenerator(
};
self.interceptors.into_iter().for_each(|interceptor| {
config_override.add_interceptor(interceptor);
config_override.push_interceptor(interceptor);
});
self.runtime_plugins.into_iter().for_each(|plugin| {
config_override.push_runtime_plugin(plugin);
});
(self.customizable_send)(config_override).await

View File

@ -543,6 +543,7 @@ class FluentClientGenerator(
}),
config_override: None,
interceptors: vec![],
runtime_plugins: vec![],
}
}
""",
@ -652,18 +653,30 @@ private fun baseClientRuntimePluginsFn(runtimeConfig: RuntimeConfig): RuntimeTyp
rustTemplate(
"""
pub(crate) fn base_client_runtime_plugins(
config: crate::Config,
mut config: crate::Config,
) -> #{RuntimePlugins} {
#{RuntimePlugins}::new()
.with_client_plugin(config.clone())
let mut configured_plugins = #{Vec}::new();
::std::mem::swap(&mut config.runtime_plugins, &mut configured_plugins);
let mut plugins = #{RuntimePlugins}::new()
.with_client_plugin(
#{StaticRuntimePlugin}::new()
.with_config(config.config.clone())
.with_runtime_components(config.runtime_components.clone())
)
.with_client_plugin(crate::config::ServiceRuntimePlugin::new(config))
.with_client_plugin(#{NoAuthRuntimePlugin}::new())
.with_client_plugin(#{NoAuthRuntimePlugin}::new());
for plugin in configured_plugins {
plugins = plugins.with_client_plugin(plugin);
}
plugins
}
""",
*preludeScope,
"RuntimePlugins" to RuntimeType.runtimePlugins(runtimeConfig),
"NoAuthRuntimePlugin" to RuntimeType.smithyRuntime(runtimeConfig)
.resolve("client::auth::no_auth::NoAuthRuntimePlugin"),
"StaticRuntimePlugin" to RuntimeType.smithyRuntimeApi(runtimeConfig)
.resolve("client::runtime_plugin::StaticRuntimePlugin"),
)
}

View File

@ -42,7 +42,7 @@ class IdempotencyTokenProviderCustomization(codegenContext: ClientCodegenContext
/// If a random token provider was configured,
/// a newly-randomized token provider will be returned.
pub fn idempotency_token_provider(&self) -> #{IdempotencyTokenProvider} {
self.inner.load::<#{IdempotencyTokenProvider}>().expect("the idempotency provider should be set").clone()
self.config.load::<#{IdempotencyTokenProvider}>().expect("the idempotency provider should be set").clone()
}
""",
*codegenScope,
@ -85,7 +85,7 @@ class IdempotencyTokenProviderCustomization(codegenContext: ClientCodegenContext
"""
/// Sets the idempotency token provider to use for service calls that require tokens.
pub fn set_idempotency_token_provider(&mut self, idempotency_token_provider: #{Option}<#{IdempotencyTokenProvider}>) -> &mut Self {
self.inner.store_or_unset(idempotency_token_provider);
self.config.store_or_unset(idempotency_token_provider);
self
}
""",
@ -108,7 +108,11 @@ class IdempotencyTokenProviderCustomization(codegenContext: ClientCodegenContext
ServiceConfig.BuilderBuild -> writable {
if (runtimeMode.defaultToOrchestrator) {
rustTemplate(
"layer.store_put(layer.load::<#{IdempotencyTokenProvider}>().cloned().unwrap_or_else(#{default_provider}));",
"""
if !resolver.is_set::<#{IdempotencyTokenProvider}>() {
resolver.config_mut().store_put(#{default_provider}());
}
""",
*codegenScope,
)
} else {

View File

@ -29,7 +29,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations
import software.amazon.smithy.rust.codegen.core.smithy.makeOptional
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.letIf
@ -101,11 +100,6 @@ sealed class ServiceConfig(name: String) : Section(name) {
*/
data class OperationConfigOverride(val cfg: String) : ServiceConfig("ToRuntimePlugin")
/**
* A section for appending additional runtime plugins, stored in [interceptorsField], to [interceptors]
*/
data class RuntimePluginInterceptors(val interceptors: String, val interceptorsField: String) : ServiceConfig("ToRuntimePluginInterceptors")
/**
* A section for extra functionality that needs to be defined with the config module
*/
@ -234,7 +228,7 @@ fun standardConfigParam(param: ConfigParam, codegenContext: ClientCodegenContext
rustTemplate(
"""
pub fn set_${param.name}(&mut self, ${param.name}: Option<#{T}>) -> &mut Self {
self.inner.store_or_unset(${param.name}.map(#{newtype}));
self.config.store_or_unset(${param.name}.map(#{newtype}));
self
}
""",
@ -261,8 +255,6 @@ fun standardConfigParam(param: ConfigParam, codegenContext: ClientCodegenContext
}
}
is ServiceConfig.OperationConfigOverride -> emptySection
else -> emptySection
}
}
@ -312,18 +304,20 @@ class ServiceConfigGenerator(
}
}
private val runtimeApi = RuntimeType.smithyRuntimeApi(codegenContext.runtimeConfig)
private val smithyTypes = RuntimeType.smithyTypes(codegenContext.runtimeConfig)
val codegenScope = arrayOf(
*preludeScope,
"BoxError" to RuntimeType.boxError(codegenContext.runtimeConfig),
"CloneableLayer" to smithyTypes.resolve("config_bag::CloneableLayer"),
"ConfigBag" to RuntimeType.configBag(codegenContext.runtimeConfig),
"ConfigBagAccessors" to RuntimeType.configBagAccessors(codegenContext.runtimeConfig),
"Cow" to RuntimeType.Cow,
"FrozenLayer" to smithyTypes.resolve("config_bag::FrozenLayer"),
"InterceptorRegistrar" to runtimeApi.resolve("client::interceptors::InterceptorRegistrar"),
"Layer" to smithyTypes.resolve("config_bag::Layer"),
"RuntimePlugin" to runtimeApi.resolve("client::runtime_plugin::RuntimePlugin"),
*preludeScope,
"Resolver" to RuntimeType.smithyRuntime(codegenContext.runtimeConfig).resolve("client::config_override::Resolver"),
"RuntimeComponentsBuilder" to RuntimeType.runtimeComponentsBuilder(codegenContext.runtimeConfig),
"RuntimePlugin" to RuntimeType.runtimePlugin(codegenContext.runtimeConfig),
"SharedRuntimePlugin" to RuntimeType.sharedRuntimePlugin(codegenContext.runtimeConfig),
)
private val moduleUseName = codegenContext.moduleUseName()
private val runtimeMode = codegenContext.smithyRuntimeMode
@ -334,10 +328,17 @@ class ServiceConfigGenerator(
it.section(ServiceConfig.ConfigStructAdditionalDocs)(writer)
}
Attribute(Attribute.derive(RuntimeType.Clone)).render(writer)
if (runtimeMode.generateOrchestrator) {
Attribute(Attribute.derive(RuntimeType.Debug)).render(writer)
}
writer.rustBlock("pub struct Config") {
if (runtimeMode.defaultToOrchestrator) {
rustTemplate(
"inner: #{FrozenLayer},",
"""
pub(crate) config: #{FrozenLayer},
pub(crate) runtime_components: #{RuntimeComponentsBuilder},
pub(crate) runtime_plugins: #{Vec}<#{SharedRuntimePlugin}>,
""",
*codegenScope,
)
}
@ -346,16 +347,18 @@ class ServiceConfigGenerator(
}
}
// Custom implementation for Debug so we don't need to enforce Debug down the chain
writer.rustBlock("impl std::fmt::Debug for Config") {
writer.rustTemplate(
"""
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut config = f.debug_struct("Config");
config.finish()
}
""",
)
if (runtimeMode.defaultToMiddleware) {
// Custom implementation for Debug so we don't need to enforce Debug down the chain
writer.rustBlock("impl std::fmt::Debug for Config") {
writer.rustTemplate(
"""
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut config = f.debug_struct("Config");
config.finish()
}
""",
)
}
}
writer.rustBlock("impl Config") {
@ -372,10 +375,17 @@ class ServiceConfigGenerator(
writer.docs("Builder for creating a `Config`.")
writer.raw("#[derive(Clone, Default)]")
if (runtimeMode.defaultToOrchestrator) {
Attribute(Attribute.derive(RuntimeType.Debug)).render(writer)
}
writer.rustBlock("pub struct Builder") {
if (runtimeMode.defaultToOrchestrator) {
rustTemplate(
"inner: #{CloneableLayer},",
"""
pub(crate) config: #{CloneableLayer},
pub(crate) runtime_components: #{RuntimeComponentsBuilder},
pub(crate) runtime_plugins: #{Vec}<#{SharedRuntimePlugin}>,
""",
*codegenScope,
)
}
@ -384,16 +394,18 @@ class ServiceConfigGenerator(
}
}
// Custom implementation for Debug so we don't need to enforce Debug down the chain
writer.rustBlock("impl std::fmt::Debug for Builder") {
writer.rustTemplate(
"""
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut config = f.debug_struct("Builder");
config.finish()
}
""",
)
if (runtimeMode.defaultToMiddleware) {
// Custom implementation for Debug so we don't need to enforce Debug down the chain
writer.rustBlock("impl std::fmt::Debug for Builder") {
writer.rustTemplate(
"""
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut config = f.debug_struct("Builder");
config.finish()
}
""",
)
}
}
writer.rustBlock("impl Builder") {
@ -403,6 +415,27 @@ class ServiceConfigGenerator(
it.section(ServiceConfig.BuilderImpl)(this)
}
if (runtimeMode.defaultToOrchestrator) {
rustTemplate(
"""
/// Adds a runtime plugin to the config.
##[allow(unused)]
pub(crate) fn runtime_plugin(mut self, plugin: impl #{RuntimePlugin} + 'static) -> Self {
self.push_runtime_plugin(#{SharedRuntimePlugin}::new(plugin));
self
}
/// Adds a runtime plugin to the config.
##[allow(unused)]
pub(crate) fn push_runtime_plugin(&mut self, plugin: #{SharedRuntimePlugin}) -> &mut Self {
self.runtime_plugins.push(plugin);
self
}
""",
*codegenScope,
)
}
val testUtilOnly =
Attribute(Attribute.cfg(Attribute.any(Attribute.feature(TestUtilFeature.name), writable("test"))))
@ -427,16 +460,13 @@ class ServiceConfigGenerator(
rustBlock("pub fn build(mut self) -> Config") {
rustTemplate(
"""
##[allow(unused_imports)]
use #{ConfigBagAccessors};
// The builder is being turned into a service config. While doing so, we'd like to avoid
// requiring that items created and stored _during_ the build method be `Clone`, since they
// will soon be part of a `FrozenLayer` owned by the service config. So we will convert the
// current `CloneableLayer` into a `Layer` that does not impose the `Clone` requirement.
let layer: #{Layer} = self
.inner
.into();
let mut layer = layer.with_name("$moduleUseName::config::config");
let mut layer = #{Layer}::from(self.config).with_name("$moduleUseName::config::Config");
##[allow(unused)]
let mut resolver = #{Resolver}::initial(&mut layer, &mut self.runtime_components);
""",
*codegenScope,
)
@ -447,7 +477,13 @@ class ServiceConfigGenerator(
customizations.forEach {
it.section(ServiceConfig.BuilderBuildExtras)(this)
}
rust("inner: layer.freeze(),")
rust(
"""
config: layer.freeze(),
runtime_components: self.runtime_components,
runtime_plugins: self.runtime_plugins,
""",
)
}
}
} else {
@ -464,26 +500,4 @@ class ServiceConfigGenerator(
}
}
}
fun renderRuntimePluginImplForSelf(writer: RustWriter) {
writer.rustTemplate(
"""
impl #{RuntimePlugin} for Config {
fn config(&self) -> #{Option}<#{FrozenLayer}> {
#{Some}(self.inner.clone())
}
fn interceptors(&self, _interceptors: &mut #{InterceptorRegistrar}) {
#{interceptors}
}
}
""",
*codegenScope,
"config" to writable { writeCustomizations(customizations, ServiceConfig.OperationConfigOverride("cfg")) },
"interceptors" to writable {
writeCustomizations(customizations, ServiceConfig.RuntimePluginInterceptors("_interceptors", "self"))
},
)
}
}

View File

@ -41,7 +41,7 @@ fun stubConfigCustomization(name: String, codegenContext: ClientCodegenContext):
"""
##[allow(missing_docs)]
pub fn $name(&self) -> u64 {
self.inner.load::<#{T}>().map(|u| u.0).unwrap()
self.config.load::<#{T}>().map(|u| u.0).unwrap()
}
""",
"T" to configParamNewtype(
@ -71,7 +71,7 @@ fun stubConfigCustomization(name: String, codegenContext: ClientCodegenContext):
"""
/// docs!
pub fn $name(mut self, $name: u64) -> Self {
self.inner.store_put(#{T}($name));
self.config.store_put(#{T}($name));
self
}
""",
@ -129,9 +129,6 @@ fun stubConfigProject(codegenContext: ClientCodegenContext, customization: Confi
val generator = ServiceConfigGenerator(codegenContext, customizations = customizations.toList())
project.withModule(ClientRustModule.config) {
generator.render(this)
if (codegenContext.smithyRuntimeMode.defaultToOrchestrator) {
generator.renderRuntimePluginImplForSelf(this)
}
unitTest(
"config_send_sync",
"""

View File

@ -48,6 +48,8 @@ class MetadataCustomizationTest {
"Interceptor" to RuntimeType.interceptor(runtimeConfig),
"Metadata" to RuntimeType.operationModule(runtimeConfig).resolve("Metadata"),
"capture_request" to RuntimeType.captureRequest(runtimeConfig),
"RuntimeComponents" to RuntimeType.smithyRuntimeApi(runtimeConfig)
.resolve("client::runtime_components::RuntimeComponents"),
)
rustCrate.testModule {
addDependency(CargoDependency.Tokio.withFeature("test-util").toDevDependency())
@ -64,6 +66,7 @@ class MetadataCustomizationTest {
fn modify_before_signing(
&self,
_context: &mut #{BeforeTransmitInterceptorContextMut}<'_>,
_runtime_components: &#{RuntimeComponents},
cfg: &mut #{ConfigBag},
) -> #{Result}<(), #{BoxError}> {
let metadata = cfg

View File

@ -47,16 +47,14 @@ class ClientContextConfigCustomizationTest {
use #{RuntimePlugin};
let conf = crate::Config::builder().a_string_param("hello!").a_bool_param(true).build();
assert_eq!(
conf.config()
.unwrap()
conf.config
.load::<crate::config::AStringParam>()
.map(|u| u.0.clone())
.unwrap(),
"hello!"
);
assert_eq!(
conf.config()
.unwrap()
conf.config
.load::<crate::config::ABoolParam>()
.map(|u| u.0),
Some(true)
@ -82,16 +80,14 @@ class ClientContextConfigCustomizationTest {
use #{RuntimePlugin};
let conf = crate::Config::builder().a_string_param("hello!").build();
assert_eq!(
conf.config()
.unwrap()
conf.config
.load::<crate::config::AStringParam>()
.map(|u| u.0.clone())
.unwrap(),
"hello!"
);
assert_eq!(
conf.config()
.unwrap()
conf.config
.load::<crate::config::ABoolParam>()
.map(|u| u.0),
None,

View File

@ -15,7 +15,7 @@ import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.integrationTest
import software.amazon.smithy.rust.codegen.core.testutil.runWithWarnings
import software.amazon.smithy.rust.codegen.core.util.CommandFailed
import software.amazon.smithy.rust.codegen.core.util.CommandError
/**
* End-to-end test of endpoint resolvers, attaching a real resolver to a fully generated service
@ -162,7 +162,7 @@ class EndpointsDecoratorTest {
}
}
// the model has an intentionally failing test—ensure it fails
val failure = shouldThrow<CommandFailed> { "cargo test".runWithWarnings(testDir) }
val failure = shouldThrow<CommandError> { "cargo test".runWithWarnings(testDir) }
failure.output shouldContain "endpoint::test::test_1"
failure.output shouldContain "https://failingtest.com"
"cargo clippy".runWithWarnings(testDir)

View File

@ -56,19 +56,20 @@ internal class ConfigOverrideRuntimePluginGeneratorTest {
tokioTest("test_operation_overrides_endpoint_resolver") {
rustTemplate(
"""
use #{ConfigBagAccessors};
use #{RuntimePlugin};
use ::aws_smithy_runtime_api::client::orchestrator::EndpointResolver;
let expected_url = "http://localhost:1234/";
let client_config = crate::config::Config::builder().build();
let config_override =
crate::config::Config::builder().endpoint_resolver(expected_url);
let sut = crate::config::ConfigOverrideRuntimePlugin {
client_config: client_config.config().unwrap(),
let sut = crate::config::ConfigOverrideRuntimePlugin::new(
config_override,
};
let sut_layer = sut.config().unwrap();
let endpoint_resolver = sut_layer.endpoint_resolver();
client_config.config,
&client_config.runtime_components,
);
let sut_components = sut.runtime_components();
let endpoint_resolver = sut_components.endpoint_resolver().unwrap();
let endpoint = endpoint_resolver
.resolve_endpoint(&#{EndpointResolverParams}::new(crate::config::endpoint::Params {}))
.await
@ -186,6 +187,7 @@ internal class ConfigOverrideRuntimePluginGeneratorTest {
.resolve("client::request_attempts::RequestAttempts"),
"RetryClassifiers" to RuntimeType.smithyRuntimeApi(runtimeConfig)
.resolve("client::retries::RetryClassifiers"),
"RuntimeComponentsBuilder" to RuntimeType.runtimeComponentsBuilder(runtimeConfig),
"RuntimePlugin" to RuntimeType.runtimePlugin(runtimeConfig),
"ShouldAttempt" to RuntimeType.smithyRuntimeApi(runtimeConfig)
.resolve("client::retries::ShouldAttempt"),
@ -195,46 +197,65 @@ internal class ConfigOverrideRuntimePluginGeneratorTest {
unitTest("test_operation_overrides_retry_strategy") {
rustTemplate(
"""
use #{ConfigBagAccessors};
use #{RuntimePlugin};
use ::aws_smithy_runtime_api::client::retries::RetryStrategy;
let client_config = crate::config::Config::builder()
.retry_config(#{RetryConfig}::standard().with_max_attempts(3))
.build();
let client_config_layer = client_config.config().unwrap();
let mut ctx = #{InterceptorContext}::new(#{TypeErasedBox}::new(()));
ctx.set_output_or_error(#{Err}(#{OrchestratorError}::other("doesn't matter")));
let mut layer = #{Layer}::new("test");
layer.store_put(#{RequestAttempts}::new(1));
layer.set_retry_classifiers(
#{RetryClassifiers}::new().with_classifier(#{AlwaysRetry}(#{ErrorKind}::TransientError)),
);
let mut cfg = #{ConfigBag}::of_layers(vec![layer]);
let client_config_layer = client_config.config;
cfg.push_shared_layer(client_config_layer.clone());
let retry = cfg.retry_strategy().unwrap();
let retry_classifiers_component = #{RuntimeComponentsBuilder}::new("retry_classifier")
.with_retry_classifiers(#{Some}(
#{RetryClassifiers}::new().with_classifier(#{AlwaysRetry}(#{ErrorKind}::TransientError)),
));
// Emulate the merging of runtime components from runtime plugins that the orchestrator does
let runtime_components = #{RuntimeComponentsBuilder}::for_tests()
.merge_from(&client_config.runtime_components)
.merge_from(&retry_classifiers_component)
.build()
.unwrap();
let retry = runtime_components.retry_strategy();
assert!(matches!(
retry.should_attempt_retry(&ctx, &cfg).unwrap(),
retry.should_attempt_retry(&ctx, &runtime_components, &cfg).unwrap(),
#{ShouldAttempt}::YesAfterDelay(_)
));
// sets `max_attempts` to 1 implicitly by using `disabled()`, forcing it to run out of
// attempts with respect to `RequestAttempts` set to 1 above
let config_override = crate::config::Config::builder()
let config_override_builder = crate::config::Config::builder()
.retry_config(#{RetryConfig}::disabled());
let sut = crate::config::ConfigOverrideRuntimePlugin {
client_config: client_config_layer,
config_override,
};
let config_override = config_override_builder.clone().build();
let sut = crate::config::ConfigOverrideRuntimePlugin::new(
config_override_builder,
client_config_layer,
&client_config.runtime_components,
);
let sut_layer = sut.config().unwrap();
cfg.push_shared_layer(sut_layer);
let retry = cfg.retry_strategy().unwrap();
// Emulate the merging of runtime components from runtime plugins that the orchestrator does
let runtime_components = #{RuntimeComponentsBuilder}::for_tests()
.merge_from(&client_config.runtime_components)
.merge_from(&retry_classifiers_component)
.merge_from(&config_override.runtime_components)
.build()
.unwrap();
let retry = runtime_components.retry_strategy();
assert!(matches!(
retry.should_attempt_retry(&ctx, &cfg).unwrap(),
retry.should_attempt_retry(&ctx, &runtime_components, &cfg).unwrap(),
#{ShouldAttempt}::No
));
""",

View File

@ -106,7 +106,7 @@ internal class ServiceConfigGeneratorTest {
"""
##[allow(missing_docs)]
pub fn config_field(&self) -> u64 {
self.inner.load::<#{T}>().map(|u| u.0).unwrap()
self.config.load::<#{T}>().map(|u| u.0).unwrap()
}
""",
"T" to configParamNewtype(
@ -137,7 +137,7 @@ internal class ServiceConfigGeneratorTest {
"""
##[allow(missing_docs)]
pub fn config_field(mut self, config_field: u64) -> Self {
self.inner.store_put(#{T}(config_field));
self.config.store_put(#{T}(config_field));
self
}
""",

View File

@ -31,7 +31,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolGenerat
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.util.CommandFailed
import software.amazon.smithy.rust.codegen.core.util.CommandError
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.outputShape
import java.nio.file.Path
@ -257,7 +257,7 @@ class ProtocolTestGeneratorTest {
@Test
fun `test incorrect response parsing`() {
val err = assertThrows<CommandFailed> {
val err = assertThrows<CommandError> {
testService(
"""
.uri("/?Hi=Hello%20there&required")
@ -273,7 +273,7 @@ class ProtocolTestGeneratorTest {
@Test
fun `test invalid body`() {
val err = assertThrows<CommandFailed> {
val err = assertThrows<CommandError> {
testService(
"""
.uri("/?Hi=Hello%20there&required")
@ -290,7 +290,7 @@ class ProtocolTestGeneratorTest {
@Test
fun `test invalid url parameter`() {
val err = assertThrows<CommandFailed> {
val err = assertThrows<CommandError> {
testService(
"""
.uri("/?Hi=INCORRECT&required")
@ -306,7 +306,7 @@ class ProtocolTestGeneratorTest {
@Test
fun `test forbidden url parameter`() {
val err = assertThrows<CommandFailed> {
val err = assertThrows<CommandError> {
testService(
"""
.uri("/?goodbye&Hi=Hello%20there&required")
@ -323,7 +323,7 @@ class ProtocolTestGeneratorTest {
@Test
fun `test required url parameter`() {
// Hard coded implementation for this 1 test
val err = assertThrows<CommandFailed> {
val err = assertThrows<CommandError> {
testService(
"""
.uri("/?Hi=Hello%20there")
@ -340,7 +340,7 @@ class ProtocolTestGeneratorTest {
@Test
fun `invalid header`() {
val err = assertThrows<CommandFailed> {
val err = assertThrows<CommandError> {
testService(
"""
.uri("/?Hi=Hello%20there&required")

View File

@ -338,8 +338,14 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null)
smithyTypes(runtimeConfig).resolve("config_bag::ConfigBag")
fun configBagAccessors(runtimeConfig: RuntimeConfig): RuntimeType =
smithyRuntimeApi(runtimeConfig).resolve("client::config_bag_accessors::ConfigBagAccessors")
fun runtimeComponentsBuilder(runtimeConfig: RuntimeConfig) =
smithyRuntimeApi(runtimeConfig).resolve("client::runtime_components::RuntimeComponentsBuilder")
fun runtimePlugins(runtimeConfig: RuntimeConfig): RuntimeType =
smithyRuntimeApi(runtimeConfig).resolve("client::runtime_plugin::RuntimePlugins")
fun runtimePlugin(runtimeConfig: RuntimeConfig) =
smithyRuntimeApi(runtimeConfig).resolve("client::runtime_plugin::RuntimePlugin")
fun sharedRuntimePlugin(runtimeConfig: RuntimeConfig) =
smithyRuntimeApi(runtimeConfig).resolve("client::runtime_plugin::SharedRuntimePlugin")
fun boxError(runtimeConfig: RuntimeConfig): RuntimeType =
smithyRuntimeApi(runtimeConfig).resolve("box_error::BoxError")
fun interceptor(runtimeConfig: RuntimeConfig): RuntimeType =
@ -468,8 +474,5 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null)
fun idempotencyToken(runtimeConfig: RuntimeConfig) =
forInlineDependency(InlineDependency.idempotencyToken(runtimeConfig))
fun runtimePlugin(runtimeConfig: RuntimeConfig) =
RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::runtime_plugin::RuntimePlugin")
}
}

View File

@ -32,7 +32,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.ModuleDocProvider
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.util.CommandFailed
import software.amazon.smithy.rust.codegen.core.util.CommandError
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.orNullIfEmpty
@ -389,7 +389,7 @@ fun RustWriter.compileAndTest(
println("Test sources for debugging: file://${testModule.absolutePath}")
}
return testOutput
} catch (e: CommandFailed) {
} catch (e: CommandError) {
if (!expectFailure) {
println("Test sources for debugging: file://${testModule.absolutePath}")
}

View File

@ -9,7 +9,7 @@ import java.io.IOException
import java.nio.file.Path
import java.util.concurrent.TimeUnit
data class CommandFailed(val output: String) : Exception("Command Failed\n$output")
data class CommandError(val output: String) : Exception("Command Error\n$output")
fun String.runCommand(workdir: Path? = null, environment: Map<String, String> = mapOf(), timeout: Long = 3600): String {
val parts = this.split("\\s".toRegex())
@ -30,13 +30,13 @@ fun String.runCommand(workdir: Path? = null, environment: Map<String, String> =
val output = "$stdErr\n$stdOut"
return when (proc.exitValue()) {
0 -> output
else -> throw CommandFailed("Command Failed\n$output")
else -> throw CommandError("Command Error\n$output")
}
} catch (_: IllegalThreadStateException) {
throw CommandFailed("Timeout")
throw CommandError("Timeout")
} catch (err: IOException) {
throw CommandFailed("$this was not a valid command.\n Hint: is everything installed?\n$err")
throw CommandError("$this was not a valid command.\n Hint: is everything installed?\n$err")
} catch (other: Exception) {
throw CommandFailed("Unexpected exception thrown when executing subprocess:\n$other")
throw CommandError("Unexpected exception thrown when executing subprocess:\n$other")
}
}

View File

@ -17,7 +17,7 @@ import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider
import software.amazon.smithy.rust.codegen.core.util.CommandFailed
import software.amazon.smithy.rust.codegen.core.util.CommandError
import software.amazon.smithy.rust.codegen.core.util.lookup
class RecursiveShapesIntegrationTest {
@ -61,7 +61,7 @@ class RecursiveShapesIntegrationTest {
project
}
val unmodifiedProject = check(model)
val output = assertThrows<CommandFailed> {
val output = assertThrows<CommandError> {
unmodifiedProject.compileAndTest(expectFailure = true)
}
// THIS IS A LOAD-BEARING shouldContain! If the compiler error changes then this will break!

View File

@ -11,7 +11,7 @@ import org.junit.jupiter.api.Test
internal class ExecKtTest {
@Test
fun `missing command throws CommandFailed`() {
shouldThrow<CommandFailed> {
shouldThrow<CommandError> {
"notaprogram run".runCommand()
}
}

View File

@ -45,7 +45,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolGenerat
import software.amazon.smithy.rust.codegen.core.smithy.transformers.EventStreamNormalizer
import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer
import software.amazon.smithy.rust.codegen.core.util.CommandFailed
import software.amazon.smithy.rust.codegen.core.util.CommandError
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.hasEventStreamMember
import software.amazon.smithy.rust.codegen.core.util.hasTrait
@ -266,7 +266,7 @@ open class ServerCodegenVisitor(
fileManifest.baseDir,
timeout = settings.codegenConfig.formatTimeoutSeconds.toLong(),
)
} catch (err: CommandFailed) {
} catch (err: CommandError) {
logger.info(
"[rust-server-codegen] Failed to run cargo fmt: [${service.id}]\n${err.output}",
)

View File

@ -20,7 +20,7 @@ import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
import software.amazon.smithy.rust.codegen.core.util.CommandFailed
import software.amazon.smithy.rust.codegen.core.util.CommandError
import software.amazon.smithy.rust.codegen.core.util.lookup
import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule
import software.amazon.smithy.rust.codegen.server.smithy.createTestInlineModuleCreator
@ -237,7 +237,7 @@ class ConstrainedStringGeneratorTest {
).render()
}
assertThrows<CommandFailed> {
assertThrows<CommandError> {
project.compileAndTest()
}
}

View File

@ -12,7 +12,6 @@ rt-tokio = ["tokio/time"]
test-util = []
[dependencies]
aws-smithy-types = { path = "../aws-smithy-types" }
pin-project-lite = "0.2"
tokio = { version = "1.23.1", features = ["sync"] }
tokio-stream = { version = "0.1.5", default-features = false }

View File

@ -6,7 +6,6 @@
//! Provides an [`AsyncSleep`] trait that returns a future that sleeps for a given duration,
//! and implementations of `AsyncSleep` for different async runtimes.
use aws_smithy_types::config_bag::{Storable, StoreReplace};
use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::pin::Pin;
@ -69,10 +68,6 @@ impl AsyncSleep for SharedAsyncSleep {
}
}
impl Storable for SharedAsyncSleep {
type Storer = StoreReplace<SharedAsyncSleep>;
}
#[cfg(feature = "rt-tokio")]
/// Returns a default sleep implementation based on the features enabled
pub fn default_async_sleep() -> Option<SharedAsyncSleep> {

View File

@ -4,7 +4,6 @@
*/
//! Time source abstraction to support WASM and testing
use aws_smithy_types::config_bag::{Storable, StoreReplace};
use std::fmt::Debug;
use std::sync::Arc;
use std::time::SystemTime;
@ -87,7 +86,3 @@ impl TimeSource for SharedTimeSource {
self.0.now()
}
}
impl Storable for SharedTimeSource {
type Storer = StoreReplace<SharedTimeSource>;
}

View File

@ -23,7 +23,7 @@ allowed_external_types = [
"tokio::io::async_write::AsyncWrite",
# TODO(https://github.com/awslabs/smithy-rs/issues/1193): Once tooling permits it, only allow the following types in the `test-utils` feature
# TODO(https://github.com/awslabs/smithy-rs/issues/1193): Once tooling permits it, only allow the following types in the `test-util` feature
"bytes::bytes::Bytes",
"serde::ser::Serialize",
"serde::de::Deserialize",

View File

@ -3,6 +3,8 @@
* SPDX-License-Identifier: Apache-2.0
*/
pub mod runtime_components;
/// Client orchestrator configuration accessors for the [`ConfigBag`](aws_smithy_types::config_bag::ConfigBag).
pub mod config_bag_accessors;

View File

@ -4,9 +4,10 @@
*/
use crate::box_error::BoxError;
use crate::client::identity::{Identity, IdentityResolvers, SharedIdentityResolver};
use crate::client::identity::{Identity, SharedIdentityResolver};
use crate::client::orchestrator::HttpRequest;
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreAppend, StoreReplace};
use crate::client::runtime_components::{GetIdentityResolver, RuntimeComponents};
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
use aws_smithy_types::type_erasure::{TypeErasedBox, TypedBox};
use aws_smithy_types::Document;
use std::borrow::Cow;
@ -66,20 +67,16 @@ pub trait AuthOptionResolver: Send + Sync + fmt::Debug {
) -> Result<Cow<'_, [AuthSchemeId]>, BoxError>;
}
#[derive(Debug)]
pub struct DynAuthOptionResolver(Box<dyn AuthOptionResolver>);
#[derive(Clone, Debug)]
pub struct SharedAuthOptionResolver(Arc<dyn AuthOptionResolver>);
impl DynAuthOptionResolver {
impl SharedAuthOptionResolver {
pub fn new(auth_option_resolver: impl AuthOptionResolver + 'static) -> Self {
Self(Box::new(auth_option_resolver))
Self(Arc::new(auth_option_resolver))
}
}
impl Storable for DynAuthOptionResolver {
type Storer = StoreReplace<Self>;
}
impl AuthOptionResolver for DynAuthOptionResolver {
impl AuthOptionResolver for SharedAuthOptionResolver {
fn resolve_auth_options(
&self,
params: &AuthOptionResolverParams,
@ -93,7 +90,7 @@ pub trait HttpAuthScheme: Send + Sync + fmt::Debug {
fn identity_resolver(
&self,
identity_resolvers: &IdentityResolvers,
identity_resolvers: &dyn GetIdentityResolver,
) -> Option<SharedIdentityResolver>;
fn request_signer(&self) -> &dyn HttpRequestSigner;
@ -117,7 +114,7 @@ impl HttpAuthScheme for SharedHttpAuthScheme {
fn identity_resolver(
&self,
identity_resolvers: &IdentityResolvers,
identity_resolvers: &dyn GetIdentityResolver,
) -> Option<SharedIdentityResolver> {
self.0.identity_resolver(identity_resolvers)
}
@ -127,10 +124,6 @@ impl HttpAuthScheme for SharedHttpAuthScheme {
}
}
impl Storable for SharedHttpAuthScheme {
type Storer = StoreAppend<Self>;
}
pub trait HttpRequestSigner: Send + Sync + fmt::Debug {
/// Return a signed version of the given request using the given identity.
///
@ -140,6 +133,7 @@ pub trait HttpRequestSigner: Send + Sync + fmt::Debug {
request: &mut HttpRequest,
identity: &Identity,
auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
runtime_components: &RuntimeComponents,
config_bag: &ConfigBag,
) -> Result<(), BoxError>;
}
@ -166,52 +160,3 @@ impl<'a> AuthSchemeEndpointConfig<'a> {
self.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use aws_smithy_types::config_bag::{ConfigBag, Layer};
#[test]
fn test_shared_http_auth_scheme_configuration() {
#[derive(Debug)]
struct TestHttpAuthScheme(&'static str);
impl HttpAuthScheme for TestHttpAuthScheme {
fn scheme_id(&self) -> AuthSchemeId {
AuthSchemeId::new(self.0)
}
fn identity_resolver(&self, _: &IdentityResolvers) -> Option<SharedIdentityResolver> {
unreachable!("this shouldn't get called in this test")
}
fn request_signer(&self) -> &dyn HttpRequestSigner {
unreachable!("this shouldn't get called in this test")
}
}
let mut config_bag = ConfigBag::base();
let mut layer = Layer::new("first");
layer.store_append(SharedHttpAuthScheme::new(TestHttpAuthScheme("scheme_1")));
config_bag.push_layer(layer);
let mut layer = Layer::new("second");
layer.store_append(SharedHttpAuthScheme::new(TestHttpAuthScheme("scheme_2")));
layer.store_append(SharedHttpAuthScheme::new(TestHttpAuthScheme("scheme_3")));
config_bag.push_layer(layer);
let auth_schemes = config_bag.load::<SharedHttpAuthScheme>();
let encountered_scheme_ids: Vec<AuthSchemeId> =
auth_schemes.map(|s| s.scheme_id()).collect();
assert_eq!(
vec![
AuthSchemeId::new("scheme_3"),
AuthSchemeId::new("scheme_2"),
AuthSchemeId::new("scheme_1")
],
encountered_scheme_ids
);
}
}

View File

@ -3,23 +3,12 @@
* SPDX-License-Identifier: Apache-2.0
*/
use crate::client::auth::{
AuthOptionResolver, AuthOptionResolverParams, AuthSchemeId, DynAuthOptionResolver,
SharedHttpAuthScheme,
};
use crate::client::connectors::{Connector, DynConnector};
use crate::client::identity::{
ConfiguredIdentityResolver, IdentityResolvers, SharedIdentityResolver,
};
use crate::client::auth::AuthOptionResolverParams;
use crate::client::orchestrator::{
DynEndpointResolver, DynResponseDeserializer, EndpointResolver, EndpointResolverParams,
LoadedRequestBody, ResponseDeserializer, SharedRequestSerializer, NOT_NEEDED,
DynResponseDeserializer, EndpointResolverParams, LoadedRequestBody, ResponseDeserializer,
SharedRequestSerializer, NOT_NEEDED,
};
use crate::client::retries::{DynRetryStrategy, RetryClassifiers, RetryStrategy};
use aws_smithy_async::rt::sleep::SharedAsyncSleep;
use aws_smithy_async::time::{SharedTimeSource, TimeSource};
use aws_smithy_types::config_bag::{AppendItemIter, CloneableLayer, ConfigBag, FrozenLayer, Layer};
use std::fmt::Debug;
use aws_smithy_types::config_bag::{CloneableLayer, ConfigBag, FrozenLayer, Layer};
// Place traits in a private module so that they can be used in the public API without being a part of the public API.
mod internal {
@ -60,51 +49,6 @@ mod internal {
}
}
pub trait CloneableSettable {
fn store_put<T>(&mut self, value: T)
where
T: Storable<Storer = StoreReplace<T>> + Clone;
fn store_append<T>(&mut self, item: T)
where
T: Storable<Storer = StoreAppend<T>> + Clone;
}
impl<S> CloneableSettable for S
where
S: Settable,
{
fn store_put<T>(&mut self, value: T)
where
T: Storable<Storer = StoreReplace<T>> + Clone,
{
Settable::store_put(self, value);
}
fn store_append<T>(&mut self, item: T)
where
T: Storable<Storer = StoreAppend<T>> + Clone,
{
Settable::store_append(self, item);
}
}
impl CloneableSettable for CloneableLayer {
fn store_put<T>(&mut self, value: T)
where
T: Storable<Storer = StoreReplace<T>> + Clone,
{
CloneableLayer::store_put(self, value);
}
fn store_append<T>(&mut self, item: T)
where
T: Storable<Storer = StoreAppend<T>> + Clone,
{
CloneableLayer::store_append(self, item);
}
}
pub trait Gettable {
fn load<T: Storable>(&self) -> <T::Storer as Store>::ReturnedType<'_>;
}
@ -134,7 +78,7 @@ mod internal {
}
}
use internal::{CloneableSettable, Gettable, Settable};
use internal::{Gettable, Settable};
pub trait ConfigBagAccessors {
fn auth_option_resolver_params(&self) -> &AuthOptionResolverParams
@ -153,21 +97,6 @@ pub trait ConfigBagAccessors {
self.store_put::<AuthOptionResolverParams>(auth_option_resolver_params);
}
fn auth_option_resolver(&self) -> &dyn AuthOptionResolver
where
Self: Gettable,
{
self.load::<DynAuthOptionResolver>()
.expect("an auth option resolver must be set")
}
fn set_auth_option_resolver(&mut self, auth_option_resolver: DynAuthOptionResolver)
where
Self: Settable,
{
self.store_put::<DynAuthOptionResolver>(auth_option_resolver);
}
fn endpoint_resolver_params(&self) -> &EndpointResolverParams
where
Self: Gettable,
@ -183,73 +112,6 @@ pub trait ConfigBagAccessors {
self.store_put::<EndpointResolverParams>(endpoint_resolver_params);
}
fn endpoint_resolver(&self) -> &dyn EndpointResolver
where
Self: Gettable,
{
self.load::<DynEndpointResolver>()
.expect("an endpoint resolver must be set")
}
fn set_endpoint_resolver(&mut self, endpoint_resolver: DynEndpointResolver)
where
Self: Settable,
{
self.store_put::<DynEndpointResolver>(endpoint_resolver);
}
/// Returns the configured identity resolvers.
fn identity_resolvers(&self) -> IdentityResolvers
where
Self: Gettable,
{
IdentityResolvers::new(self.load::<ConfiguredIdentityResolver>())
}
/// Adds an identity resolver to the config.
fn push_identity_resolver(
&mut self,
auth_scheme_id: AuthSchemeId,
identity_resolver: SharedIdentityResolver,
) where
Self: CloneableSettable,
{
self.store_append::<ConfiguredIdentityResolver>(ConfiguredIdentityResolver::new(
auth_scheme_id,
identity_resolver,
));
}
fn connector(&self) -> &dyn Connector
where
Self: Gettable,
{
self.load::<DynConnector>().expect("missing connector")
}
fn set_connector(&mut self, connection: DynConnector)
where
Self: Settable,
{
self.store_put::<DynConnector>(connection);
}
/// Returns the configured HTTP auth schemes.
fn http_auth_schemes(&self) -> HttpAuthSchemes<'_>
where
Self: Gettable,
{
HttpAuthSchemes::new(self.load::<SharedHttpAuthScheme>())
}
/// Adds a HTTP auth scheme to the config.
fn push_http_auth_scheme(&mut self, auth_scheme: SharedHttpAuthScheme)
where
Self: Settable,
{
self.store_append::<SharedHttpAuthScheme>(auth_scheme);
}
fn request_serializer(&self) -> SharedRequestSerializer
where
Self: Gettable,
@ -279,63 +141,6 @@ pub trait ConfigBagAccessors {
self.store_put::<DynResponseDeserializer>(response_deserializer);
}
fn retry_classifiers(&self) -> &RetryClassifiers
where
Self: Gettable,
{
self.load::<RetryClassifiers>()
.expect("retry classifiers must be set")
}
fn set_retry_classifiers(&mut self, retry_classifiers: RetryClassifiers)
where
Self: Settable,
{
self.store_put::<RetryClassifiers>(retry_classifiers);
}
fn retry_strategy(&self) -> Option<&dyn RetryStrategy>
where
Self: Gettable,
{
self.load::<DynRetryStrategy>().map(|rs| rs as _)
}
fn set_retry_strategy(&mut self, retry_strategy: DynRetryStrategy)
where
Self: Settable,
{
self.store_put::<DynRetryStrategy>(retry_strategy);
}
fn request_time(&self) -> Option<SharedTimeSource>
where
Self: Gettable,
{
self.load::<SharedTimeSource>().cloned()
}
fn set_request_time(&mut self, time_source: impl TimeSource + 'static)
where
Self: Settable,
{
self.store_put::<SharedTimeSource>(SharedTimeSource::new(time_source));
}
fn sleep_impl(&self) -> Option<SharedAsyncSleep>
where
Self: Gettable,
{
self.load::<SharedAsyncSleep>().cloned()
}
fn set_sleep_impl(&mut self, async_sleep: Option<SharedAsyncSleep>)
where
Self: Settable,
{
if let Some(sleep_impl) = async_sleep {
self.store_put::<SharedAsyncSleep>(sleep_impl);
} else {
self.unset::<SharedAsyncSleep>();
}
}
fn loaded_request_body(&self) -> &LoadedRequestBody
where
Self: Gettable,
@ -354,90 +159,3 @@ impl ConfigBagAccessors for ConfigBag {}
impl ConfigBagAccessors for FrozenLayer {}
impl ConfigBagAccessors for CloneableLayer {}
impl ConfigBagAccessors for Layer {}
/// Accessor for HTTP auth schemes.
#[derive(Debug)]
pub struct HttpAuthSchemes<'a> {
inner: AppendItemIter<'a, SharedHttpAuthScheme>,
}
impl<'a> HttpAuthSchemes<'a> {
pub(crate) fn new(inner: AppendItemIter<'a, SharedHttpAuthScheme>) -> Self {
Self { inner }
}
/// Returns the HTTP auth scheme with the given ID, if there is one.
pub fn scheme(mut self, scheme_id: AuthSchemeId) -> Option<SharedHttpAuthScheme> {
use crate::client::auth::HttpAuthScheme;
self.inner
.find(|&scheme| scheme.scheme_id() == scheme_id)
.cloned()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::auth::{HttpAuthScheme, HttpRequestSigner};
use crate::client::config_bag_accessors::ConfigBagAccessors;
use aws_smithy_types::config_bag::{ConfigBag, Layer};
#[test]
fn test_shared_http_auth_scheme_configuration() {
#[derive(Debug)]
struct TestHttpAuthScheme(&'static str);
impl HttpAuthScheme for TestHttpAuthScheme {
fn scheme_id(&self) -> AuthSchemeId {
AuthSchemeId::new(self.0)
}
fn identity_resolver(&self, _: &IdentityResolvers) -> Option<SharedIdentityResolver> {
unreachable!("this shouldn't get called in this test")
}
fn request_signer(&self) -> &dyn HttpRequestSigner {
unreachable!("this shouldn't get called in this test")
}
}
let mut config_bag = ConfigBag::base();
let mut layer = Layer::new("first");
layer.push_http_auth_scheme(SharedHttpAuthScheme::new(TestHttpAuthScheme("scheme_1")));
config_bag.push_layer(layer);
let mut layer = Layer::new("second");
layer.push_http_auth_scheme(SharedHttpAuthScheme::new(TestHttpAuthScheme("scheme_2")));
layer.push_http_auth_scheme(SharedHttpAuthScheme::new(TestHttpAuthScheme("scheme_3")));
config_bag.push_layer(layer);
assert!(config_bag
.http_auth_schemes()
.scheme(AuthSchemeId::new("does-not-exist"))
.is_none());
assert_eq!(
AuthSchemeId::new("scheme_1"),
config_bag
.http_auth_schemes()
.scheme(AuthSchemeId::new("scheme_1"))
.unwrap()
.scheme_id()
);
assert_eq!(
AuthSchemeId::new("scheme_2"),
config_bag
.http_auth_schemes()
.scheme(AuthSchemeId::new("scheme_2"))
.unwrap()
.scheme_id()
);
assert_eq!(
AuthSchemeId::new("scheme_3"),
config_bag
.http_auth_schemes()
.scheme(AuthSchemeId::new("scheme_3"))
.unwrap()
.scheme_id()
);
}
}

View File

@ -4,28 +4,24 @@
*/
use crate::client::orchestrator::{BoxFuture, HttpRequest, HttpResponse};
use aws_smithy_types::config_bag::{Storable, StoreReplace};
use std::fmt;
use std::sync::Arc;
pub trait Connector: Send + Sync + fmt::Debug {
fn call(&self, request: HttpRequest) -> BoxFuture<HttpResponse>;
}
#[derive(Debug)]
pub struct DynConnector(Box<dyn Connector>);
#[derive(Clone, Debug)]
pub struct SharedConnector(Arc<dyn Connector>);
impl DynConnector {
impl SharedConnector {
pub fn new(connection: impl Connector + 'static) -> Self {
Self(Box::new(connection))
Self(Arc::new(connection))
}
}
impl Connector for DynConnector {
impl Connector for SharedConnector {
fn call(&self, request: HttpRequest) -> BoxFuture<HttpResponse> {
(*self.0).call(request)
}
}
impl Storable for DynConnector {
type Storer = StoreReplace<Self>;
}

View File

@ -5,7 +5,7 @@
use crate::client::auth::AuthSchemeId;
use crate::client::orchestrator::Future;
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreAppend, StoreReplace};
use aws_smithy_types::config_bag::ConfigBag;
use std::any::Any;
use std::fmt;
use std::fmt::Debug;
@ -67,36 +67,6 @@ impl ConfiguredIdentityResolver {
}
}
impl Storable for ConfiguredIdentityResolver {
type Storer = StoreAppend<Self>;
}
#[derive(Clone, Debug, Default)]
pub struct IdentityResolvers {
identity_resolvers: Vec<ConfiguredIdentityResolver>,
}
impl Storable for IdentityResolvers {
type Storer = StoreReplace<IdentityResolvers>;
}
impl IdentityResolvers {
pub(crate) fn new<'a>(resolvers: impl Iterator<Item = &'a ConfiguredIdentityResolver>) -> Self {
let identity_resolvers: Vec<_> = resolvers.cloned().collect();
if identity_resolvers.is_empty() {
tracing::warn!("no identity resolvers available for this request");
}
Self { identity_resolvers }
}
pub fn identity_resolver(&self, scheme_id: AuthSchemeId) -> Option<SharedIdentityResolver> {
self.identity_resolvers
.iter()
.find(|pair| pair.scheme_id() == scheme_id)
.map(|pair| pair.identity_resolver())
}
}
#[derive(Clone)]
pub struct Identity {
data: Arc<dyn Any + Send + Sync>,

View File

@ -11,7 +11,8 @@ use crate::client::interceptors::context::{
BeforeTransmitInterceptorContextRef, FinalizerInterceptorContextMut,
FinalizerInterceptorContextRef, InterceptorContext,
};
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreAppend, StoreReplace};
use crate::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
use aws_smithy_types::error::display::DisplayErrorContext;
use context::{Error, Input, Output};
use std::fmt;
@ -27,16 +28,28 @@ pub use error::InterceptorError;
macro_rules! interceptor_trait_fn {
($name:ident, $phase:ident, $docs:tt) => {
#[doc = $docs]
fn $name(&self, context: &$phase<'_>, cfg: &mut ConfigBag) -> Result<(), BoxError> {
fn $name(
&self,
context: &$phase<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let _ctx = context;
let _rc = runtime_components;
let _cfg = cfg;
Ok(())
}
};
(mut $name:ident, $phase:ident, $docs:tt) => {
#[doc = $docs]
fn $name(&self, context: &mut $phase<'_>, cfg: &mut ConfigBag) -> Result<(), BoxError> {
fn $name(
&self,
context: &mut $phase<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let _ctx = context;
let _rc = runtime_components;
let _cfg = cfg;
Ok(())
}
@ -54,29 +67,32 @@ macro_rules! interceptor_trait_fn {
/// to read in-flight request or response messages, or "read/write" hooks, which make it possible
/// to modify in-flight request or output messages.
pub trait Interceptor: fmt::Debug {
interceptor_trait_fn!(
read_before_execution,
BeforeSerializationInterceptorContextRef,
"
A hook called at the start of an execution, before the SDK
does anything else.
**When:** This will **ALWAYS** be called once per execution. The duration
between invocation of this hook and `after_execution` is very close
to full duration of the execution.
**Available Information:** The [InterceptorContext::input()] is
**ALWAYS** available. Other information **WILL NOT** be available.
**Error Behavior:** Errors raised by this hook will be stored
until all interceptors have had their `before_execution` invoked.
Other hooks will then be skipped and execution will jump to
`modify_before_completion` with the raised error as the
[InterceptorContext::output_or_error()]. If multiple
`before_execution` methods raise errors, the latest
will be used and earlier ones will be logged and dropped.
"
);
/// A hook called at the start of an execution, before the SDK
/// does anything else.
///
/// **When:** This will **ALWAYS** be called once per execution. The duration
/// between invocation of this hook and `after_execution` is very close
/// to full duration of the execution.
///
/// **Available Information:** The [InterceptorContext::input()] is
/// **ALWAYS** available. Other information **WILL NOT** be available.
///
/// **Error Behavior:** Errors raised by this hook will be stored
/// until all interceptors have had their `before_execution` invoked.
/// Other hooks will then be skipped and execution will jump to
/// `modify_before_completion` with the raised error as the
/// [InterceptorContext::output_or_error()]. If multiple
/// `before_execution` methods raise errors, the latest
/// will be used and earlier ones will be logged and dropped.
fn read_before_execution(
&self,
context: &BeforeSerializationInterceptorContextRef<'_>,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let _ctx = context;
let _cfg = cfg;
Ok(())
}
interceptor_trait_fn!(
mut modify_before_serialization,
@ -96,7 +112,6 @@ pub trait Interceptor: fmt::Debug {
later hooks. Other information **WILL NOT** be available.
**Error Behavior:** If errors are raised by this hook,
execution will jump to `modify_before_completion` with the raised
error as the [InterceptorContext::output_or_error()].
@ -478,9 +493,11 @@ pub trait Interceptor: fmt::Debug {
fn modify_before_attempt_completion(
&self,
context: &mut FinalizerInterceptorContextMut<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let _ctx = context;
let _rc = runtime_components;
let _cfg = cfg;
Ok(())
}
@ -510,9 +527,11 @@ pub trait Interceptor: fmt::Debug {
fn read_after_attempt(
&self,
context: &FinalizerInterceptorContextRef<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let _ctx = context;
let _rc = runtime_components;
let _cfg = cfg;
Ok(())
}
@ -540,9 +559,11 @@ pub trait Interceptor: fmt::Debug {
fn modify_before_completion(
&self,
context: &mut FinalizerInterceptorContextMut<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let _ctx = context;
let _rc = runtime_components;
let _cfg = cfg;
Ok(())
}
@ -568,9 +589,11 @@ pub trait Interceptor: fmt::Debug {
fn read_after_execution(
&self,
context: &FinalizerInterceptorContextRef<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let _ctx = context;
let _rc = runtime_components;
let _cfg = cfg;
Ok(())
}
@ -607,18 +630,6 @@ impl SharedInterceptor {
}
}
/// A interceptor wrapper to conditionally enable the interceptor based on [`DisableInterceptor`]
struct ConditionallyEnabledInterceptor<'a>(&'a SharedInterceptor);
impl ConditionallyEnabledInterceptor<'_> {
fn if_enabled(&self, cfg: &ConfigBag) -> Option<&dyn Interceptor> {
if self.0.enabled(cfg) {
Some(self.0.as_ref())
} else {
None
}
}
}
impl AsRef<dyn Interceptor> for SharedInterceptor {
fn as_ref(&self) -> &(dyn Interceptor + 'static) {
self.interceptor.as_ref()
@ -632,45 +643,29 @@ impl Deref for SharedInterceptor {
}
}
impl Storable for SharedInterceptor {
type Storer = StoreAppend<SharedInterceptor>;
}
/// Collection of [`SharedInterceptor`] that allows for only registration
#[derive(Debug, Clone, Default)]
pub struct InterceptorRegistrar {
interceptors: Vec<SharedInterceptor>,
}
impl InterceptorRegistrar {
/// Register an interceptor with this `InterceptorRegistrar`.
///
/// When this `InterceptorRegistrar` is passed to an orchestrator, the orchestrator will run the
/// registered interceptor for all the "hooks" that it implements.
pub fn register(&mut self, interceptor: SharedInterceptor) {
self.interceptors.push(interceptor);
}
}
impl Extend<SharedInterceptor> for InterceptorRegistrar {
fn extend<T: IntoIterator<Item = SharedInterceptor>>(&mut self, iter: T) {
for interceptor in iter {
self.register(interceptor);
/// A interceptor wrapper to conditionally enable the interceptor based on [`DisableInterceptor`]
struct ConditionallyEnabledInterceptor(SharedInterceptor);
impl ConditionallyEnabledInterceptor {
fn if_enabled(&self, cfg: &ConfigBag) -> Option<&dyn Interceptor> {
if self.0.enabled(cfg) {
Some(self.0.as_ref())
} else {
None
}
}
}
#[derive(Debug, Clone, Default)]
pub struct Interceptors {
client_interceptors: InterceptorRegistrar,
operation_interceptors: InterceptorRegistrar,
#[derive(Debug)]
pub struct Interceptors<I> {
interceptors: I,
}
macro_rules! interceptor_impl_fn {
(mut $interceptor:ident) => {
pub fn $interceptor(
&self,
self,
ctx: &mut InterceptorContext,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), InterceptorError> {
tracing::trace!(concat!(
@ -680,9 +675,11 @@ macro_rules! interceptor_impl_fn {
));
let mut result: Result<(), BoxError> = Ok(());
let mut ctx = ctx.into();
for interceptor in self.interceptors() {
for interceptor in self.into_iter() {
if let Some(interceptor) = interceptor.if_enabled(cfg) {
if let Err(new_error) = interceptor.$interceptor(&mut ctx, cfg) {
if let Err(new_error) =
interceptor.$interceptor(&mut ctx, runtime_components, cfg)
{
if let Err(last_error) = result {
tracing::debug!("{}", DisplayErrorContext(&*last_error));
}
@ -695,15 +692,17 @@ macro_rules! interceptor_impl_fn {
};
(ref $interceptor:ident) => {
pub fn $interceptor(
&self,
self,
ctx: &InterceptorContext,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), InterceptorError> {
let mut result: Result<(), BoxError> = Ok(());
let ctx = ctx.into();
for interceptor in self.interceptors() {
for interceptor in self.into_iter() {
if let Some(interceptor) = interceptor.if_enabled(cfg) {
if let Err(new_error) = interceptor.$interceptor(&ctx, cfg) {
if let Err(new_error) = interceptor.$interceptor(&ctx, runtime_components, cfg)
{
if let Err(last_error) = result {
tracing::debug!("{}", DisplayErrorContext(&*last_error));
}
@ -742,68 +741,31 @@ pub fn disable_interceptor<T: Interceptor>(cause: &'static str) -> DisableInterc
}
}
impl Interceptors {
pub fn new() -> Self {
Self::default()
impl<I> Interceptors<I>
where
I: Iterator<Item = SharedInterceptor>,
{
pub fn new(interceptors: I) -> Self {
Self { interceptors }
}
fn interceptors(&self) -> impl Iterator<Item = ConditionallyEnabledInterceptor<'_>> {
self.client_interceptors()
.chain(self.operation_interceptors())
fn into_iter(self) -> impl Iterator<Item = ConditionallyEnabledInterceptor> {
self.interceptors.map(ConditionallyEnabledInterceptor)
}
fn client_interceptors(&self) -> impl Iterator<Item = ConditionallyEnabledInterceptor<'_>> {
self.client_interceptors
.interceptors
.iter()
.map(ConditionallyEnabledInterceptor)
}
fn operation_interceptors(&self) -> impl Iterator<Item = ConditionallyEnabledInterceptor<'_>> {
self.operation_interceptors
.interceptors
.iter()
.map(ConditionallyEnabledInterceptor)
}
pub fn client_interceptors_mut(&mut self) -> &mut InterceptorRegistrar {
&mut self.client_interceptors
}
pub fn operation_interceptors_mut(&mut self) -> &mut InterceptorRegistrar {
&mut self.operation_interceptors
}
pub fn client_read_before_execution(
&self,
pub fn read_before_execution(
self,
operation: bool,
ctx: &InterceptorContext<Input, Output, Error>,
cfg: &mut ConfigBag,
) -> Result<(), InterceptorError> {
tracing::trace!("running `client_read_before_execution` interceptors");
tracing::trace!(
"running {} `read_before_execution` interceptors",
if operation { "operation" } else { "client" }
);
let mut result: Result<(), BoxError> = Ok(());
let ctx: BeforeSerializationInterceptorContextRef<'_> = ctx.into();
for interceptor in self.client_interceptors() {
if let Some(interceptor) = interceptor.if_enabled(cfg) {
if let Err(new_error) = interceptor.read_before_execution(&ctx, cfg) {
if let Err(last_error) = result {
tracing::debug!("{}", DisplayErrorContext(&*last_error));
}
result = Err(new_error);
}
}
}
result.map_err(InterceptorError::read_before_execution)
}
pub fn operation_read_before_execution(
&self,
ctx: &InterceptorContext<Input, Output, Error>,
cfg: &mut ConfigBag,
) -> Result<(), InterceptorError> {
tracing::trace!("running `operation_read_before_execution` interceptors");
let mut result: Result<(), BoxError> = Ok(());
let ctx: BeforeSerializationInterceptorContextRef<'_> = ctx.into();
for interceptor in self.operation_interceptors() {
for interceptor in self.into_iter() {
if let Some(interceptor) = interceptor.if_enabled(cfg) {
if let Err(new_error) = interceptor.read_before_execution(&ctx, cfg) {
if let Err(last_error) = result {
@ -832,16 +794,18 @@ impl Interceptors {
interceptor_impl_fn!(ref read_after_deserialization);
pub fn modify_before_attempt_completion(
&self,
self,
ctx: &mut InterceptorContext<Input, Output, Error>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), InterceptorError> {
tracing::trace!("running `modify_before_attempt_completion` interceptors");
let mut result: Result<(), BoxError> = Ok(());
let mut ctx: FinalizerInterceptorContextMut<'_> = ctx.into();
for interceptor in self.interceptors() {
for interceptor in self.into_iter() {
if let Some(interceptor) = interceptor.if_enabled(cfg) {
if let Err(new_error) = interceptor.modify_before_attempt_completion(&mut ctx, cfg)
if let Err(new_error) =
interceptor.modify_before_attempt_completion(&mut ctx, runtime_components, cfg)
{
if let Err(last_error) = result {
tracing::debug!("{}", DisplayErrorContext(&*last_error));
@ -854,16 +818,19 @@ impl Interceptors {
}
pub fn read_after_attempt(
&self,
self,
ctx: &InterceptorContext<Input, Output, Error>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), InterceptorError> {
tracing::trace!("running `read_after_attempt` interceptors");
let mut result: Result<(), BoxError> = Ok(());
let ctx: FinalizerInterceptorContextRef<'_> = ctx.into();
for interceptor in self.interceptors() {
for interceptor in self.into_iter() {
if let Some(interceptor) = interceptor.if_enabled(cfg) {
if let Err(new_error) = interceptor.read_after_attempt(&ctx, cfg) {
if let Err(new_error) =
interceptor.read_after_attempt(&ctx, runtime_components, cfg)
{
if let Err(last_error) = result {
tracing::debug!("{}", DisplayErrorContext(&*last_error));
}
@ -875,16 +842,19 @@ impl Interceptors {
}
pub fn modify_before_completion(
&self,
self,
ctx: &mut InterceptorContext<Input, Output, Error>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), InterceptorError> {
tracing::trace!("running `modify_before_completion` interceptors");
let mut result: Result<(), BoxError> = Ok(());
let mut ctx: FinalizerInterceptorContextMut<'_> = ctx.into();
for interceptor in self.interceptors() {
for interceptor in self.into_iter() {
if let Some(interceptor) = interceptor.if_enabled(cfg) {
if let Err(new_error) = interceptor.modify_before_completion(&mut ctx, cfg) {
if let Err(new_error) =
interceptor.modify_before_completion(&mut ctx, runtime_components, cfg)
{
if let Err(last_error) = result {
tracing::debug!("{}", DisplayErrorContext(&*last_error));
}
@ -896,16 +866,19 @@ impl Interceptors {
}
pub fn read_after_execution(
&self,
self,
ctx: &InterceptorContext<Input, Output, Error>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), InterceptorError> {
tracing::trace!("running `read_after_execution` interceptors");
let mut result: Result<(), BoxError> = Ok(());
let ctx: FinalizerInterceptorContextRef<'_> = ctx.into();
for interceptor in self.interceptors() {
for interceptor in self.into_iter() {
if let Some(interceptor) = interceptor.if_enabled(cfg) {
if let Err(new_error) = interceptor.read_after_execution(&ctx, cfg) {
if let Err(new_error) =
interceptor.read_after_execution(&ctx, runtime_components, cfg)
{
if let Err(last_error) = result {
tracing::debug!("{}", DisplayErrorContext(&*last_error));
}
@ -917,35 +890,20 @@ impl Interceptors {
}
}
#[cfg(test)]
#[cfg(all(test, feature = "test-util"))]
mod tests {
use crate::client::interceptors::context::Input;
use crate::client::interceptors::{
disable_interceptor, BeforeTransmitInterceptorContextRef, BoxError, Interceptor,
InterceptorContext, InterceptorRegistrar, Interceptors, SharedInterceptor,
InterceptorContext, Interceptors, SharedInterceptor,
};
use crate::client::runtime_components::{RuntimeComponents, RuntimeComponentsBuilder};
use aws_smithy_types::config_bag::ConfigBag;
#[derive(Debug)]
struct TestInterceptor;
impl Interceptor for TestInterceptor {}
#[test]
fn register_interceptor() {
let mut registrar = InterceptorRegistrar::default();
registrar.register(SharedInterceptor::new(TestInterceptor));
assert_eq!(1, registrar.interceptors.len());
}
#[test]
fn bulk_register_interceptors() {
let mut registrar = InterceptorRegistrar::default();
let number_of_interceptors = 3;
let interceptors = vec![SharedInterceptor::new(TestInterceptor); number_of_interceptors];
registrar.extend(interceptors);
assert_eq!(number_of_interceptors, registrar.interceptors.len());
}
#[test]
fn test_disable_interceptors() {
#[derive(Debug)]
@ -954,42 +912,43 @@ mod tests {
fn read_before_transmit(
&self,
_context: &BeforeTransmitInterceptorContextRef<'_>,
_rc: &RuntimeComponents,
_cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
Err("boom".into())
}
}
let mut interceptors = Interceptors::new();
let interceptors_vec = vec![
SharedInterceptor::new(PanicInterceptor),
SharedInterceptor::new(TestInterceptor),
];
interceptors
.client_interceptors_mut()
.extend(interceptors_vec);
let rc = RuntimeComponentsBuilder::for_tests()
.with_interceptor(SharedInterceptor::new(PanicInterceptor))
.with_interceptor(SharedInterceptor::new(TestInterceptor))
.build()
.unwrap();
let mut cfg = ConfigBag::base();
let interceptors = Interceptors::new(rc.interceptors());
assert_eq!(
interceptors
.interceptors()
.into_iter()
.filter(|i| i.if_enabled(&cfg).is_some())
.count(),
2
);
interceptors
.read_before_transmit(&InterceptorContext::new(Input::new(5)), &mut cfg)
Interceptors::new(rc.interceptors())
.read_before_transmit(&InterceptorContext::new(Input::new(5)), &rc, &mut cfg)
.expect_err("interceptor returns error");
cfg.interceptor_state()
.store_put(disable_interceptor::<PanicInterceptor>("test"));
assert_eq!(
interceptors
.interceptors()
Interceptors::new(rc.interceptors())
.into_iter()
.filter(|i| i.if_enabled(&cfg).is_some())
.count(),
1
);
// shouldn't error because interceptors won't run
interceptors
.read_before_transmit(&InterceptorContext::new(Input::new(5)), &mut cfg)
Interceptors::new(rc.interceptors())
.read_before_transmit(&InterceptorContext::new(Input::new(5)), &rc, &mut cfg)
.expect("interceptor is now disabled");
}
}

View File

@ -114,25 +114,21 @@ pub trait EndpointResolver: Send + Sync + fmt::Debug {
fn resolve_endpoint(&self, params: &EndpointResolverParams) -> Future<Endpoint>;
}
#[derive(Debug)]
pub struct DynEndpointResolver(Box<dyn EndpointResolver>);
#[derive(Clone, Debug)]
pub struct SharedEndpointResolver(Arc<dyn EndpointResolver>);
impl DynEndpointResolver {
impl SharedEndpointResolver {
pub fn new(endpoint_resolver: impl EndpointResolver + 'static) -> Self {
Self(Box::new(endpoint_resolver))
Self(Arc::new(endpoint_resolver))
}
}
impl EndpointResolver for DynEndpointResolver {
impl EndpointResolver for SharedEndpointResolver {
fn resolve_endpoint(&self, params: &EndpointResolverParams) -> Future<Endpoint> {
self.0.resolve_endpoint(params)
}
}
impl Storable for DynEndpointResolver {
type Storer = StoreReplace<Self>;
}
/// Informs the orchestrator on whether or not the request body needs to be loaded into memory before transmit.
///
/// This enum gets placed into the `ConfigBag` to change the orchestrator behavior.

View File

@ -4,7 +4,7 @@
*/
use crate::client::interceptors::context::InterceptorContext;
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
use aws_smithy_types::config_bag::ConfigBag;
use std::fmt::Debug;
use std::time::Duration;
use tracing::trace;
@ -30,42 +30,50 @@ impl ShouldAttempt {
}
pub trait RetryStrategy: Send + Sync + Debug {
fn should_attempt_initial_request(&self, cfg: &ConfigBag) -> Result<ShouldAttempt, BoxError>;
fn should_attempt_initial_request(
&self,
runtime_components: &RuntimeComponents,
cfg: &ConfigBag,
) -> Result<ShouldAttempt, BoxError>;
fn should_attempt_retry(
&self,
context: &InterceptorContext,
runtime_components: &RuntimeComponents,
cfg: &ConfigBag,
) -> Result<ShouldAttempt, BoxError>;
}
#[derive(Debug)]
pub struct DynRetryStrategy(Box<dyn RetryStrategy>);
#[derive(Clone, Debug)]
pub struct SharedRetryStrategy(Arc<dyn RetryStrategy>);
impl DynRetryStrategy {
impl SharedRetryStrategy {
pub fn new(retry_strategy: impl RetryStrategy + 'static) -> Self {
Self(Box::new(retry_strategy))
Self(Arc::new(retry_strategy))
}
}
impl RetryStrategy for DynRetryStrategy {
fn should_attempt_initial_request(&self, cfg: &ConfigBag) -> Result<ShouldAttempt, BoxError> {
self.0.should_attempt_initial_request(cfg)
impl RetryStrategy for SharedRetryStrategy {
fn should_attempt_initial_request(
&self,
runtime_components: &RuntimeComponents,
cfg: &ConfigBag,
) -> Result<ShouldAttempt, BoxError> {
self.0
.should_attempt_initial_request(runtime_components, cfg)
}
fn should_attempt_retry(
&self,
context: &InterceptorContext,
runtime_components: &RuntimeComponents,
cfg: &ConfigBag,
) -> Result<ShouldAttempt, BoxError> {
self.0.should_attempt_retry(context, cfg)
self.0
.should_attempt_retry(context, runtime_components, cfg)
}
}
impl Storable for DynRetryStrategy {
type Storer = StoreReplace<Self>;
}
#[non_exhaustive]
#[derive(Clone, Eq, PartialEq, Debug)]
pub enum RetryReason {
@ -83,9 +91,9 @@ pub trait ClassifyRetry: Send + Sync + Debug {
fn name(&self) -> &'static str;
}
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct RetryClassifiers {
inner: Vec<Box<dyn ClassifyRetry>>,
inner: Vec<Arc<dyn ClassifyRetry>>,
}
impl RetryClassifiers {
@ -98,8 +106,7 @@ impl RetryClassifiers {
}
pub fn with_classifier(mut self, retry_classifier: impl ClassifyRetry + 'static) -> Self {
self.inner.push(Box::new(retry_classifier));
self.inner.push(Arc::new(retry_classifier));
self
}
@ -107,10 +114,6 @@ impl RetryClassifiers {
// pub fn map_classifiers(mut self, fun: Fn() -> RetryClassifiers)
}
impl Storable for RetryClassifiers {
type Storer = StoreReplace<Self>;
}
impl ClassifyRetry for RetryClassifiers {
fn classify_retry(&self, ctx: &InterceptorContext) -> Option<RetryReason> {
// return the first non-None result
@ -160,5 +163,7 @@ mod test_util {
}
use crate::box_error::BoxError;
use crate::client::runtime_components::RuntimeComponents;
use std::sync::Arc;
#[cfg(feature = "test-util")]
pub use test_util::AlwaysRetry;

View File

@ -0,0 +1,786 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
use crate::client::auth::{
AuthSchemeId, HttpAuthScheme, SharedAuthOptionResolver, SharedHttpAuthScheme,
};
use crate::client::connectors::SharedConnector;
use crate::client::identity::{ConfiguredIdentityResolver, SharedIdentityResolver};
use crate::client::interceptors::SharedInterceptor;
use crate::client::orchestrator::SharedEndpointResolver;
use crate::client::retries::{RetryClassifiers, SharedRetryStrategy};
use aws_smithy_async::rt::sleep::SharedAsyncSleep;
use aws_smithy_async::time::SharedTimeSource;
use std::fmt;
pub(crate) static EMPTY_RUNTIME_COMPONENTS_BUILDER: RuntimeComponentsBuilder =
RuntimeComponentsBuilder::new("empty");
/// Internal to `declare_runtime_components!`.
///
/// Merges a field from one builder into another.
macro_rules! merge {
(Option $other:ident . $name:ident => $self:ident) => {
$self.$name = $other.$name.clone().or($self.$name.take());
};
(Vec $other:ident . $name:ident => $self:ident) => {
if !$other.$name.is_empty() {
$self.$name.extend($other.$name.iter().cloned());
}
};
}
/// Internal to `declare_runtime_components!`.
///
/// This is used when creating the builder's `build` method
/// to populate each individual field value. The `required`/`atLeastOneRequired`
/// validations are performed here.
macro_rules! builder_field_value {
(Option $self:ident . $name:ident) => {
$self.$name
};
(Option $self:ident . $name:ident required) => {
$self.$name.ok_or(BuildError(concat!(
"the `",
stringify!($name),
"` runtime component is required"
)))?
};
(Vec $self:ident . $name:ident) => {
$self.$name
};
(Vec $self:ident . $name:ident atLeastOneRequired) => {{
if $self.$name.is_empty() {
return Err(BuildError(concat!(
"at least one `",
stringify!($name),
"` runtime component is required"
)));
}
$self.$name
}};
}
/// Internal to `declare_runtime_components!`.
///
/// Converts the field type from `Option<T>` or `Vec<T>` into `Option<Tracked<T>>` or `Vec<Tracked<T>>` respectively.
/// Also removes the `Option` wrapper for required fields in the non-builder struct.
macro_rules! runtime_component_field_type {
(Option $inner_type:ident) => {
Option<Tracked<$inner_type>>
};
(Option $inner_type:ident required) => {
Tracked<$inner_type>
};
(Vec $inner_type:ident) => {
Vec<Tracked<$inner_type>>
};
(Vec $inner_type:ident atLeastOneRequired) => {
Vec<Tracked<$inner_type>>
};
}
/// Internal to `declare_runtime_components!`.
///
/// Converts an `$outer_type` into an empty instantiation for that type.
/// This is needed since `Default::default()` can't be used in a `const` function,
/// and `RuntimeComponentsBuilder::new()` is `const`.
macro_rules! empty_builder_value {
(Option) => {
None
};
(Vec) => {
Vec::new()
};
}
/// Macro to define the structs for both `RuntimeComponents` and `RuntimeComponentsBuilder`.
///
/// This is a macro in order to keep the fields consistent between the two, and to automatically
/// update the `merge_from` and `build` methods when new components are added.
///
/// It also facilitates unit testing since the overall mechanism can be unit tested with different
/// fields that are easy to check in tests (testing with real components makes it hard
/// to tell that the correct component was selected when merging builders).
///
/// # Example usage
///
/// The two identifiers after "fields for" become the names of the struct and builder respectively.
/// Following that, all the fields are specified. Fields MUST be wrapped in `Option` or `Vec`.
/// To make a field required in the non-builder struct, add `#[required]` for `Option` fields, or
/// `#[atLeastOneRequired]` for `Vec` fields.
///
/// ```no_compile
/// declare_runtime_components! {
/// fields for TestRc and TestRcBuilder {
/// some_optional_string: Option<String>,
///
/// some_optional_vec: Vec<String>,
///
/// #[required]
/// some_required_string: Option<String>,
///
/// #[atLeastOneRequired]
/// some_required_vec: Vec<String>,
/// }
/// }
/// ```
macro_rules! declare_runtime_components {
(fields for $rc_name:ident and $builder_name:ident {
$($(#[$option:ident])? $field_name:ident : $outer_type:ident<$inner_type:ident> ,)+
}) => {
/// Components that can only be set in runtime plugins that the orchestrator uses directly to call an operation.
#[derive(Clone, Debug)]
pub struct $rc_name {
$($field_name: runtime_component_field_type!($outer_type $inner_type $($option)?),)+
}
#[derive(Clone, Debug, Default)]
pub struct $builder_name {
builder_name: &'static str,
$($field_name: $outer_type<Tracked<$inner_type>>,)+
}
impl $builder_name {
/// Creates a new builder.
///
/// Since multiple builders are merged together to make the final [`RuntimeComponents`],
/// all components added by this builder are associated with the given `name` so that
/// the origin of a component can be easily found when debugging.
pub const fn new(name: &'static str) -> Self {
Self {
builder_name: name,
$($field_name: empty_builder_value!($outer_type),)+
}
}
/// Merge in components from another builder.
pub fn merge_from(mut self, other: &Self) -> Self {
$(merge!($outer_type other.$field_name => self);)+
self
}
/// Builds [`RuntimeComponents`] from this builder.
pub fn build(self) -> Result<$rc_name, BuildError> {
Ok($rc_name {
$($field_name: builder_field_value!($outer_type self.$field_name $($option)?),)+
})
}
}
};
}
declare_runtime_components! {
fields for RuntimeComponents and RuntimeComponentsBuilder {
#[required]
auth_option_resolver: Option<SharedAuthOptionResolver>,
// A connector is not required since a client could technically only be used for presigning
connector: Option<SharedConnector>,
#[required]
endpoint_resolver: Option<SharedEndpointResolver>,
#[atLeastOneRequired]
http_auth_schemes: Vec<SharedHttpAuthScheme>,
#[atLeastOneRequired]
identity_resolvers: Vec<ConfiguredIdentityResolver>,
interceptors: Vec<SharedInterceptor>,
retry_classifiers: Option<RetryClassifiers>,
#[required]
retry_strategy: Option<SharedRetryStrategy>,
time_source: Option<SharedTimeSource>,
sleep_impl: Option<SharedAsyncSleep>,
}
}
impl RuntimeComponents {
/// Returns a builder for runtime components.
pub fn builder() -> RuntimeComponentsBuilder {
Default::default()
}
/// Returns the auth option resolver.
pub fn auth_option_resolver(&self) -> SharedAuthOptionResolver {
self.auth_option_resolver.value.clone()
}
/// Returns the connector.
pub fn connector(&self) -> Option<SharedConnector> {
self.connector.as_ref().map(|s| s.value.clone())
}
/// Returns the endpoint resolver.
pub fn endpoint_resolver(&self) -> SharedEndpointResolver {
self.endpoint_resolver.value.clone()
}
/// Returns the requested auth scheme if it is set.
pub fn http_auth_scheme(&self, scheme_id: AuthSchemeId) -> Option<SharedHttpAuthScheme> {
self.http_auth_schemes
.iter()
.find(|s| s.value.scheme_id() == scheme_id)
.map(|s| s.value.clone())
}
/// Returns an iterator over the interceptors.
pub fn interceptors(&self) -> impl Iterator<Item = SharedInterceptor> + '_ {
self.interceptors.iter().map(|s| s.value.clone())
}
/// Returns the retry classifiers.
pub fn retry_classifiers(&self) -> Option<&RetryClassifiers> {
self.retry_classifiers.as_ref().map(|s| &s.value)
}
/// Returns the retry strategy.
pub fn retry_strategy(&self) -> SharedRetryStrategy {
self.retry_strategy.value.clone()
}
/// Returns the async sleep implementation.
pub fn sleep_impl(&self) -> Option<SharedAsyncSleep> {
self.sleep_impl.as_ref().map(|s| s.value.clone())
}
/// Returns the time source.
pub fn time_source(&self) -> Option<SharedTimeSource> {
self.time_source.as_ref().map(|s| s.value.clone())
}
}
impl RuntimeComponentsBuilder {
/// Returns the auth option resolver.
pub fn auth_option_resolver(&self) -> Option<SharedAuthOptionResolver> {
self.auth_option_resolver.as_ref().map(|s| s.value.clone())
}
/// Sets the auth option resolver.
pub fn set_auth_option_resolver(
&mut self,
auth_option_resolver: Option<SharedAuthOptionResolver>,
) -> &mut Self {
self.auth_option_resolver =
auth_option_resolver.map(|r| Tracked::new(self.builder_name, r));
self
}
/// Sets the auth option resolver.
pub fn with_auth_option_resolver(
mut self,
auth_option_resolver: Option<SharedAuthOptionResolver>,
) -> Self {
self.set_auth_option_resolver(auth_option_resolver);
self
}
/// Returns the connector.
pub fn connector(&self) -> Option<SharedConnector> {
self.connector.as_ref().map(|s| s.value.clone())
}
/// Sets the connector.
pub fn set_connector(&mut self, connector: Option<SharedConnector>) -> &mut Self {
self.connector = connector.map(|c| Tracked::new(self.builder_name, c));
self
}
/// Sets the connector.
pub fn with_connector(mut self, connector: Option<SharedConnector>) -> Self {
self.set_connector(connector);
self
}
/// Returns the endpoint resolver.
pub fn endpoint_resolver(&self) -> Option<SharedEndpointResolver> {
self.endpoint_resolver.as_ref().map(|s| s.value.clone())
}
/// Sets the endpoint resolver.
pub fn set_endpoint_resolver(
&mut self,
endpoint_resolver: Option<SharedEndpointResolver>,
) -> &mut Self {
self.endpoint_resolver = endpoint_resolver.map(|s| Tracked::new(self.builder_name, s));
self
}
/// Sets the endpoint resolver.
pub fn with_endpoint_resolver(
mut self,
endpoint_resolver: Option<SharedEndpointResolver>,
) -> Self {
self.set_endpoint_resolver(endpoint_resolver);
self
}
/// Returns the HTTP auth schemes.
pub fn http_auth_schemes(&self) -> impl Iterator<Item = SharedHttpAuthScheme> + '_ {
self.http_auth_schemes.iter().map(|s| s.value.clone())
}
/// Adds a HTTP auth scheme.
pub fn push_http_auth_scheme(&mut self, auth_scheme: SharedHttpAuthScheme) -> &mut Self {
self.http_auth_schemes
.push(Tracked::new(self.builder_name, auth_scheme));
self
}
/// Adds a HTTP auth scheme.
pub fn with_http_auth_scheme(mut self, auth_scheme: SharedHttpAuthScheme) -> Self {
self.push_http_auth_scheme(auth_scheme);
self
}
/// Adds an identity resolver.
pub fn push_identity_resolver(
&mut self,
scheme_id: AuthSchemeId,
identity_resolver: SharedIdentityResolver,
) -> &mut Self {
self.identity_resolvers.push(Tracked::new(
self.builder_name,
ConfiguredIdentityResolver::new(scheme_id, identity_resolver),
));
self
}
/// Adds an identity resolver.
pub fn with_identity_resolver(
mut self,
scheme_id: AuthSchemeId,
identity_resolver: SharedIdentityResolver,
) -> Self {
self.push_identity_resolver(scheme_id, identity_resolver);
self
}
/// Returns the interceptors.
pub fn interceptors(&self) -> impl Iterator<Item = SharedInterceptor> + '_ {
self.interceptors.iter().map(|s| s.value.clone())
}
/// Adds all the given interceptors.
pub fn extend_interceptors(
&mut self,
interceptors: impl Iterator<Item = SharedInterceptor>,
) -> &mut Self {
self.interceptors
.extend(interceptors.map(|s| Tracked::new(self.builder_name, s)));
self
}
/// Adds an interceptor.
pub fn push_interceptor(&mut self, interceptor: SharedInterceptor) -> &mut Self {
self.interceptors
.push(Tracked::new(self.builder_name, interceptor));
self
}
/// Adds an interceptor.
pub fn with_interceptor(mut self, interceptor: SharedInterceptor) -> Self {
self.push_interceptor(interceptor);
self
}
/// Directly sets the interceptors and clears out any that were previously pushed.
pub fn set_interceptors(
&mut self,
interceptors: impl Iterator<Item = SharedInterceptor>,
) -> &mut Self {
self.interceptors.clear();
self.interceptors
.extend(interceptors.map(|s| Tracked::new(self.builder_name, s)));
self
}
/// Directly sets the interceptors and clears out any that were previously pushed.
pub fn with_interceptors(
mut self,
interceptors: impl Iterator<Item = SharedInterceptor>,
) -> Self {
self.set_interceptors(interceptors);
self
}
/// Returns the retry classifiers.
pub fn retry_classifiers(&self) -> Option<&RetryClassifiers> {
self.retry_classifiers.as_ref().map(|s| &s.value)
}
/// Sets the retry classifiers.
pub fn set_retry_classifiers(
&mut self,
retry_classifiers: Option<RetryClassifiers>,
) -> &mut Self {
self.retry_classifiers = retry_classifiers.map(|s| Tracked::new(self.builder_name, s));
self
}
/// Sets the retry classifiers.
pub fn with_retry_classifiers(mut self, retry_classifiers: Option<RetryClassifiers>) -> Self {
self.retry_classifiers = retry_classifiers.map(|s| Tracked::new(self.builder_name, s));
self
}
/// Returns the retry strategy.
pub fn retry_strategy(&self) -> Option<SharedRetryStrategy> {
self.retry_strategy.as_ref().map(|s| s.value.clone())
}
/// Sets the retry strategy.
pub fn set_retry_strategy(&mut self, retry_strategy: Option<SharedRetryStrategy>) -> &mut Self {
self.retry_strategy = retry_strategy.map(|s| Tracked::new(self.builder_name, s));
self
}
/// Sets the retry strategy.
pub fn with_retry_strategy(mut self, retry_strategy: Option<SharedRetryStrategy>) -> Self {
self.retry_strategy = retry_strategy.map(|s| Tracked::new(self.builder_name, s));
self
}
/// Returns the async sleep implementation.
pub fn sleep_impl(&self) -> Option<SharedAsyncSleep> {
self.sleep_impl.as_ref().map(|s| s.value.clone())
}
/// Sets the async sleep implementation.
pub fn set_sleep_impl(&mut self, sleep_impl: Option<SharedAsyncSleep>) -> &mut Self {
self.sleep_impl = sleep_impl.map(|s| Tracked::new(self.builder_name, s));
self
}
/// Sets the async sleep implementation.
pub fn with_sleep_impl(mut self, sleep_impl: Option<SharedAsyncSleep>) -> Self {
self.sleep_impl = sleep_impl.map(|s| Tracked::new(self.builder_name, s));
self
}
/// Returns the time source.
pub fn time_source(&self) -> Option<SharedTimeSource> {
self.time_source.as_ref().map(|s| s.value.clone())
}
/// Sets the time source.
pub fn set_time_source(&mut self, time_source: Option<SharedTimeSource>) -> &mut Self {
self.time_source = time_source.map(|s| Tracked::new(self.builder_name, s));
self
}
/// Sets the time source.
pub fn with_time_source(mut self, time_source: Option<SharedTimeSource>) -> Self {
self.time_source = time_source.map(|s| Tracked::new(self.builder_name, s));
self
}
}
#[derive(Clone, Debug)]
#[cfg_attr(test, derive(Eq, PartialEq))]
struct Tracked<T> {
_origin: &'static str,
value: T,
}
impl<T> Tracked<T> {
fn new(origin: &'static str, value: T) -> Self {
Self {
_origin: origin,
value,
}
}
}
impl RuntimeComponentsBuilder {
/// Creates a runtime components builder with all the required components filled in with fake (panicking) implementations.
#[cfg(feature = "test-util")]
pub fn for_tests() -> Self {
use crate::client::auth::AuthOptionResolver;
use crate::client::connectors::Connector;
use crate::client::identity::Identity;
use crate::client::identity::IdentityResolver;
use crate::client::orchestrator::{EndpointResolver, EndpointResolverParams, Future};
use crate::client::retries::RetryStrategy;
use aws_smithy_async::rt::sleep::AsyncSleep;
use aws_smithy_async::time::TimeSource;
use aws_smithy_types::config_bag::ConfigBag;
use aws_smithy_types::endpoint::Endpoint;
#[derive(Debug)]
struct FakeAuthOptionResolver;
impl AuthOptionResolver for FakeAuthOptionResolver {
fn resolve_auth_options(
&self,
_: &crate::client::auth::AuthOptionResolverParams,
) -> Result<std::borrow::Cow<'_, [AuthSchemeId]>, crate::box_error::BoxError>
{
unreachable!("fake auth option resolver must be overridden for this test")
}
}
#[derive(Debug)]
struct FakeConnector;
impl Connector for FakeConnector {
fn call(
&self,
_: crate::client::orchestrator::HttpRequest,
) -> crate::client::orchestrator::BoxFuture<crate::client::orchestrator::HttpResponse>
{
unreachable!("fake connector must be overridden for this test")
}
}
#[derive(Debug)]
struct FakeEndpointResolver;
impl EndpointResolver for FakeEndpointResolver {
fn resolve_endpoint(&self, _: &EndpointResolverParams) -> Future<Endpoint> {
unreachable!("fake endpoint resolver must be overridden for this test")
}
}
#[derive(Debug)]
struct FakeHttpAuthScheme;
impl HttpAuthScheme for FakeHttpAuthScheme {
fn scheme_id(&self) -> AuthSchemeId {
AuthSchemeId::new("fake")
}
fn identity_resolver(
&self,
_: &dyn GetIdentityResolver,
) -> Option<SharedIdentityResolver> {
None
}
fn request_signer(&self) -> &dyn crate::client::auth::HttpRequestSigner {
unreachable!("fake http auth scheme must be overridden for this test")
}
}
#[derive(Debug)]
struct FakeIdentityResolver;
impl IdentityResolver for FakeIdentityResolver {
fn resolve_identity(&self, _: &ConfigBag) -> Future<Identity> {
unreachable!("fake identity resolver must be overridden for this test")
}
}
#[derive(Debug)]
struct FakeRetryStrategy;
impl RetryStrategy for FakeRetryStrategy {
fn should_attempt_initial_request(
&self,
_: &RuntimeComponents,
_: &ConfigBag,
) -> Result<crate::client::retries::ShouldAttempt, crate::box_error::BoxError>
{
unreachable!("fake retry strategy must be overridden for this test")
}
fn should_attempt_retry(
&self,
_: &crate::client::interceptors::context::InterceptorContext,
_: &RuntimeComponents,
_: &ConfigBag,
) -> Result<crate::client::retries::ShouldAttempt, crate::box_error::BoxError>
{
unreachable!("fake retry strategy must be overridden for this test")
}
}
#[derive(Debug)]
struct FakeTimeSource;
impl TimeSource for FakeTimeSource {
fn now(&self) -> std::time::SystemTime {
unreachable!("fake time source must be overridden for this test")
}
}
#[derive(Debug)]
struct FakeSleep;
impl AsyncSleep for FakeSleep {
fn sleep(&self, _: std::time::Duration) -> aws_smithy_async::rt::sleep::Sleep {
unreachable!("fake sleep must be overridden for this test")
}
}
Self::new("aws_smithy_runtime_api::client::runtime_components::RuntimeComponentBuilder::for_tests")
.with_auth_option_resolver(Some(SharedAuthOptionResolver::new(FakeAuthOptionResolver)))
.with_connector(Some(SharedConnector::new(FakeConnector)))
.with_endpoint_resolver(Some(SharedEndpointResolver::new(FakeEndpointResolver)))
.with_http_auth_scheme(SharedHttpAuthScheme::new(FakeHttpAuthScheme))
.with_identity_resolver(AuthSchemeId::new("fake"), SharedIdentityResolver::new(FakeIdentityResolver))
.with_retry_classifiers(Some(RetryClassifiers::new()))
.with_retry_strategy(Some(SharedRetryStrategy::new(FakeRetryStrategy)))
.with_time_source(Some(SharedTimeSource::new(FakeTimeSource)))
.with_sleep_impl(Some(SharedAsyncSleep::new(FakeSleep)))
}
}
#[derive(Debug)]
pub struct BuildError(&'static str);
impl std::error::Error for BuildError {}
impl fmt::Display for BuildError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
/// A trait for retrieving a shared identity resolver.
///
/// This trait exists so that [`HttpAuthScheme::identity_resolver`](crate::client::auth::HttpAuthScheme::identity_resolver)
/// can have access to configured identity resolvers without having access to all the runtime components.
pub trait GetIdentityResolver: Send + Sync {
/// Returns the requested identity resolver if it is set.
fn identity_resolver(&self, scheme_id: AuthSchemeId) -> Option<SharedIdentityResolver>;
}
impl GetIdentityResolver for RuntimeComponents {
fn identity_resolver(&self, scheme_id: AuthSchemeId) -> Option<SharedIdentityResolver> {
self.identity_resolvers
.iter()
.find(|s| s.value.scheme_id() == scheme_id)
.map(|s| s.value.identity_resolver())
}
}
#[cfg(all(test, feature = "test-util"))]
mod tests {
use super::*;
#[test]
#[allow(unreachable_pub)]
#[allow(dead_code)]
fn the_builders_should_merge() {
declare_runtime_components! {
fields for TestRc and TestRcBuilder {
#[required]
some_required_string: Option<String>,
some_optional_string: Option<String>,
#[atLeastOneRequired]
some_required_vec: Vec<String>,
some_optional_vec: Vec<String>,
}
}
let builder1 = TestRcBuilder {
builder_name: "builder1",
some_required_string: Some(Tracked::new("builder1", "override_me".into())),
some_optional_string: Some(Tracked::new("builder1", "override_me optional".into())),
some_required_vec: vec![Tracked::new("builder1", "first".into())],
some_optional_vec: vec![Tracked::new("builder1", "first optional".into())],
};
let builder2 = TestRcBuilder {
builder_name: "builder2",
some_required_string: Some(Tracked::new("builder2", "override_me_too".into())),
some_optional_string: Some(Tracked::new("builder2", "override_me_too optional".into())),
some_required_vec: vec![Tracked::new("builder2", "second".into())],
some_optional_vec: vec![Tracked::new("builder2", "second optional".into())],
};
let builder3 = TestRcBuilder {
builder_name: "builder3",
some_required_string: Some(Tracked::new("builder3", "correct".into())),
some_optional_string: Some(Tracked::new("builder3", "correct optional".into())),
some_required_vec: vec![Tracked::new("builder3", "third".into())],
some_optional_vec: vec![Tracked::new("builder3", "third optional".into())],
};
let rc = TestRcBuilder::new("root")
.merge_from(&builder1)
.merge_from(&builder2)
.merge_from(&builder3)
.build()
.expect("success");
assert_eq!(
Tracked::new("builder3", "correct".to_string()),
rc.some_required_string
);
assert_eq!(
Some(Tracked::new("builder3", "correct optional".to_string())),
rc.some_optional_string
);
assert_eq!(
vec![
Tracked::new("builder1", "first".to_string()),
Tracked::new("builder2", "second".into()),
Tracked::new("builder3", "third".into())
],
rc.some_required_vec
);
assert_eq!(
vec![
Tracked::new("builder1", "first optional".to_string()),
Tracked::new("builder2", "second optional".into()),
Tracked::new("builder3", "third optional".into())
],
rc.some_optional_vec
);
}
#[test]
#[allow(unreachable_pub)]
#[allow(dead_code)]
#[should_panic(expected = "the `_some_string` runtime component is required")]
fn require_field_singular() {
declare_runtime_components! {
fields for TestRc and TestRcBuilder {
#[required]
_some_string: Option<String>,
}
}
let rc = TestRcBuilder::new("test").build().unwrap();
// Ensure the correct types were used
let _: Tracked<String> = rc._some_string;
}
#[test]
#[allow(unreachable_pub)]
#[allow(dead_code)]
#[should_panic(expected = "at least one `_some_vec` runtime component is required")]
fn require_field_plural() {
declare_runtime_components! {
fields for TestRc and TestRcBuilder {
#[atLeastOneRequired]
_some_vec: Vec<String>,
}
}
let rc = TestRcBuilder::new("test").build().unwrap();
// Ensure the correct types were used
let _: Vec<Tracked<String>> = rc._some_vec;
}
#[test]
#[allow(unreachable_pub)]
#[allow(dead_code)]
fn optional_fields_dont_panic() {
declare_runtime_components! {
fields for TestRc and TestRcBuilder {
_some_optional_string: Option<String>,
_some_optional_vec: Vec<String>,
}
}
let rc = TestRcBuilder::new("test").build().unwrap();
// Ensure the correct types were used
let _: Option<Tracked<String>> = rc._some_optional_string;
let _: Vec<Tracked<String>> = rc._some_optional_vec;
}
#[test]
fn building_test_builder_should_not_panic() {
let _ = RuntimeComponentsBuilder::for_tests().build(); // should not panic
}
}

View File

@ -4,31 +4,34 @@
*/
use crate::box_error::BoxError;
use crate::client::interceptors::InterceptorRegistrar;
use crate::client::runtime_components::{
RuntimeComponentsBuilder, EMPTY_RUNTIME_COMPONENTS_BUILDER,
};
use aws_smithy_types::config_bag::{ConfigBag, FrozenLayer};
use std::borrow::Cow;
use std::fmt::Debug;
use std::sync::Arc;
/// RuntimePlugin Trait
///
/// A RuntimePlugin is the unit of configuration for augmenting the SDK with new behavior
/// A RuntimePlugin is the unit of configuration for augmenting the SDK with new behavior.
///
/// Runtime plugins can set configuration and register interceptors.
/// Runtime plugins can register interceptors, set runtime components, and modify configuration.
pub trait RuntimePlugin: Debug + Send + Sync {
fn config(&self) -> Option<FrozenLayer> {
None
}
fn interceptors(&self, interceptors: &mut InterceptorRegistrar) {
let _ = interceptors;
fn runtime_components(&self) -> Cow<'_, RuntimeComponentsBuilder> {
Cow::Borrowed(&EMPTY_RUNTIME_COMPONENTS_BUILDER)
}
}
#[derive(Debug, Clone)]
struct SharedRuntimePlugin(Arc<dyn RuntimePlugin>);
pub struct SharedRuntimePlugin(Arc<dyn RuntimePlugin>);
impl SharedRuntimePlugin {
fn new(plugin: impl RuntimePlugin + 'static) -> Self {
pub fn new(plugin: impl RuntimePlugin + 'static) -> Self {
Self(Arc::new(plugin))
}
}
@ -38,8 +41,8 @@ impl RuntimePlugin for SharedRuntimePlugin {
self.0.config()
}
fn interceptors(&self, interceptors: &mut InterceptorRegistrar) {
self.0.interceptors(interceptors)
fn runtime_components(&self) -> Cow<'_, RuntimeComponentsBuilder> {
self.0.runtime_components()
}
}
@ -68,33 +71,66 @@ impl RuntimePlugins {
pub fn apply_client_configuration(
&self,
cfg: &mut ConfigBag,
interceptors: &mut InterceptorRegistrar,
) -> Result<(), BoxError> {
) -> Result<RuntimeComponentsBuilder, BoxError> {
tracing::trace!("applying client runtime plugins");
let mut builder = RuntimeComponentsBuilder::new("apply_client_configuration");
for plugin in self.client_plugins.iter() {
if let Some(layer) = plugin.config() {
cfg.push_shared_layer(layer);
}
plugin.interceptors(interceptors);
builder = builder.merge_from(&plugin.runtime_components());
}
Ok(())
Ok(builder)
}
pub fn apply_operation_configuration(
&self,
cfg: &mut ConfigBag,
interceptors: &mut InterceptorRegistrar,
) -> Result<(), BoxError> {
) -> Result<RuntimeComponentsBuilder, BoxError> {
tracing::trace!("applying operation runtime plugins");
let mut builder = RuntimeComponentsBuilder::new("apply_operation_configuration");
for plugin in self.operation_plugins.iter() {
if let Some(layer) = plugin.config() {
cfg.push_shared_layer(layer);
}
plugin.interceptors(interceptors);
builder = builder.merge_from(&plugin.runtime_components());
}
Ok(builder)
}
}
Ok(())
#[derive(Default, Debug)]
pub struct StaticRuntimePlugin {
config: Option<FrozenLayer>,
runtime_components: Option<RuntimeComponentsBuilder>,
}
impl StaticRuntimePlugin {
pub fn new() -> Self {
Default::default()
}
pub fn with_config(mut self, config: FrozenLayer) -> Self {
self.config = Some(config);
self
}
pub fn with_runtime_components(mut self, runtime_components: RuntimeComponentsBuilder) -> Self {
self.runtime_components = Some(runtime_components);
self
}
}
impl RuntimePlugin for StaticRuntimePlugin {
fn config(&self) -> Option<FrozenLayer> {
self.config.clone()
}
fn runtime_components(&self) -> Cow<'_, RuntimeComponentsBuilder> {
self.runtime_components
.as_ref()
.map(Cow::Borrowed)
.unwrap_or_else(|| RuntimePlugin::runtime_components(self))
}
}

View File

@ -14,6 +14,9 @@ pub mod auth;
/// By default, the orchestrator uses a connector provided by `hyper`.
pub mod connectors;
/// Utility to simplify config building for config and config overrides.
pub mod config_override;
/// The client orchestrator implementation
pub mod orchestrator;

View File

@ -13,10 +13,9 @@ use aws_smithy_runtime_api::client::auth::{
AuthSchemeEndpointConfig, AuthSchemeId, HttpAuthScheme, HttpRequestSigner,
};
use aws_smithy_runtime_api::client::identity::http::{Login, Token};
use aws_smithy_runtime_api::client::identity::{
Identity, IdentityResolvers, SharedIdentityResolver,
};
use aws_smithy_runtime_api::client::identity::{Identity, SharedIdentityResolver};
use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
use aws_smithy_runtime_api::client::runtime_components::{GetIdentityResolver, RuntimeComponents};
use aws_smithy_types::base64::encode;
use aws_smithy_types::config_bag::ConfigBag;
use http::header::HeaderName;
@ -59,7 +58,7 @@ impl HttpAuthScheme for ApiKeyAuthScheme {
fn identity_resolver(
&self,
identity_resolvers: &IdentityResolvers,
identity_resolvers: &dyn GetIdentityResolver,
) -> Option<SharedIdentityResolver> {
identity_resolvers.identity_resolver(self.scheme_id())
}
@ -82,6 +81,7 @@ impl HttpRequestSigner for ApiKeySigner {
request: &mut HttpRequest,
identity: &Identity,
_auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
_runtime_components: &RuntimeComponents,
_config_bag: &ConfigBag,
) -> Result<(), BoxError> {
let api_key = identity
@ -129,7 +129,7 @@ impl HttpAuthScheme for BasicAuthScheme {
fn identity_resolver(
&self,
identity_resolvers: &IdentityResolvers,
identity_resolvers: &dyn GetIdentityResolver,
) -> Option<SharedIdentityResolver> {
identity_resolvers.identity_resolver(self.scheme_id())
}
@ -148,6 +148,7 @@ impl HttpRequestSigner for BasicAuthSigner {
request: &mut HttpRequest,
identity: &Identity,
_auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
_runtime_components: &RuntimeComponents,
_config_bag: &ConfigBag,
) -> Result<(), BoxError> {
let login = identity
@ -187,7 +188,7 @@ impl HttpAuthScheme for BearerAuthScheme {
fn identity_resolver(
&self,
identity_resolvers: &IdentityResolvers,
identity_resolvers: &dyn GetIdentityResolver,
) -> Option<SharedIdentityResolver> {
identity_resolvers.identity_resolver(self.scheme_id())
}
@ -206,6 +207,7 @@ impl HttpRequestSigner for BearerAuthSigner {
request: &mut HttpRequest,
identity: &Identity,
_auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
_runtime_components: &RuntimeComponents,
_config_bag: &ConfigBag,
) -> Result<(), BoxError> {
let token = identity
@ -243,7 +245,7 @@ impl HttpAuthScheme for DigestAuthScheme {
fn identity_resolver(
&self,
identity_resolvers: &IdentityResolvers,
identity_resolvers: &dyn GetIdentityResolver,
) -> Option<SharedIdentityResolver> {
identity_resolvers.identity_resolver(self.scheme_id())
}
@ -262,6 +264,7 @@ impl HttpRequestSigner for DigestAuthSigner {
_request: &mut HttpRequest,
_identity: &Identity,
_auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
_runtime_components: &RuntimeComponents,
_config_bag: &ConfigBag,
) -> Result<(), BoxError> {
unimplemented!(
@ -275,6 +278,7 @@ mod tests {
use super::*;
use aws_smithy_http::body::SdkBody;
use aws_smithy_runtime_api::client::identity::http::Login;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
#[test]
fn test_api_key_signing_headers() {
@ -283,6 +287,7 @@ mod tests {
location: ApiKeyLocation::Header,
name: "some-header-name".into(),
};
let runtime_components = RuntimeComponentsBuilder::for_tests().build().unwrap();
let config_bag = ConfigBag::base();
let identity = Identity::new(Token::new("some-token", None), None);
let mut request = http::Request::builder()
@ -294,6 +299,7 @@ mod tests {
&mut request,
&identity,
AuthSchemeEndpointConfig::empty(),
&runtime_components,
&config_bag,
)
.expect("success");
@ -311,6 +317,7 @@ mod tests {
location: ApiKeyLocation::Query,
name: "some-query-name".into(),
};
let runtime_components = RuntimeComponentsBuilder::for_tests().build().unwrap();
let config_bag = ConfigBag::base();
let identity = Identity::new(Token::new("some-token", None), None);
let mut request = http::Request::builder()
@ -322,6 +329,7 @@ mod tests {
&mut request,
&identity,
AuthSchemeEndpointConfig::empty(),
&runtime_components,
&config_bag,
)
.expect("success");
@ -335,6 +343,7 @@ mod tests {
#[test]
fn test_basic_auth() {
let signer = BasicAuthSigner;
let runtime_components = RuntimeComponentsBuilder::for_tests().build().unwrap();
let config_bag = ConfigBag::base();
let identity = Identity::new(Login::new("Aladdin", "open sesame", None), None);
let mut request = http::Request::builder().body(SdkBody::empty()).unwrap();
@ -344,6 +353,7 @@ mod tests {
&mut request,
&identity,
AuthSchemeEndpointConfig::empty(),
&runtime_components,
&config_bag,
)
.expect("success");
@ -358,6 +368,7 @@ mod tests {
let signer = BearerAuthSigner;
let config_bag = ConfigBag::base();
let runtime_components = RuntimeComponentsBuilder::for_tests().build().unwrap();
let identity = Identity::new(Token::new("some-token", None), None);
let mut request = http::Request::builder().body(SdkBody::empty()).unwrap();
signer
@ -365,6 +376,7 @@ mod tests {
&mut request,
&identity,
AuthSchemeEndpointConfig::empty(),
&runtime_components,
&config_bag,
)
.expect("success");

View File

@ -10,13 +10,14 @@ use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::auth::{
AuthSchemeEndpointConfig, AuthSchemeId, HttpAuthScheme, HttpRequestSigner, SharedHttpAuthScheme,
};
use aws_smithy_runtime_api::client::config_bag_accessors::ConfigBagAccessors;
use aws_smithy_runtime_api::client::identity::{
Identity, IdentityResolvers, SharedIdentityResolver,
};
use aws_smithy_runtime_api::client::identity::{Identity, SharedIdentityResolver};
use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
use aws_smithy_runtime_api::client::runtime_components::{
GetIdentityResolver, RuntimeComponents, RuntimeComponentsBuilder,
};
use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin;
use aws_smithy_types::config_bag::{ConfigBag, FrozenLayer, Layer};
use aws_smithy_types::config_bag::ConfigBag;
use std::borrow::Cow;
pub const NO_AUTH_SCHEME_ID: AuthSchemeId = AuthSchemeId::new("no_auth");
@ -26,7 +27,7 @@ pub const NO_AUTH_SCHEME_ID: AuthSchemeId = AuthSchemeId::new("no_auth");
/// a Smithy `@optionalAuth` trait.
#[non_exhaustive]
#[derive(Debug)]
pub struct NoAuthRuntimePlugin(FrozenLayer);
pub struct NoAuthRuntimePlugin(RuntimeComponentsBuilder);
impl Default for NoAuthRuntimePlugin {
fn default() -> Self {
@ -36,19 +37,20 @@ impl Default for NoAuthRuntimePlugin {
impl NoAuthRuntimePlugin {
pub fn new() -> Self {
let mut cfg = Layer::new("NoAuth");
cfg.push_identity_resolver(
NO_AUTH_SCHEME_ID,
SharedIdentityResolver::new(NoAuthIdentityResolver::new()),
);
cfg.push_http_auth_scheme(SharedHttpAuthScheme::new(NoAuthScheme::new()));
Self(cfg.freeze())
Self(
RuntimeComponentsBuilder::new("NoAuthRuntimePlugin")
.with_identity_resolver(
NO_AUTH_SCHEME_ID,
SharedIdentityResolver::new(NoAuthIdentityResolver::new()),
)
.with_http_auth_scheme(SharedHttpAuthScheme::new(NoAuthScheme::new())),
)
}
}
impl RuntimePlugin for NoAuthRuntimePlugin {
fn config(&self) -> Option<FrozenLayer> {
Some(self.0.clone())
fn runtime_components(&self) -> Cow<'_, RuntimeComponentsBuilder> {
Cow::Borrowed(&self.0)
}
}
@ -72,6 +74,7 @@ impl HttpRequestSigner for NoAuthSigner {
_request: &mut HttpRequest,
_identity: &Identity,
_auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
_runtime_components: &RuntimeComponents,
_config_bag: &ConfigBag,
) -> Result<(), BoxError> {
Ok(())
@ -85,7 +88,7 @@ impl HttpAuthScheme for NoAuthScheme {
fn identity_resolver(
&self,
identity_resolvers: &IdentityResolvers,
identity_resolvers: &dyn GetIdentityResolver,
) -> Option<SharedIdentityResolver> {
identity_resolvers.identity_resolver(NO_AUTH_SCHEME_ID)
}

View File

@ -0,0 +1,254 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
use aws_smithy_async::rt::sleep::SharedAsyncSleep;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
use aws_smithy_types::config_bag::{FrozenLayer, Layer, Storable, Store, StoreReplace};
macro_rules! component {
($typ:ty, $accessor:ident, $latest_accessor:ident) => {
pub fn $accessor(&self) -> Option<$typ> {
fallback_component!(self, $typ, $accessor)
}
pub fn $latest_accessor(&self) -> Option<$typ> {
latest_component!(self, $typ, $accessor)
}
};
}
macro_rules! fallback_component {
($self:ident, $typ:ty, $accessor:ident) => {
match &$self.inner {
Inner::Initial(initial) => initial.components.$accessor(),
Inner::Override(overrid) => overrid
.components
.$accessor()
.or_else(|| overrid.initial_components.$accessor()),
}
};
}
macro_rules! latest_component {
($self:ident, $typ:ty, $accessor:ident) => {
match &$self.inner {
Inner::Initial(initial) => initial.components.$accessor(),
Inner::Override(overrid) => overrid.components.$accessor(),
}
};
}
struct Initial<'a> {
config: &'a mut Layer,
components: &'a mut RuntimeComponentsBuilder,
}
struct Override<'a> {
initial_config: FrozenLayer,
initial_components: &'a RuntimeComponentsBuilder,
config: &'a mut Layer,
components: &'a mut RuntimeComponentsBuilder,
}
enum Inner<'a> {
Initial(Initial<'a>),
Override(Override<'a>),
}
/// Utility to simplify config building and config overrides.
///
/// The resolver allows the same initialization logic to be reused
/// for both initial config and override config.
///
/// This resolver can be initialized to one of two modes:
/// 1. _Initial mode_: The resolver is being used in a service `Config` builder's `build()` method, and thus,
/// there is no config override at this point.
/// 2. _Override mode_: The resolver is being used by the `ConfigOverrideRuntimePlugin`'s constructor and needs
/// to incorporate both the original config and the given config override for this operation.
///
/// In all the methods on [`Resolver`], the term "latest" refers to the initial config when in _Initial mode_,
/// and to config override when in _Override mode_.
pub struct Resolver<'a> {
inner: Inner<'a>,
}
impl<'a> Resolver<'a> {
/// Construct a new [`Resolver`] in _initial mode_.
pub fn initial(config: &'a mut Layer, components: &'a mut RuntimeComponentsBuilder) -> Self {
Self {
inner: Inner::Initial(Initial { config, components }),
}
}
/// Construct a new [`Resolver`] in _override mode_.
pub fn overrid(
initial_config: FrozenLayer,
initial_components: &'a RuntimeComponentsBuilder,
config: &'a mut Layer,
components: &'a mut RuntimeComponentsBuilder,
) -> Self {
Self {
inner: Inner::Override(Override {
initial_config,
initial_components,
config,
components,
}),
}
}
/// Returns true if in _initial mode_.
pub fn is_initial(&self) -> bool {
matches!(self.inner, Inner::Initial(_))
}
/// Returns a mutable reference to the latest config.
pub fn config_mut(&mut self) -> &mut Layer {
match &mut self.inner {
Inner::Initial(initial) => initial.config,
Inner::Override(overrid) => overrid.config,
}
}
/// Returns a mutable reference to the latest runtime components.
pub fn runtime_components_mut(&mut self) -> &mut RuntimeComponentsBuilder {
match &mut self.inner {
Inner::Initial(initial) => initial.components,
Inner::Override(overrid) => overrid.components,
}
}
/// Returns true if the latest config has `T` set.
///
/// The "latest" is initial for `Resolver::Initial`, and override for `Resolver::Override`.
pub fn is_latest_set<T>(&self) -> bool
where
T: Storable<Storer = StoreReplace<T>>,
{
self.config().load::<T>().is_some()
}
/// Returns true if `T` is set anywhere.
pub fn is_set<T>(&self) -> bool
where
T: Storable<Storer = StoreReplace<T>>,
{
match &self.inner {
Inner::Initial(initial) => initial.config.load::<T>().is_some(),
Inner::Override(overrid) => {
overrid.initial_config.load::<T>().is_some() || overrid.config.load::<T>().is_some()
}
}
}
/// Resolves the value `T` with fallback
pub fn resolve_config<T>(&self) -> <T::Storer as Store>::ReturnedType<'_>
where
T: Storable<Storer = StoreReplace<T>>,
{
let mut maybe_value = self.config().load::<T>();
if maybe_value.is_none() {
// Try to fallback
if let Inner::Override(overrid) = &self.inner {
maybe_value = overrid.initial_config.load::<T>()
}
}
maybe_value
}
// Add additional component methods as needed
component!(SharedAsyncSleep, sleep_impl, latest_sleep_impl);
fn config(&self) -> &Layer {
match &self.inner {
Inner::Initial(initial) => initial.config,
Inner::Override(overrid) => overrid.config,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use aws_smithy_types::config_bag::CloneableLayer;
#[derive(Clone, Debug)]
struct TestStorable(String);
impl Storable for TestStorable {
type Storer = StoreReplace<Self>;
}
#[test]
fn initial_mode_config() {
let mut config = Layer::new("test");
let mut components = RuntimeComponentsBuilder::new("test");
let mut resolver = Resolver::initial(&mut config, &mut components);
assert!(resolver.is_initial());
assert!(!resolver.is_latest_set::<TestStorable>());
assert!(!resolver.is_set::<TestStorable>());
assert!(resolver.resolve_config::<TestStorable>().is_none());
resolver.config_mut().store_put(TestStorable("test".into()));
assert!(resolver.is_latest_set::<TestStorable>());
assert!(resolver.is_set::<TestStorable>());
assert_eq!("test", resolver.resolve_config::<TestStorable>().unwrap().0);
}
#[test]
fn override_mode_config() {
let mut initial_config = CloneableLayer::new("initial");
let initial_components = RuntimeComponentsBuilder::new("initial");
let mut config = Layer::new("override");
let mut components = RuntimeComponentsBuilder::new("override");
let resolver = Resolver::overrid(
initial_config.clone().freeze(),
&initial_components,
&mut config,
&mut components,
);
assert!(!resolver.is_initial());
assert!(!resolver.is_latest_set::<TestStorable>());
assert!(!resolver.is_set::<TestStorable>());
assert!(resolver.resolve_config::<TestStorable>().is_none());
initial_config.store_put(TestStorable("test".into()));
let resolver = Resolver::overrid(
initial_config.clone().freeze(),
&initial_components,
&mut config,
&mut components,
);
assert!(!resolver.is_latest_set::<TestStorable>());
assert!(resolver.is_set::<TestStorable>());
assert_eq!("test", resolver.resolve_config::<TestStorable>().unwrap().0);
initial_config.unset::<TestStorable>();
config.store_put(TestStorable("test".into()));
let resolver = Resolver::overrid(
initial_config.clone().freeze(),
&initial_components,
&mut config,
&mut components,
);
assert!(resolver.is_latest_set::<TestStorable>());
assert!(resolver.is_set::<TestStorable>());
assert_eq!("test", resolver.resolve_config::<TestStorable>().unwrap().0);
initial_config.store_put(TestStorable("override me".into()));
config.store_put(TestStorable("override".into()));
let resolver = Resolver::overrid(
initial_config.freeze(),
&initial_components,
&mut config,
&mut components,
);
assert!(resolver.is_latest_set::<TestStorable>());
assert!(resolver.is_set::<TestStorable>());
assert_eq!(
"override",
resolver.resolve_config::<TestStorable>().unwrap().0
);
}
}

View File

@ -5,12 +5,12 @@
use aws_smithy_http::connection::{CaptureSmithyConnection, ConnectionMetadata};
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::config_bag_accessors::ConfigBagAccessors;
use aws_smithy_runtime_api::client::interceptors::context::{
BeforeDeserializationInterceptorContextMut, BeforeTransmitInterceptorContextMut,
};
use aws_smithy_runtime_api::client::interceptors::Interceptor;
use aws_smithy_runtime_api::client::retries::{ClassifyRetry, RetryReason};
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
use aws_smithy_types::retry::{ErrorKind, ReconnectMode, RetryConfig};
use std::fmt;
@ -43,6 +43,7 @@ impl Interceptor for ConnectionPoisoningInterceptor {
fn modify_before_transmit(
&self,
context: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let capture_smithy_connection = CaptureSmithyConnectionWrapper::new();
@ -58,6 +59,7 @@ impl Interceptor for ConnectionPoisoningInterceptor {
fn modify_before_deserialization(
&self,
context: &mut BeforeDeserializationInterceptorContextMut<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let reconnect_mode = cfg
@ -65,7 +67,9 @@ impl Interceptor for ConnectionPoisoningInterceptor {
.map(RetryConfig::reconnect_mode)
.unwrap_or(ReconnectMode::ReconnectOnTransientError);
let captured_connection = cfg.load::<CaptureSmithyConnectionWrapper>().cloned();
let retry_classifiers = cfg.retry_classifiers();
let retry_classifiers = runtime_components
.retry_classifiers()
.ok_or("retry classifiers are required for connection poisoning to work")?;
let error_is_transient = retry_classifiers
.classify_retry(context.into_inner())

View File

@ -8,6 +8,7 @@ use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut;
use aws_smithy_runtime_api::client::interceptors::Interceptor;
use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::ConfigBag;
use std::error::Error as StdError;
use std::fmt;
@ -41,6 +42,7 @@ where
fn modify_before_signing(
&self,
context: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
_cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let mut request = HttpRequest::new(SdkBody::taken());
@ -75,6 +77,7 @@ where
fn modify_before_signing(
&self,
context: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
_cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let request = context.request_mut();

View File

@ -9,22 +9,24 @@
use self::auth::orchestrate_auth;
use crate::client::orchestrator::endpoints::orchestrate_endpoint;
use crate::client::orchestrator::http::read_body;
use crate::client::timeout::{MaybeTimeout, ProvideMaybeTimeoutConfig, TimeoutKind};
use crate::client::timeout::{MaybeTimeout, MaybeTimeoutConfig, TimeoutKind};
use aws_smithy_async::rt::sleep::AsyncSleep;
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::byte_stream::ByteStream;
use aws_smithy_http::result::SdkError;
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::config_bag_accessors::ConfigBagAccessors;
use aws_smithy_runtime_api::client::connectors::Connector;
use aws_smithy_runtime_api::client::interceptors::context::{
Error, Input, InterceptorContext, Output, RewindResult,
};
use aws_smithy_runtime_api::client::interceptors::Interceptors;
use aws_smithy_runtime_api::client::orchestrator::{
HttpResponse, LoadedRequestBody, OrchestratorError, RequestSerializer,
DynResponseDeserializer, HttpResponse, LoadedRequestBody, OrchestratorError, RequestSerializer,
ResponseDeserializer, SharedRequestSerializer,
};
use aws_smithy_runtime_api::client::request_attempts::RequestAttempts;
use aws_smithy_runtime_api::client::retries::ShouldAttempt;
use aws_smithy_runtime_api::client::retries::{RetryStrategy, ShouldAttempt};
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugins;
use aws_smithy_types::config_bag::ConfigBag;
use std::mem;
@ -62,6 +64,24 @@ macro_rules! continue_on_err {
};
}
macro_rules! run_interceptors {
(continue_on_err: { $($interceptor:ident($ctx:ident, $rc:ident, $cfg:ident);)+ }) => {
$(run_interceptors!(continue_on_err: $interceptor($ctx, $rc, $cfg));)+
};
(continue_on_err: $interceptor:ident($ctx:ident, $rc:ident, $cfg:ident)) => {
continue_on_err!([$ctx] => run_interceptors!(__private $interceptor($ctx, $rc, $cfg)))
};
(halt_on_err: { $($interceptor:ident($ctx:ident, $rc:ident, $cfg:ident);)+ }) => {
$(run_interceptors!(halt_on_err: $interceptor($ctx, $rc, $cfg));)+
};
(halt_on_err: $interceptor:ident($ctx:ident, $rc:ident, $cfg:ident)) => {
halt_on_err!([$ctx] => run_interceptors!(__private $interceptor($ctx, $rc, $cfg)))
};
(__private $interceptor:ident($ctx:ident, $rc:ident, $cfg:ident)) => {
Interceptors::new($rc.interceptors()).$interceptor($ctx, $rc, $cfg)
};
}
pub async fn invoke(
service_name: &str,
operation_name: &str,
@ -101,24 +121,25 @@ pub async fn invoke_with_stop_point(
let mut cfg = ConfigBag::base();
let cfg = &mut cfg;
let mut interceptors = Interceptors::new();
let mut ctx = InterceptorContext::new(input);
if let Err(err) = apply_configuration(&mut ctx, cfg, &mut interceptors, runtime_plugins) {
return Err(SdkError::construction_failure(err));
}
let operation_timeout_config = cfg.maybe_timeout_config(TimeoutKind::Operation);
let runtime_components = apply_configuration(&mut ctx, cfg, runtime_plugins)
.map_err(SdkError::construction_failure)?;
trace!(runtime_components = ?runtime_components);
let operation_timeout_config =
MaybeTimeoutConfig::new(&runtime_components, cfg, TimeoutKind::Operation);
trace!(operation_timeout_config = ?operation_timeout_config);
async {
// If running the pre-execution interceptors failed, then we skip running the op and run the
// final interceptors instead.
if !ctx.is_failed() {
try_op(&mut ctx, cfg, &interceptors, stop_point).await;
try_op(&mut ctx, cfg, &runtime_components, stop_point).await;
}
finally_op(&mut ctx, cfg, &interceptors).await;
finally_op(&mut ctx, cfg, &runtime_components).await;
Ok(ctx)
}
.maybe_timeout_with_config(operation_timeout_config)
.maybe_timeout(operation_timeout_config)
.await
}
.instrument(debug_span!("invoke", service = %service_name, operation = %operation_name))
@ -132,42 +153,49 @@ pub async fn invoke_with_stop_point(
fn apply_configuration(
ctx: &mut InterceptorContext,
cfg: &mut ConfigBag,
interceptors: &mut Interceptors,
runtime_plugins: &RuntimePlugins,
) -> Result<(), BoxError> {
runtime_plugins.apply_client_configuration(cfg, interceptors.client_interceptors_mut())?;
continue_on_err!([ctx] => interceptors.client_read_before_execution(ctx, cfg));
) -> Result<RuntimeComponents, BoxError> {
let client_rc_builder = runtime_plugins.apply_client_configuration(cfg)?;
continue_on_err!([ctx] => Interceptors::new(client_rc_builder.interceptors()).read_before_execution(false, ctx, cfg));
runtime_plugins
.apply_operation_configuration(cfg, interceptors.operation_interceptors_mut())?;
continue_on_err!([ctx] => interceptors.operation_read_before_execution(ctx, cfg));
let operation_rc_builder = runtime_plugins.apply_operation_configuration(cfg)?;
continue_on_err!([ctx] => Interceptors::new(operation_rc_builder.interceptors()).read_before_execution(true, ctx, cfg));
Ok(())
// The order below is important. Client interceptors must run before operation interceptors.
Ok(RuntimeComponents::builder()
.merge_from(&client_rc_builder)
.merge_from(&operation_rc_builder)
.build()?)
}
#[instrument(skip_all)]
async fn try_op(
ctx: &mut InterceptorContext,
cfg: &mut ConfigBag,
interceptors: &Interceptors,
runtime_components: &RuntimeComponents,
stop_point: StopPoint,
) {
// Before serialization
halt_on_err!([ctx] => interceptors.read_before_serialization(ctx, cfg));
halt_on_err!([ctx] => interceptors.modify_before_serialization(ctx, cfg));
run_interceptors!(halt_on_err: {
read_before_serialization(ctx, runtime_components, cfg);
modify_before_serialization(ctx, runtime_components, cfg);
});
// Serialization
ctx.enter_serialization_phase();
{
let _span = debug_span!("serialization").entered();
let request_serializer = cfg.request_serializer();
let request_serializer = cfg
.load::<SharedRequestSerializer>()
.expect("request serializer must be in the config bag")
.clone();
let input = ctx.take_input().expect("input set at this point");
let request = halt_on_err!([ctx] => request_serializer.serialize_input(input, cfg).map_err(OrchestratorError::other));
ctx.set_request(request);
}
// Load the request body into memory if configured to do so
if let LoadedRequestBody::Requested = cfg.loaded_request_body() {
if let Some(&LoadedRequestBody::Requested) = cfg.load::<LoadedRequestBody>() {
debug!("loading request body into memory");
let mut body = SdkBody::taken();
mem::swap(&mut body, ctx.request_mut().expect("set above").body_mut());
@ -175,20 +203,21 @@ async fn try_op(
*ctx.request_mut().as_mut().expect("set above").body_mut() =
SdkBody::from(loaded_body.clone());
cfg.interceptor_state()
.set_loaded_request_body(LoadedRequestBody::Loaded(loaded_body));
.store_put(LoadedRequestBody::Loaded(loaded_body));
}
// Before transmit
ctx.enter_before_transmit_phase();
halt_on_err!([ctx] => interceptors.read_after_serialization(ctx, cfg));
halt_on_err!([ctx] => interceptors.modify_before_retry_loop(ctx, cfg));
run_interceptors!(halt_on_err: {
read_after_serialization(ctx, runtime_components, cfg);
modify_before_retry_loop(ctx, runtime_components, cfg);
});
let retry_strategy = cfg.retry_strategy();
// If we got a retry strategy from the bag, ask it what to do.
// Otherwise, assume we should attempt the initial request.
let should_attempt = retry_strategy
.map(|rs| rs.should_attempt_initial_request(cfg))
.unwrap_or(Ok(ShouldAttempt::Yes));
let should_attempt = runtime_components
.retry_strategy()
.should_attempt_initial_request(runtime_components, cfg);
match should_attempt {
// Yes, let's make a request
Ok(ShouldAttempt::Yes) => debug!("retry strategy has OKed initial request"),
@ -200,7 +229,7 @@ async fn try_op(
// No, we shouldn't make a request because...
Err(err) => halt!([ctx] => OrchestratorError::other(err)),
Ok(ShouldAttempt::YesAfterDelay(delay)) => {
let sleep_impl = halt_on_err!([ctx] => cfg.sleep_impl().ok_or(OrchestratorError::other(
let sleep_impl = halt_on_err!([ctx] => runtime_components.sleep_impl().ok_or_else(|| OrchestratorError::other(
"the retry strategy requested a delay before sending the initial request, but no 'async sleep' implementation was set"
)));
debug!("retry strategy has OKed initial request after a {delay:?} delay");
@ -228,31 +257,28 @@ async fn try_op(
debug!("delaying for {delay:?}");
sleep.await;
}
let attempt_timeout_config = cfg.maybe_timeout_config(TimeoutKind::OperationAttempt);
let attempt_timeout_config =
MaybeTimeoutConfig::new(runtime_components, cfg, TimeoutKind::OperationAttempt);
trace!(attempt_timeout_config = ?attempt_timeout_config);
let maybe_timeout = async {
debug!("beginning attempt #{i}");
try_attempt(ctx, cfg, interceptors, stop_point).await;
finally_attempt(ctx, cfg, interceptors).await;
try_attempt(ctx, cfg, runtime_components, stop_point).await;
finally_attempt(ctx, cfg, runtime_components).await;
Result::<_, SdkError<Error, HttpResponse>>::Ok(())
}
.maybe_timeout_with_config(attempt_timeout_config)
.maybe_timeout(attempt_timeout_config)
.await
.map_err(|err| OrchestratorError::timeout(err.into_source().unwrap()));
// We continue when encountering a timeout error. The retry classifier will decide what to do with it.
continue_on_err!([ctx] => maybe_timeout);
let retry_strategy = cfg.retry_strategy();
// If we got a retry strategy from the bag, ask it what to do.
// If no strategy was set, we won't retry.
let should_attempt = match retry_strategy {
Some(retry_strategy) => halt_on_err!(
[ctx] => retry_strategy.should_attempt_retry(ctx, cfg).map_err(OrchestratorError::other)
),
None => ShouldAttempt::No,
};
let should_attempt = halt_on_err!([ctx] => runtime_components
.retry_strategy()
.should_attempt_retry(ctx, runtime_components, cfg)
.map_err(OrchestratorError::other));
match should_attempt {
// Yes, let's retry the request
ShouldAttempt::Yes => continue,
@ -262,7 +288,7 @@ async fn try_op(
break;
}
ShouldAttempt::YesAfterDelay(delay) => {
let sleep_impl = halt_on_err!([ctx] => cfg.sleep_impl().ok_or(OrchestratorError::other(
let sleep_impl = halt_on_err!([ctx] => runtime_components.sleep_impl().ok_or_else(|| OrchestratorError::other(
"the retry strategy requested a delay before sending the retry request, but no 'async sleep' implementation was set"
)));
retry_delay = Some((delay, sleep_impl.sleep(delay)));
@ -276,21 +302,25 @@ async fn try_op(
async fn try_attempt(
ctx: &mut InterceptorContext,
cfg: &mut ConfigBag,
interceptors: &Interceptors,
runtime_components: &RuntimeComponents,
stop_point: StopPoint,
) {
halt_on_err!([ctx] => interceptors.read_before_attempt(ctx, cfg));
run_interceptors!(halt_on_err: read_before_attempt(ctx, runtime_components, cfg));
halt_on_err!([ctx] => orchestrate_endpoint(ctx, cfg).await.map_err(OrchestratorError::other));
halt_on_err!([ctx] => orchestrate_endpoint(ctx, runtime_components, cfg).await.map_err(OrchestratorError::other));
halt_on_err!([ctx] => interceptors.modify_before_signing(ctx, cfg));
halt_on_err!([ctx] => interceptors.read_before_signing(ctx, cfg));
run_interceptors!(halt_on_err: {
modify_before_signing(ctx, runtime_components, cfg);
read_before_signing(ctx, runtime_components, cfg);
});
halt_on_err!([ctx] => orchestrate_auth(ctx, cfg).await.map_err(OrchestratorError::other));
halt_on_err!([ctx] => orchestrate_auth(ctx, runtime_components, cfg).await.map_err(OrchestratorError::other));
halt_on_err!([ctx] => interceptors.read_after_signing(ctx, cfg));
halt_on_err!([ctx] => interceptors.modify_before_transmit(ctx, cfg));
halt_on_err!([ctx] => interceptors.read_before_transmit(ctx, cfg));
run_interceptors!(halt_on_err: {
read_after_signing(ctx, runtime_components, cfg);
modify_before_transmit(ctx, runtime_components, cfg);
read_before_transmit(ctx, runtime_components, cfg);
});
// Return early if a stop point is set for before transmit
if let StopPoint::BeforeTransmit = stop_point {
@ -304,7 +334,10 @@ async fn try_attempt(
let response = halt_on_err!([ctx] => {
let request = ctx.take_request().expect("set during serialization");
trace!(request = ?request, "transmitting request");
cfg.connector().call(request).await.map_err(|err| {
let connector = halt_on_err!([ctx] => runtime_components.connector().ok_or_else(||
OrchestratorError::other("a connector is required to send requests")
));
connector.call(request).await.map_err(|err| {
match err.downcast() {
Ok(connector_error) => OrchestratorError::connector(*connector_error),
Err(box_err) => OrchestratorError::other(box_err)
@ -315,14 +348,18 @@ async fn try_attempt(
ctx.set_response(response);
ctx.enter_before_deserialization_phase();
halt_on_err!([ctx] => interceptors.read_after_transmit(ctx, cfg));
halt_on_err!([ctx] => interceptors.modify_before_deserialization(ctx, cfg));
halt_on_err!([ctx] => interceptors.read_before_deserialization(ctx, cfg));
run_interceptors!(halt_on_err: {
read_after_transmit(ctx, runtime_components, cfg);
modify_before_deserialization(ctx, runtime_components, cfg);
read_before_deserialization(ctx, runtime_components, cfg);
});
ctx.enter_deserialization_phase();
let output_or_error = async {
let response = ctx.response_mut().expect("set during transmit");
let response_deserializer = cfg.response_deserializer();
let response_deserializer = cfg
.load::<DynResponseDeserializer>()
.expect("a request deserializer must be in the config bag");
let maybe_deserialized = {
let _span = debug_span!("deserialize_streaming").entered();
response_deserializer.deserialize_streaming(response)
@ -345,44 +382,48 @@ async fn try_attempt(
ctx.set_output_or_error(output_or_error);
ctx.enter_after_deserialization_phase();
halt_on_err!([ctx] => interceptors.read_after_deserialization(ctx, cfg));
run_interceptors!(halt_on_err: read_after_deserialization(ctx, runtime_components, cfg));
}
#[instrument(skip_all)]
async fn finally_attempt(
ctx: &mut InterceptorContext,
cfg: &mut ConfigBag,
interceptors: &Interceptors,
runtime_components: &RuntimeComponents,
) {
continue_on_err!([ctx] => interceptors.modify_before_attempt_completion(ctx, cfg));
continue_on_err!([ctx] => interceptors.read_after_attempt(ctx, cfg));
run_interceptors!(continue_on_err: {
modify_before_attempt_completion(ctx, runtime_components, cfg);
read_after_attempt(ctx, runtime_components, cfg);
});
}
#[instrument(skip_all)]
async fn finally_op(
ctx: &mut InterceptorContext,
cfg: &mut ConfigBag,
interceptors: &Interceptors,
runtime_components: &RuntimeComponents,
) {
continue_on_err!([ctx] => interceptors.modify_before_completion(ctx, cfg));
continue_on_err!([ctx] => interceptors.read_after_execution(ctx, cfg));
run_interceptors!(continue_on_err: {
modify_before_completion(ctx, runtime_components, cfg);
read_after_execution(ctx, runtime_components, cfg);
});
}
#[cfg(all(test, feature = "test-util"))]
mod tests {
use super::*;
use crate::client::auth::no_auth::{NoAuthRuntimePlugin, NO_AUTH_SCHEME_ID};
use crate::client::orchestrator::endpoints::{
StaticUriEndpointResolver, StaticUriEndpointResolverParams,
};
use crate::client::orchestrator::endpoints::StaticUriEndpointResolver;
use crate::client::retries::strategy::NeverRetryStrategy;
use crate::client::test_util::{
deserializer::CannedResponseDeserializer, serializer::CannedRequestSerializer,
};
use ::http::{Request, Response, StatusCode};
use aws_smithy_runtime_api::client::auth::option_resolver::StaticAuthOptionResolver;
use aws_smithy_runtime_api::client::auth::{AuthOptionResolverParams, DynAuthOptionResolver};
use aws_smithy_runtime_api::client::connectors::{Connector, DynConnector};
use aws_smithy_runtime_api::client::auth::{
AuthOptionResolverParams, SharedAuthOptionResolver,
};
use aws_smithy_runtime_api::client::connectors::{Connector, SharedConnector};
use aws_smithy_runtime_api::client::interceptors::context::{
AfterDeserializationInterceptorContextRef, BeforeDeserializationInterceptorContextMut,
BeforeDeserializationInterceptorContextRef, BeforeSerializationInterceptorContextMut,
@ -390,17 +431,17 @@ mod tests {
BeforeTransmitInterceptorContextRef, FinalizerInterceptorContextMut,
FinalizerInterceptorContextRef,
};
use aws_smithy_runtime_api::client::interceptors::{
Interceptor, InterceptorRegistrar, SharedInterceptor,
};
use aws_smithy_runtime_api::client::interceptors::{Interceptor, SharedInterceptor};
use aws_smithy_runtime_api::client::orchestrator::{
BoxFuture, DynEndpointResolver, DynResponseDeserializer, Future, HttpRequest,
SharedRequestSerializer,
BoxFuture, DynResponseDeserializer, EndpointResolverParams, Future, HttpRequest,
SharedEndpointResolver, SharedRequestSerializer,
};
use aws_smithy_runtime_api::client::retries::DynRetryStrategy;
use aws_smithy_runtime_api::client::retries::SharedRetryStrategy;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
use aws_smithy_runtime_api::client::runtime_plugin::{RuntimePlugin, RuntimePlugins};
use aws_smithy_types::config_bag::{ConfigBag, FrozenLayer, Layer};
use aws_smithy_types::type_erasure::{TypeErasedBox, TypedBox};
use std::borrow::Cow;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tracing_test::traced_test;
@ -442,36 +483,58 @@ mod tests {
}
#[derive(Debug)]
struct TestOperationRuntimePlugin;
struct TestOperationRuntimePlugin {
builder: RuntimeComponentsBuilder,
}
impl TestOperationRuntimePlugin {
fn new() -> Self {
Self {
builder: RuntimeComponentsBuilder::new("TestOperationRuntimePlugin")
.with_retry_strategy(Some(SharedRetryStrategy::new(NeverRetryStrategy::new())))
.with_endpoint_resolver(Some(SharedEndpointResolver::new(
StaticUriEndpointResolver::http_localhost(8080),
)))
.with_connector(Some(SharedConnector::new(OkConnector::new())))
.with_auth_option_resolver(Some(SharedAuthOptionResolver::new(
StaticAuthOptionResolver::new(vec![NO_AUTH_SCHEME_ID]),
))),
}
}
}
impl RuntimePlugin for TestOperationRuntimePlugin {
fn config(&self) -> Option<FrozenLayer> {
let mut cfg = Layer::new("test operation");
cfg.set_request_serializer(SharedRequestSerializer::new(new_request_serializer()));
cfg.set_response_deserializer(
DynResponseDeserializer::new(new_response_deserializer()),
);
cfg.set_retry_strategy(DynRetryStrategy::new(NeverRetryStrategy::new()));
cfg.set_endpoint_resolver(DynEndpointResolver::new(
StaticUriEndpointResolver::http_localhost(8080),
));
cfg.set_endpoint_resolver_params(StaticUriEndpointResolverParams::new().into());
cfg.set_connector(DynConnector::new(OkConnector::new()));
cfg.set_auth_option_resolver_params(AuthOptionResolverParams::new("idontcare"));
cfg.set_auth_option_resolver(DynAuthOptionResolver::new(
StaticAuthOptionResolver::new(vec![NO_AUTH_SCHEME_ID]),
));
let mut layer = Layer::new("TestOperationRuntimePlugin");
layer.store_put(AuthOptionResolverParams::new("idontcare"));
layer.store_put(EndpointResolverParams::new("dontcare"));
layer.store_put(SharedRequestSerializer::new(new_request_serializer()));
layer.store_put(DynResponseDeserializer::new(new_response_deserializer()));
Some(layer.freeze())
}
Some(cfg.freeze())
fn runtime_components(&self) -> Cow<'_, RuntimeComponentsBuilder> {
Cow::Borrowed(&self.builder)
}
}
macro_rules! interceptor_error_handling_test {
(read_before_execution, $ctx:ty, $expected:expr,) => {
interceptor_error_handling_test!(__private read_before_execution, $ctx, $expected,);
};
($interceptor:ident, $ctx:ty, $expected:expr) => {
interceptor_error_handling_test!(__private $interceptor, $ctx, $expected, _rc: &RuntimeComponents,);
};
(__private $interceptor:ident, $ctx:ty, $expected:expr, $($rc_arg:tt)*) => {
#[derive(Debug)]
struct FailingInterceptorA;
impl Interceptor for FailingInterceptorA {
fn $interceptor(&self, _ctx: $ctx, _cfg: &mut ConfigBag) -> Result<(), BoxError> {
fn $interceptor(
&self,
_ctx: $ctx,
$($rc_arg)*
_cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
tracing::debug!("FailingInterceptorA called!");
Err("FailingInterceptorA".into())
}
@ -480,7 +543,12 @@ mod tests {
#[derive(Debug)]
struct FailingInterceptorB;
impl Interceptor for FailingInterceptorB {
fn $interceptor(&self, _ctx: $ctx, _cfg: &mut ConfigBag) -> Result<(), BoxError> {
fn $interceptor(
&self,
_ctx: $ctx,
$($rc_arg)*
_cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
tracing::debug!("FailingInterceptorB called!");
Err("FailingInterceptorB".into())
}
@ -489,37 +557,53 @@ mod tests {
#[derive(Debug)]
struct FailingInterceptorC;
impl Interceptor for FailingInterceptorC {
fn $interceptor(&self, _ctx: $ctx, _cfg: &mut ConfigBag) -> Result<(), BoxError> {
fn $interceptor(
&self,
_ctx: $ctx,
$($rc_arg)*
_cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
tracing::debug!("FailingInterceptorC called!");
Err("FailingInterceptorC".into())
}
}
#[derive(Debug)]
struct FailingInterceptorsClientRuntimePlugin;
struct FailingInterceptorsClientRuntimePlugin(RuntimeComponentsBuilder);
impl FailingInterceptorsClientRuntimePlugin {
fn new() -> Self {
Self(RuntimeComponentsBuilder::new("test").with_interceptor(SharedInterceptor::new(FailingInterceptorA)))
}
}
impl RuntimePlugin for FailingInterceptorsClientRuntimePlugin {
fn interceptors(&self, interceptors: &mut InterceptorRegistrar) {
interceptors.register(SharedInterceptor::new(FailingInterceptorA));
fn runtime_components(&self) -> Cow<'_, RuntimeComponentsBuilder> {
Cow::Borrowed(&self.0)
}
}
#[derive(Debug)]
struct FailingInterceptorsOperationRuntimePlugin;
struct FailingInterceptorsOperationRuntimePlugin(RuntimeComponentsBuilder);
impl FailingInterceptorsOperationRuntimePlugin {
fn new() -> Self {
Self(
RuntimeComponentsBuilder::new("test")
.with_interceptor(SharedInterceptor::new(FailingInterceptorB))
.with_interceptor(SharedInterceptor::new(FailingInterceptorC))
)
}
}
impl RuntimePlugin for FailingInterceptorsOperationRuntimePlugin {
fn interceptors(&self, interceptors: &mut InterceptorRegistrar) {
interceptors.register(SharedInterceptor::new(FailingInterceptorB));
interceptors.register(SharedInterceptor::new(FailingInterceptorC));
fn runtime_components(&self) -> Cow<'_, RuntimeComponentsBuilder> {
Cow::Borrowed(&self.0)
}
}
let input = TypeErasedBox::new(Box::new(()));
let runtime_plugins = RuntimePlugins::new()
.with_client_plugin(FailingInterceptorsClientRuntimePlugin)
.with_operation_plugin(TestOperationRuntimePlugin)
.with_client_plugin(FailingInterceptorsClientRuntimePlugin::new())
.with_operation_plugin(TestOperationRuntimePlugin::new())
.with_operation_plugin(NoAuthRuntimePlugin::new())
.with_operation_plugin(FailingInterceptorsOperationRuntimePlugin);
.with_operation_plugin(FailingInterceptorsOperationRuntimePlugin::new());
let actual = invoke("test", "test", input, &runtime_plugins)
.await
.expect_err("should error");
@ -539,7 +623,7 @@ mod tests {
interceptor_error_handling_test!(
read_before_execution,
&BeforeSerializationInterceptorContextRef<'_>,
expected
expected,
);
}
@ -742,13 +826,20 @@ mod tests {
}
macro_rules! interceptor_error_redirection_test {
(read_before_execution, $origin_ctx:ty, $destination_interceptor:ident, $destination_ctx:ty, $expected:expr) => {
interceptor_error_redirection_test!(__private read_before_execution, $origin_ctx, $destination_interceptor, $destination_ctx, $expected,);
};
($origin_interceptor:ident, $origin_ctx:ty, $destination_interceptor:ident, $destination_ctx:ty, $expected:expr) => {
interceptor_error_redirection_test!(__private $origin_interceptor, $origin_ctx, $destination_interceptor, $destination_ctx, $expected, _rc: &RuntimeComponents,);
};
(__private $origin_interceptor:ident, $origin_ctx:ty, $destination_interceptor:ident, $destination_ctx:ty, $expected:expr, $($rc_arg:tt)*) => {
#[derive(Debug)]
struct OriginInterceptor;
impl Interceptor for OriginInterceptor {
fn $origin_interceptor(
&self,
_ctx: $origin_ctx,
$($rc_arg)*
_cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
tracing::debug!("OriginInterceptor called!");
@ -762,6 +853,7 @@ mod tests {
fn $destination_interceptor(
&self,
_ctx: $destination_ctx,
_runtime_components: &RuntimeComponents,
_cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
tracing::debug!("DestinationInterceptor called!");
@ -770,20 +862,27 @@ mod tests {
}
#[derive(Debug)]
struct InterceptorsTestOperationRuntimePlugin;
struct InterceptorsTestOperationRuntimePlugin(RuntimeComponentsBuilder);
impl InterceptorsTestOperationRuntimePlugin {
fn new() -> Self {
Self(
RuntimeComponentsBuilder::new("test")
.with_interceptor(SharedInterceptor::new(OriginInterceptor))
.with_interceptor(SharedInterceptor::new(DestinationInterceptor))
)
}
}
impl RuntimePlugin for InterceptorsTestOperationRuntimePlugin {
fn interceptors(&self, interceptors: &mut InterceptorRegistrar) {
interceptors.register(SharedInterceptor::new(OriginInterceptor));
interceptors.register(SharedInterceptor::new(DestinationInterceptor));
fn runtime_components(&self) -> Cow<'_, RuntimeComponentsBuilder> {
Cow::Borrowed(&self.0)
}
}
let input = TypeErasedBox::new(Box::new(()));
let runtime_plugins = RuntimePlugins::new()
.with_operation_plugin(TestOperationRuntimePlugin)
.with_operation_plugin(TestOperationRuntimePlugin::new())
.with_operation_plugin(NoAuthRuntimePlugin::new())
.with_operation_plugin(InterceptorsTestOperationRuntimePlugin);
.with_operation_plugin(InterceptorsTestOperationRuntimePlugin::new());
let actual = invoke("test", "test", input, &runtime_plugins)
.await
.expect_err("should error");
@ -1006,12 +1105,6 @@ mod tests {
);
}
// #[tokio::test]
// #[traced_test]
// async fn test_read_after_attempt_error_causes_jump_to_modify_before_attempt_completion() {
// todo!("I'm confused by the behavior described in the spec")
// }
#[tokio::test]
#[traced_test]
async fn test_modify_before_completion_error_causes_jump_to_read_after_execution() {
@ -1029,7 +1122,7 @@ mod tests {
async fn test_stop_points() {
let runtime_plugins = || {
RuntimePlugins::new()
.with_operation_plugin(TestOperationRuntimePlugin)
.with_operation_plugin(TestOperationRuntimePlugin::new())
.with_operation_plugin(NoAuthRuntimePlugin::new())
};
@ -1076,6 +1169,7 @@ mod tests {
fn modify_before_retry_loop(
&self,
_context: &mut BeforeTransmitInterceptorContextMut<'_>,
_rc: &RuntimeComponents,
_cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.inner
@ -1087,6 +1181,7 @@ mod tests {
fn modify_before_completion(
&self,
_context: &mut FinalizerInterceptorContextMut<'_>,
_rc: &RuntimeComponents,
_cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.inner
@ -1098,6 +1193,7 @@ mod tests {
fn read_after_execution(
&self,
_context: &FinalizerInterceptorContextRef<'_>,
_rc: &RuntimeComponents,
_cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
self.inner
@ -1109,21 +1205,22 @@ mod tests {
#[derive(Debug)]
struct TestInterceptorRuntimePlugin {
interceptor: TestInterceptor,
builder: RuntimeComponentsBuilder,
}
impl RuntimePlugin for TestInterceptorRuntimePlugin {
fn interceptors(&self, interceptors: &mut InterceptorRegistrar) {
interceptors.register(SharedInterceptor::new(self.interceptor.clone()));
fn runtime_components(&self) -> Cow<'_, RuntimeComponentsBuilder> {
Cow::Borrowed(&self.builder)
}
}
let interceptor = TestInterceptor::default();
let runtime_plugins = || {
RuntimePlugins::new()
.with_operation_plugin(TestOperationRuntimePlugin)
.with_operation_plugin(TestOperationRuntimePlugin::new())
.with_operation_plugin(NoAuthRuntimePlugin::new())
.with_operation_plugin(TestInterceptorRuntimePlugin {
interceptor: interceptor.clone(),
builder: RuntimeComponentsBuilder::new("test")
.with_interceptor(SharedInterceptor::new(interceptor.clone())),
})
};

View File

@ -5,11 +5,12 @@
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::auth::{
AuthSchemeEndpointConfig, AuthSchemeId, HttpAuthScheme,
AuthOptionResolver, AuthSchemeEndpointConfig, AuthSchemeId, HttpAuthScheme,
};
use aws_smithy_runtime_api::client::config_bag_accessors::ConfigBagAccessors;
use aws_smithy_runtime_api::client::identity::IdentityResolver;
use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::ConfigBag;
use aws_smithy_types::endpoint::Endpoint;
use aws_smithy_types::Document;
@ -66,22 +67,22 @@ impl StdError for AuthOrchestrationError {}
pub(super) async fn orchestrate_auth(
ctx: &mut InterceptorContext,
runtime_components: &RuntimeComponents,
cfg: &ConfigBag,
) -> Result<(), BoxError> {
let params = cfg.auth_option_resolver_params();
let auth_options = cfg.auth_option_resolver().resolve_auth_options(params)?;
let identity_resolvers = cfg.identity_resolvers();
let auth_option_resolver = runtime_components.auth_option_resolver();
let auth_options = auth_option_resolver.resolve_auth_options(params)?;
trace!(
auth_option_resolver_params = ?params,
auth_options = ?auth_options,
identity_resolvers = ?identity_resolvers,
"orchestrating auth",
);
for &scheme_id in auth_options.as_ref() {
if let Some(auth_scheme) = cfg.http_auth_schemes().scheme(scheme_id) {
if let Some(identity_resolver) = auth_scheme.identity_resolver(&identity_resolvers) {
if let Some(auth_scheme) = runtime_components.http_auth_scheme(scheme_id) {
if let Some(identity_resolver) = auth_scheme.identity_resolver(runtime_components) {
let request_signer = auth_scheme.request_signer();
trace!(
auth_scheme = ?auth_scheme,
@ -106,6 +107,7 @@ pub(super) async fn orchestrate_auth(
request,
&identity,
auth_scheme_endpoint_config,
runtime_components,
cfg,
)?;
return Ok(());
@ -145,21 +147,23 @@ fn extract_endpoint_auth_scheme_config(
Ok(AuthSchemeEndpointConfig::new(Some(auth_scheme_config)))
}
#[cfg(test)]
#[cfg(all(test, feature = "test-util"))]
mod tests {
use super::*;
use aws_smithy_http::body::SdkBody;
use aws_smithy_runtime_api::client::auth::option_resolver::StaticAuthOptionResolver;
use aws_smithy_runtime_api::client::auth::{
AuthOptionResolverParams, AuthSchemeId, DynAuthOptionResolver, HttpAuthScheme,
HttpRequestSigner, SharedHttpAuthScheme,
AuthOptionResolverParams, AuthSchemeId, HttpAuthScheme, HttpRequestSigner,
SharedAuthOptionResolver, SharedHttpAuthScheme,
};
use aws_smithy_runtime_api::client::config_bag_accessors::ConfigBagAccessors;
use aws_smithy_runtime_api::client::identity::{
Identity, IdentityResolver, IdentityResolvers, SharedIdentityResolver,
Identity, IdentityResolver, SharedIdentityResolver,
};
use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
use aws_smithy_runtime_api::client::orchestrator::{Future, HttpRequest};
use aws_smithy_runtime_api::client::runtime_components::{
GetIdentityResolver, RuntimeComponentsBuilder,
};
use aws_smithy_types::config_bag::Layer;
use aws_smithy_types::type_erasure::TypedBox;
use std::collections::HashMap;
@ -183,6 +187,7 @@ mod tests {
request: &mut HttpRequest,
_identity: &Identity,
_auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
_runtime_components: &RuntimeComponents,
_config_bag: &ConfigBag,
) -> Result<(), BoxError> {
request
@ -205,7 +210,7 @@ mod tests {
fn identity_resolver(
&self,
identity_resolvers: &IdentityResolvers,
identity_resolvers: &dyn GetIdentityResolver,
) -> Option<SharedIdentityResolver> {
identity_resolvers.identity_resolver(self.scheme_id())
}
@ -221,23 +226,28 @@ mod tests {
let _ = ctx.take_input();
ctx.enter_before_transmit_phase();
let mut layer = Layer::new("test");
layer.set_auth_option_resolver_params(AuthOptionResolverParams::new("doesntmatter"));
layer.set_auth_option_resolver(DynAuthOptionResolver::new(StaticAuthOptionResolver::new(
vec![TEST_SCHEME_ID],
)));
layer.push_identity_resolver(
TEST_SCHEME_ID,
SharedIdentityResolver::new(TestIdentityResolver),
);
layer.push_http_auth_scheme(SharedHttpAuthScheme::new(TestAuthScheme {
signer: TestSigner,
}));
layer.store_put(Endpoint::builder().url("dontcare").build());
let runtime_components = RuntimeComponentsBuilder::for_tests()
.with_auth_option_resolver(Some(SharedAuthOptionResolver::new(
StaticAuthOptionResolver::new(vec![TEST_SCHEME_ID]),
)))
.with_identity_resolver(
TEST_SCHEME_ID,
SharedIdentityResolver::new(TestIdentityResolver),
)
.with_http_auth_scheme(SharedHttpAuthScheme::new(TestAuthScheme {
signer: TestSigner,
}))
.build()
.unwrap();
let mut cfg = ConfigBag::base();
cfg.push_layer(layer);
orchestrate_auth(&mut ctx, &cfg).await.expect("success");
let mut layer: Layer = Layer::new("test");
layer.store_put(AuthOptionResolverParams::new("doesntmatter"));
layer.store_put(Endpoint::builder().url("dontcare").build());
let cfg = ConfigBag::of_layers(vec![layer]);
orchestrate_auth(&mut ctx, &runtime_components, &cfg)
.await
.expect("success");
assert_eq!(
"success!",
@ -267,26 +277,33 @@ mod tests {
fn config_with_identity(
scheme_id: AuthSchemeId,
identity: impl IdentityResolver + 'static,
) -> ConfigBag {
let mut layer = Layer::new("test");
layer.push_http_auth_scheme(SharedHttpAuthScheme::new(BasicAuthScheme::new()));
layer.push_http_auth_scheme(SharedHttpAuthScheme::new(BearerAuthScheme::new()));
layer.store_put(Endpoint::builder().url("dontcare").build());
) -> (RuntimeComponents, ConfigBag) {
let runtime_components = RuntimeComponentsBuilder::for_tests()
.with_http_auth_scheme(SharedHttpAuthScheme::new(BasicAuthScheme::new()))
.with_http_auth_scheme(SharedHttpAuthScheme::new(BearerAuthScheme::new()))
.with_auth_option_resolver(Some(SharedAuthOptionResolver::new(
StaticAuthOptionResolver::new(vec![
HTTP_BASIC_AUTH_SCHEME_ID,
HTTP_BEARER_AUTH_SCHEME_ID,
]),
)))
.with_identity_resolver(scheme_id, SharedIdentityResolver::new(identity))
.build()
.unwrap();
layer.set_auth_option_resolver_params(AuthOptionResolverParams::new("doesntmatter"));
layer.set_auth_option_resolver(DynAuthOptionResolver::new(
StaticAuthOptionResolver::new(vec![
HTTP_BASIC_AUTH_SCHEME_ID,
HTTP_BEARER_AUTH_SCHEME_ID,
]),
));
layer.push_identity_resolver(scheme_id, SharedIdentityResolver::new(identity));
ConfigBag::of_layers(vec![layer])
let mut layer = Layer::new("test");
layer.store_put(Endpoint::builder().url("dontcare").build());
layer.store_put(AuthOptionResolverParams::new("doesntmatter"));
(runtime_components, ConfigBag::of_layers(vec![layer]))
}
// First, test the presence of a basic auth login and absence of a bearer token
let cfg = config_with_identity(HTTP_BASIC_AUTH_SCHEME_ID, Login::new("a", "b", None));
orchestrate_auth(&mut ctx, &cfg).await.expect("success");
let (runtime_components, cfg) =
config_with_identity(HTTP_BASIC_AUTH_SCHEME_ID, Login::new("a", "b", None));
orchestrate_auth(&mut ctx, &runtime_components, &cfg)
.await
.expect("success");
assert_eq!(
// "YTpi" == "a:b" in base64
"Basic YTpi",
@ -298,13 +315,16 @@ mod tests {
);
// Next, test the presence of a bearer token and absence of basic auth
let cfg = config_with_identity(HTTP_BEARER_AUTH_SCHEME_ID, Token::new("t", None));
let (runtime_components, cfg) =
config_with_identity(HTTP_BEARER_AUTH_SCHEME_ID, Token::new("t", None));
let mut ctx = InterceptorContext::new(TypedBox::new("doesnt-matter").erase());
ctx.enter_serialization_phase();
ctx.set_request(http::Request::builder().body(SdkBody::empty()).unwrap());
let _ = ctx.take_input();
ctx.enter_before_transmit_phase();
orchestrate_auth(&mut ctx, &cfg).await.expect("success");
orchestrate_auth(&mut ctx, &runtime_components, &cfg)
.await
.expect("success");
assert_eq!(
"Bearer t",
ctx.request()

View File

@ -14,6 +14,7 @@ use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
use aws_smithy_runtime_api::client::orchestrator::{
EndpointResolver, EndpointResolverParams, Future, HttpRequest,
};
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
use aws_smithy_types::endpoint::Endpoint;
use http::header::HeaderName;
@ -103,6 +104,7 @@ where
pub(super) async fn orchestrate_endpoint(
ctx: &mut InterceptorContext,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
trace!("orchestrating endpoint resolution");
@ -111,8 +113,10 @@ pub(super) async fn orchestrate_endpoint(
let endpoint_prefix = cfg.load::<EndpointPrefix>();
let request = ctx.request_mut().expect("set during serialization");
let endpoint_resolver = cfg.endpoint_resolver();
let endpoint = endpoint_resolver.resolve_endpoint(params).await?;
let endpoint = runtime_components
.endpoint_resolver()
.resolve_endpoint(params)
.await?;
apply_endpoint(request, &endpoint, endpoint_prefix)?;
// Make the endpoint config available to interceptors

View File

@ -6,6 +6,7 @@
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::interceptors::context::BeforeDeserializationInterceptorContextMut;
use aws_smithy_runtime_api::client::interceptors::Interceptor;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
use aws_smithy_types::date_time::Format;
use aws_smithy_types::DateTime;
@ -68,6 +69,7 @@ impl Interceptor for ServiceClockSkewInterceptor {
fn modify_before_deserialization(
&self,
ctx: &mut BeforeDeserializationInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let time_received = DateTime::from(SystemTime::now());

View File

@ -298,11 +298,8 @@ mod tests {
use super::{cubic_throttle, ClientRateLimiter};
use crate::client::retries::client_rate_limiter::RequestReason;
use approx::assert_relative_eq;
use aws_smithy_async::rt::sleep::{AsyncSleep, SharedAsyncSleep};
use aws_smithy_async::rt::sleep::AsyncSleep;
use aws_smithy_async::test_util::instant_time_and_sleep;
use aws_smithy_async::time::SharedTimeSource;
use aws_smithy_runtime_api::client::config_bag_accessors::ConfigBagAccessors;
use aws_smithy_types::config_bag::ConfigBag;
use std::time::{Duration, SystemTime};
const ONE_SECOND: Duration = Duration::from_secs(1);
@ -325,12 +322,6 @@ mod tests {
#[tokio::test]
async fn throttling_is_enabled_once_throttling_error_is_received() {
let mut cfg = ConfigBag::base();
let (time_source, sleep_impl) = instant_time_and_sleep(SystemTime::UNIX_EPOCH);
cfg.interceptor_state()
.set_request_time(SharedTimeSource::new(time_source));
cfg.interceptor_state()
.set_sleep_impl(Some(SharedAsyncSleep::new(sleep_impl)));
let rate_limiter = ClientRateLimiter::builder()
.previous_time_bucket(0.0)
.time_of_last_throttle(0.0)
@ -349,13 +340,6 @@ mod tests {
#[tokio::test]
async fn test_calculated_rate_with_successes() {
let mut cfg = ConfigBag::base();
let (time_source, sleep_impl) = instant_time_and_sleep(SystemTime::UNIX_EPOCH);
sleep_impl.sleep(Duration::from_secs(5)).await;
cfg.interceptor_state()
.set_request_time(SharedTimeSource::new(time_source));
cfg.interceptor_state()
.set_sleep_impl(Some(SharedAsyncSleep::new(sleep_impl.clone())));
let rate_limiter = ClientRateLimiter::builder()
.time_of_last_throttle(5.0)
.tokens_retrieved_per_second_at_time_of_last_throttle(10.0)
@ -414,13 +398,6 @@ mod tests {
#[tokio::test]
async fn test_calculated_rate_with_throttles() {
let mut cfg = ConfigBag::base();
let (time_source, sleep_impl) = instant_time_and_sleep(SystemTime::UNIX_EPOCH);
sleep_impl.sleep(Duration::from_secs(5)).await;
cfg.interceptor_state()
.set_request_time(SharedTimeSource::new(time_source));
cfg.interceptor_state()
.set_sleep_impl(Some(SharedAsyncSleep::new(sleep_impl.clone())));
let rate_limiter = ClientRateLimiter::builder()
.tokens_retrieved_per_second_at_time_of_last_throttle(10.0)
.time_of_last_throttle(5.0)
@ -496,12 +473,7 @@ mod tests {
#[tokio::test]
async fn test_client_sending_rates() {
let mut cfg = ConfigBag::base();
let (time_source, sleep_impl) = instant_time_and_sleep(SystemTime::UNIX_EPOCH);
cfg.interceptor_state()
.set_request_time(SharedTimeSource::new(time_source));
cfg.interceptor_state()
.set_sleep_impl(Some(SharedAsyncSleep::new(sleep_impl.clone())));
let (_, sleep_impl) = instant_time_and_sleep(SystemTime::UNIX_EPOCH);
let rate_limiter = ClientRateLimiter::builder().build();
struct Attempt {

View File

@ -7,8 +7,9 @@ use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
use aws_smithy_runtime_api::client::request_attempts::RequestAttempts;
use aws_smithy_runtime_api::client::retries::{
ClassifyRetry, RetryClassifiers, RetryReason, RetryStrategy, ShouldAttempt,
ClassifyRetry, RetryReason, RetryStrategy, ShouldAttempt,
};
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::ConfigBag;
use std::time::Duration;
@ -34,13 +35,18 @@ impl FixedDelayRetryStrategy {
}
impl RetryStrategy for FixedDelayRetryStrategy {
fn should_attempt_initial_request(&self, _cfg: &ConfigBag) -> Result<ShouldAttempt, BoxError> {
fn should_attempt_initial_request(
&self,
_runtime_components: &RuntimeComponents,
_cfg: &ConfigBag,
) -> Result<ShouldAttempt, BoxError> {
Ok(ShouldAttempt::Yes)
}
fn should_attempt_retry(
&self,
ctx: &InterceptorContext,
runtime_components: &RuntimeComponents,
cfg: &ConfigBag,
) -> Result<ShouldAttempt, BoxError> {
// Look a the result. If it's OK then we're done; No retry required. Otherwise, we need to inspect it
@ -65,8 +71,8 @@ impl RetryStrategy for FixedDelayRetryStrategy {
return Ok(ShouldAttempt::No);
}
let retry_classifiers = cfg
.load::<RetryClassifiers>()
let retry_classifiers = runtime_components
.retry_classifiers()
.expect("a retry classifier is set");
let retry_reason = retry_classifiers.classify_retry(ctx);

View File

@ -6,6 +6,7 @@
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
use aws_smithy_runtime_api::client::retries::{RetryStrategy, ShouldAttempt};
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::ConfigBag;
#[non_exhaustive]
@ -19,13 +20,18 @@ impl NeverRetryStrategy {
}
impl RetryStrategy for NeverRetryStrategy {
fn should_attempt_initial_request(&self, _cfg: &ConfigBag) -> Result<ShouldAttempt, BoxError> {
fn should_attempt_initial_request(
&self,
_runtime_components: &RuntimeComponents,
_cfg: &ConfigBag,
) -> Result<ShouldAttempt, BoxError> {
Ok(ShouldAttempt::Yes)
}
fn should_attempt_retry(
&self,
_context: &InterceptorContext,
_runtime_components: &RuntimeComponents,
_cfg: &ConfigBag,
) -> Result<ShouldAttempt, BoxError> {
Ok(ShouldAttempt::No)

View File

@ -9,12 +9,12 @@ use crate::client::retries::strategy::standard::ReleaseResult::{
};
use crate::client::retries::token_bucket::TokenBucket;
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::config_bag_accessors::ConfigBagAccessors;
use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
use aws_smithy_runtime_api::client::request_attempts::RequestAttempts;
use aws_smithy_runtime_api::client::retries::{
ClassifyRetry, RetryReason, RetryStrategy, ShouldAttempt,
};
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
use aws_smithy_types::retry::{ErrorKind, RetryConfig};
use std::sync::Mutex;
@ -96,6 +96,7 @@ impl StandardRetryStrategy {
fn calculate_backoff(
&self,
runtime_components: &RuntimeComponents,
cfg: &ConfigBag,
retry_reason: Option<&RetryReason>,
) -> Result<Duration, ShouldAttempt> {
@ -108,8 +109,12 @@ impl StandardRetryStrategy {
match retry_reason {
Some(RetryReason::Explicit(backoff)) => Ok(*backoff),
Some(RetryReason::Error(kind)) => {
update_rate_limiter_if_exists(cfg, *kind == ErrorKind::ThrottlingError);
if let Some(delay) = check_rate_limiter_for_delay(cfg, *kind) {
update_rate_limiter_if_exists(
runtime_components,
cfg,
*kind == ErrorKind::ThrottlingError,
);
if let Some(delay) = check_rate_limiter_for_delay(runtime_components, cfg, *kind) {
let delay = delay.min(self.max_backoff);
debug!("rate limiter has requested a {delay:?} delay before retrying");
Ok(delay)
@ -138,7 +143,7 @@ impl StandardRetryStrategy {
}
Some(_) => unreachable!("RetryReason is non-exhaustive"),
None => {
update_rate_limiter_if_exists(cfg, false);
update_rate_limiter_if_exists(runtime_components, cfg, false);
debug!(
attempts = request_attempts,
max_attempts = self.max_attempts,
@ -169,9 +174,13 @@ impl Default for StandardRetryStrategy {
}
impl RetryStrategy for StandardRetryStrategy {
fn should_attempt_initial_request(&self, cfg: &ConfigBag) -> Result<ShouldAttempt, BoxError> {
fn should_attempt_initial_request(
&self,
runtime_components: &RuntimeComponents,
cfg: &ConfigBag,
) -> Result<ShouldAttempt, BoxError> {
if let Some(crl) = cfg.load::<ClientRateLimiter>() {
let seconds_since_unix_epoch = get_seconds_since_unix_epoch(cfg);
let seconds_since_unix_epoch = get_seconds_since_unix_epoch(runtime_components);
if let Err(delay) = crl.acquire_permission_to_send_a_request(
seconds_since_unix_epoch,
RequestReason::InitialRequest,
@ -188,6 +197,7 @@ impl RetryStrategy for StandardRetryStrategy {
fn should_attempt_retry(
&self,
ctx: &InterceptorContext,
runtime_components: &RuntimeComponents,
cfg: &ConfigBag,
) -> Result<ShouldAttempt, BoxError> {
// Look a the result. If it's OK then we're done; No retry required. Otherwise, we need to inspect it
@ -207,7 +217,7 @@ impl RetryStrategy for StandardRetryStrategy {
tb.regenerate_a_token();
}
}
update_rate_limiter_if_exists(cfg, false);
update_rate_limiter_if_exists(runtime_components, cfg, false);
return Ok(ShouldAttempt::No);
}
@ -218,7 +228,7 @@ impl RetryStrategy for StandardRetryStrategy {
.expect("at least one request attempt is made before any retry is attempted")
.attempts();
if request_attempts >= self.max_attempts {
update_rate_limiter_if_exists(cfg, false);
update_rate_limiter_if_exists(runtime_components, cfg, false);
debug!(
attempts = request_attempts,
@ -229,11 +239,13 @@ impl RetryStrategy for StandardRetryStrategy {
}
// Run the classifiers against the context to determine if we should retry
let retry_classifiers = cfg.retry_classifiers();
let retry_classifiers = runtime_components
.retry_classifiers()
.ok_or("retry classifiers are required by the retry configuration")?;
let retry_reason = retry_classifiers.classify_retry(ctx);
// Calculate the appropriate backoff time.
let backoff = match self.calculate_backoff(cfg, retry_reason.as_ref()) {
let backoff = match self.calculate_backoff(runtime_components, cfg, retry_reason.as_ref()) {
Ok(value) => value,
// In some cases, backoff calculation will decide that we shouldn't retry at all.
Err(value) => return Ok(value),
@ -248,23 +260,32 @@ impl RetryStrategy for StandardRetryStrategy {
}
}
fn update_rate_limiter_if_exists(cfg: &ConfigBag, is_throttling_error: bool) {
fn update_rate_limiter_if_exists(
runtime_components: &RuntimeComponents,
cfg: &ConfigBag,
is_throttling_error: bool,
) {
if let Some(crl) = cfg.load::<ClientRateLimiter>() {
let seconds_since_unix_epoch = get_seconds_since_unix_epoch(cfg);
let seconds_since_unix_epoch = get_seconds_since_unix_epoch(runtime_components);
crl.update_rate_limiter(seconds_since_unix_epoch, is_throttling_error);
}
}
fn check_rate_limiter_for_delay(cfg: &ConfigBag, kind: ErrorKind) -> Option<Duration> {
fn check_rate_limiter_for_delay(
runtime_components: &RuntimeComponents,
cfg: &ConfigBag,
kind: ErrorKind,
) -> Option<Duration> {
if let Some(crl) = cfg.load::<ClientRateLimiter>() {
let retry_reason = if kind == ErrorKind::ThrottlingError {
RequestReason::RetryTimeout
} else {
RequestReason::Retry
};
if let Err(delay) = crl
.acquire_permission_to_send_a_request(get_seconds_since_unix_epoch(cfg), retry_reason)
{
if let Err(delay) = crl.acquire_permission_to_send_a_request(
get_seconds_since_unix_epoch(runtime_components),
retry_reason,
) {
return Some(delay);
}
}
@ -276,8 +297,10 @@ fn calculate_exponential_backoff(base: f64, initial_backoff: f64, retry_attempts
base * initial_backoff * 2_u32.pow(retry_attempts) as f64
}
fn get_seconds_since_unix_epoch(cfg: &ConfigBag) -> f64 {
let request_time = cfg.request_time().unwrap();
fn get_seconds_since_unix_epoch(runtime_components: &RuntimeComponents) -> f64 {
let request_time = runtime_components
.time_source()
.expect("time source required for retries");
request_time
.now()
.duration_since(SystemTime::UNIX_EPOCH)
@ -292,6 +315,7 @@ mod tests {
use aws_smithy_runtime_api::client::retries::{
AlwaysRetry, ClassifyRetry, RetryClassifiers, RetryReason, RetryStrategy,
};
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
use aws_smithy_types::config_bag::Layer;
use aws_smithy_types::retry::{ErrorKind, ProvideErrorKind};
use aws_smithy_types::type_erasure::TypeErasedBox;
@ -305,11 +329,12 @@ mod tests {
#[test]
fn no_retry_necessary_for_ok_result() {
let cfg = ConfigBag::base();
let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
let mut ctx = InterceptorContext::new(TypeErasedBox::doesnt_matter());
let strategy = StandardRetryStrategy::default();
ctx.set_output_or_error(Ok(TypeErasedBox::doesnt_matter()));
let actual = strategy
.should_attempt_retry(&ctx, &cfg)
.should_attempt_retry(&ctx, &rc, &cfg)
.expect("method is infallible for this use");
assert_eq!(ShouldAttempt::No, actual);
}
@ -317,26 +342,29 @@ mod tests {
fn set_up_cfg_and_context(
error_kind: ErrorKind,
current_request_attempts: u32,
) -> (InterceptorContext, ConfigBag) {
) -> (InterceptorContext, RuntimeComponents, ConfigBag) {
let mut ctx = InterceptorContext::new(TypeErasedBox::doesnt_matter());
ctx.set_output_or_error(Err(OrchestratorError::other("doesn't matter")));
let rc = RuntimeComponentsBuilder::for_tests()
.with_retry_classifiers(Some(
RetryClassifiers::new().with_classifier(AlwaysRetry(error_kind)),
))
.build()
.unwrap();
let mut layer = Layer::new("test");
layer.set_retry_classifiers(
RetryClassifiers::new().with_classifier(AlwaysRetry(error_kind)),
);
layer.store_put(RequestAttempts::new(current_request_attempts));
let cfg = ConfigBag::of_layers(vec![layer]);
(ctx, cfg)
(ctx, rc, cfg)
}
// Test that error kinds produce the correct "retry after X seconds" output.
// All error kinds are handled in the same way for the standard strategy.
fn test_should_retry_error_kind(error_kind: ErrorKind) {
let (ctx, cfg) = set_up_cfg_and_context(error_kind, 3);
let (ctx, rc, cfg) = set_up_cfg_and_context(error_kind, 3);
let strategy = StandardRetryStrategy::default().with_base(|| 1.0);
let actual = strategy
.should_attempt_retry(&ctx, &cfg)
.should_attempt_retry(&ctx, &rc, &cfg)
.expect("method is infallible for this use");
assert_eq!(ShouldAttempt::YesAfterDelay(Duration::from_secs(4)), actual);
}
@ -365,12 +393,12 @@ mod tests {
fn dont_retry_when_out_of_attempts() {
let current_attempts = 4;
let max_attempts = current_attempts;
let (ctx, cfg) = set_up_cfg_and_context(ErrorKind::TransientError, current_attempts);
let (ctx, rc, cfg) = set_up_cfg_and_context(ErrorKind::TransientError, current_attempts);
let strategy = StandardRetryStrategy::default()
.with_base(|| 1.0)
.with_max_attempts(max_attempts);
let actual = strategy
.should_attempt_retry(&ctx, &cfg)
.should_attempt_retry(&ctx, &rc, &cfg)
.expect("method is infallible for this use");
assert_eq!(ShouldAttempt::No, actual);
}
@ -431,23 +459,28 @@ mod tests {
}
#[cfg(feature = "test-util")]
fn setup_test(retry_reasons: Vec<RetryReason>) -> (ConfigBag, InterceptorContext) {
let mut cfg = ConfigBag::base();
cfg.interceptor_state().set_retry_classifiers(
RetryClassifiers::new()
.with_classifier(PresetReasonRetryClassifier::new(retry_reasons)),
);
fn setup_test(
retry_reasons: Vec<RetryReason>,
) -> (ConfigBag, RuntimeComponents, InterceptorContext) {
let rc = RuntimeComponentsBuilder::for_tests()
.with_retry_classifiers(Some(
RetryClassifiers::new()
.with_classifier(PresetReasonRetryClassifier::new(retry_reasons)),
))
.build()
.unwrap();
let cfg = ConfigBag::base();
let mut ctx = InterceptorContext::new(TypeErasedBox::doesnt_matter());
// This type doesn't matter b/c the classifier will just return whatever we tell it to.
ctx.set_output_or_error(Err(OrchestratorError::other("doesn't matter")));
(cfg, ctx)
(cfg, rc, ctx)
}
#[cfg(feature = "test-util")]
#[test]
fn eventual_success() {
let (mut cfg, mut ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]);
let (mut cfg, rc, mut ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]);
let strategy = StandardRetryStrategy::default()
.with_base(|| 1.0)
.with_max_attempts(5);
@ -455,13 +488,13 @@ mod tests {
let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
cfg.interceptor_state().store_put(RequestAttempts::new(1));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(1));
assert_eq!(token_bucket.available_permits(), 495);
cfg.interceptor_state().store_put(RequestAttempts::new(2));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(2));
assert_eq!(token_bucket.available_permits(), 490);
@ -469,7 +502,7 @@ mod tests {
ctx.set_output_or_error(Ok(TypeErasedBox::doesnt_matter()));
cfg.interceptor_state().store_put(RequestAttempts::new(3));
let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
assert_eq!(no_retry, ShouldAttempt::No);
assert_eq!(token_bucket.available_permits(), 495);
}
@ -477,7 +510,7 @@ mod tests {
#[cfg(feature = "test-util")]
#[test]
fn no_more_attempts() {
let (mut cfg, ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]);
let (mut cfg, rc, ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]);
let strategy = StandardRetryStrategy::default()
.with_base(|| 1.0)
.with_max_attempts(3);
@ -485,19 +518,19 @@ mod tests {
let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
cfg.interceptor_state().store_put(RequestAttempts::new(1));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(1));
assert_eq!(token_bucket.available_permits(), 495);
cfg.interceptor_state().store_put(RequestAttempts::new(2));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(2));
assert_eq!(token_bucket.available_permits(), 490);
cfg.interceptor_state().store_put(RequestAttempts::new(3));
let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
assert_eq!(no_retry, ShouldAttempt::No);
assert_eq!(token_bucket.available_permits(), 490);
}
@ -505,7 +538,7 @@ mod tests {
#[cfg(feature = "test-util")]
#[test]
fn no_quota() {
let (mut cfg, ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]);
let (mut cfg, rc, ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]);
let strategy = StandardRetryStrategy::default()
.with_base(|| 1.0)
.with_max_attempts(5);
@ -513,13 +546,13 @@ mod tests {
let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
cfg.interceptor_state().store_put(RequestAttempts::new(1));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(1));
assert_eq!(token_bucket.available_permits(), 0);
cfg.interceptor_state().store_put(RequestAttempts::new(2));
let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
assert_eq!(no_retry, ShouldAttempt::No);
assert_eq!(token_bucket.available_permits(), 0);
}
@ -527,7 +560,7 @@ mod tests {
#[cfg(feature = "test-util")]
#[test]
fn quota_replenishes_on_success() {
let (mut cfg, mut ctx) = setup_test(vec![
let (mut cfg, rc, mut ctx) = setup_test(vec![
RetryReason::Error(ErrorKind::TransientError),
RetryReason::Explicit(Duration::from_secs(1)),
]);
@ -538,13 +571,13 @@ mod tests {
let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
cfg.interceptor_state().store_put(RequestAttempts::new(1));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(1));
assert_eq!(token_bucket.available_permits(), 90);
cfg.interceptor_state().store_put(RequestAttempts::new(2));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(1));
assert_eq!(token_bucket.available_permits(), 90);
@ -552,7 +585,7 @@ mod tests {
ctx.set_output_or_error(Ok(TypeErasedBox::doesnt_matter()));
cfg.interceptor_state().store_put(RequestAttempts::new(3));
let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
assert_eq!(no_retry, ShouldAttempt::No);
assert_eq!(token_bucket.available_permits(), 100);
@ -562,7 +595,8 @@ mod tests {
#[test]
fn quota_replenishes_on_first_try_success() {
const PERMIT_COUNT: usize = 20;
let (mut cfg, mut ctx) = setup_test(vec![RetryReason::Error(ErrorKind::TransientError)]);
let (mut cfg, rc, mut ctx) =
setup_test(vec![RetryReason::Error(ErrorKind::TransientError)]);
let strategy = StandardRetryStrategy::default()
.with_base(|| 1.0)
.with_max_attempts(u32::MAX);
@ -581,7 +615,7 @@ mod tests {
cfg.interceptor_state()
.store_put(RequestAttempts::new(attempt));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
assert!(matches!(should_retry, ShouldAttempt::YesAfterDelay(_)));
attempt += 1;
}
@ -600,7 +634,7 @@ mod tests {
cfg.interceptor_state()
.store_put(RequestAttempts::new(attempt));
let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
assert_eq!(no_retry, ShouldAttempt::No);
attempt += 1;
}
@ -612,7 +646,7 @@ mod tests {
#[cfg(feature = "test-util")]
#[test]
fn backoff_timing() {
let (mut cfg, ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]);
let (mut cfg, rc, ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]);
let strategy = StandardRetryStrategy::default()
.with_base(|| 1.0)
.with_max_attempts(5);
@ -620,31 +654,31 @@ mod tests {
let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
cfg.interceptor_state().store_put(RequestAttempts::new(1));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(1));
assert_eq!(token_bucket.available_permits(), 495);
cfg.interceptor_state().store_put(RequestAttempts::new(2));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(2));
assert_eq!(token_bucket.available_permits(), 490);
cfg.interceptor_state().store_put(RequestAttempts::new(3));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(4));
assert_eq!(token_bucket.available_permits(), 485);
cfg.interceptor_state().store_put(RequestAttempts::new(4));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(8));
assert_eq!(token_bucket.available_permits(), 480);
cfg.interceptor_state().store_put(RequestAttempts::new(5));
let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
assert_eq!(no_retry, ShouldAttempt::No);
assert_eq!(token_bucket.available_permits(), 480);
}
@ -652,7 +686,7 @@ mod tests {
#[cfg(feature = "test-util")]
#[test]
fn max_backoff_time() {
let (mut cfg, ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]);
let (mut cfg, rc, ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]);
let strategy = StandardRetryStrategy::default()
.with_base(|| 1.0)
.with_max_attempts(5)
@ -662,31 +696,31 @@ mod tests {
let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
cfg.interceptor_state().store_put(RequestAttempts::new(1));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(1));
assert_eq!(token_bucket.available_permits(), 495);
cfg.interceptor_state().store_put(RequestAttempts::new(2));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(2));
assert_eq!(token_bucket.available_permits(), 490);
cfg.interceptor_state().store_put(RequestAttempts::new(3));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(3));
assert_eq!(token_bucket.available_permits(), 485);
cfg.interceptor_state().store_put(RequestAttempts::new(4));
let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
let dur = should_retry.expect_delay();
assert_eq!(dur, Duration::from_secs(3));
assert_eq!(token_bucket.available_permits(), 480);
cfg.interceptor_state().store_put(RequestAttempts::new(5));
let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap();
let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
assert_eq!(no_retry, ShouldAttempt::No);
assert_eq!(token_bucket.available_permits(), 480);
}

View File

@ -8,6 +8,7 @@
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut;
use aws_smithy_runtime_api::client::interceptors::Interceptor;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::ConfigBag;
use std::fmt;
@ -34,6 +35,7 @@ where
fn modify_before_signing(
&self,
context: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
(self.f)(context, cfg);
@ -41,34 +43,3 @@ where
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use aws_smithy_http::body::SdkBody;
use aws_smithy_runtime_api::client::config_bag_accessors::ConfigBagAccessors;
use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
use aws_smithy_types::type_erasure::TypedBox;
use std::time::{Duration, UNIX_EPOCH};
#[test]
fn set_test_request_time() {
let mut cfg = ConfigBag::base();
let mut ctx = InterceptorContext::new(TypedBox::new("anything").erase());
ctx.enter_serialization_phase();
ctx.set_request(http::Request::builder().body(SdkBody::empty()).unwrap());
let _ = ctx.take_input();
ctx.enter_before_transmit_phase();
let mut ctx = Into::into(&mut ctx);
let request_time = UNIX_EPOCH + Duration::from_secs(1624036048);
let interceptor = TestParamsSetterInterceptor::new(
move |_: &mut BeforeTransmitInterceptorContextMut<'_>, cfg: &mut ConfigBag| {
cfg.interceptor_state().set_request_time(request_time);
},
);
interceptor
.modify_before_signing(&mut ctx, &mut cfg)
.unwrap();
assert_eq!(cfg.request_time().unwrap().now(), request_time);
}
}

View File

@ -6,8 +6,8 @@
use aws_smithy_async::future::timeout::Timeout;
use aws_smithy_async::rt::sleep::{AsyncSleep, SharedAsyncSleep, Sleep};
use aws_smithy_client::SdkError;
use aws_smithy_runtime_api::client::config_bag_accessors::ConfigBagAccessors;
use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::ConfigBag;
use aws_smithy_types::timeout::TimeoutConfig;
use pin_project_lite::pin_project;
@ -109,14 +109,14 @@ pub(super) struct MaybeTimeoutConfig {
timeout_kind: TimeoutKind,
}
pub(super) trait ProvideMaybeTimeoutConfig {
fn maybe_timeout_config(&self, timeout_kind: TimeoutKind) -> MaybeTimeoutConfig;
}
impl ProvideMaybeTimeoutConfig for ConfigBag {
fn maybe_timeout_config(&self, timeout_kind: TimeoutKind) -> MaybeTimeoutConfig {
if let Some(timeout_config) = self.load::<TimeoutConfig>() {
let sleep_impl = self.sleep_impl();
impl MaybeTimeoutConfig {
pub(super) fn new(
runtime_components: &RuntimeComponents,
cfg: &ConfigBag,
timeout_kind: TimeoutKind,
) -> MaybeTimeoutConfig {
if let Some(timeout_config) = cfg.load::<TimeoutConfig>() {
let sleep_impl = runtime_components.sleep_impl();
let timeout = match (sleep_impl.as_ref(), timeout_kind) {
(None, _) => None,
(Some(_), TimeoutKind::Operation) => timeout_config.operation_timeout(),
@ -142,23 +142,14 @@ impl ProvideMaybeTimeoutConfig for ConfigBag {
/// Trait to conveniently wrap a future with an optional timeout.
pub(super) trait MaybeTimeout<T>: Sized {
/// Wraps a future in a timeout if one is set.
fn maybe_timeout_with_config(
self,
timeout_config: MaybeTimeoutConfig,
) -> MaybeTimeoutFuture<Self>;
/// Wraps a future in a timeout if one is set.
fn maybe_timeout(self, cfg: &ConfigBag, kind: TimeoutKind) -> MaybeTimeoutFuture<Self>;
fn maybe_timeout(self, timeout_config: MaybeTimeoutConfig) -> MaybeTimeoutFuture<Self>;
}
impl<T> MaybeTimeout<T> for T
where
T: Future,
{
fn maybe_timeout_with_config(
self,
timeout_config: MaybeTimeoutConfig,
) -> MaybeTimeoutFuture<Self> {
fn maybe_timeout(self, timeout_config: MaybeTimeoutConfig) -> MaybeTimeoutFuture<Self> {
match timeout_config {
MaybeTimeoutConfig {
sleep_impl: Some(sleep_impl),
@ -172,22 +163,18 @@ where
_ => MaybeTimeoutFuture::NoTimeout { future: self },
}
}
fn maybe_timeout(self, cfg: &ConfigBag, kind: TimeoutKind) -> MaybeTimeoutFuture<Self> {
self.maybe_timeout_with_config(cfg.maybe_timeout_config(kind))
}
}
#[cfg(test)]
mod tests {
use crate::client::timeout::{MaybeTimeout, TimeoutKind};
use super::*;
use aws_smithy_async::assert_elapsed;
use aws_smithy_async::future::never::Never;
use aws_smithy_async::rt::sleep::{AsyncSleep, SharedAsyncSleep, TokioSleep};
use aws_smithy_http::result::SdkError;
use aws_smithy_runtime_api::client::config_bag_accessors::ConfigBagAccessors;
use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
use aws_smithy_types::config_bag::{ConfigBag, Layer};
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
use aws_smithy_types::config_bag::{CloneableLayer, ConfigBag};
use aws_smithy_types::timeout::TimeoutConfig;
use std::time::Duration;
@ -203,14 +190,19 @@ mod tests {
let now = tokio::time::Instant::now();
tokio::time::pause();
let mut cfg = ConfigBag::base();
let mut timeout_config = Layer::new("timeout");
timeout_config.store_put(TimeoutConfig::builder().build());
timeout_config.set_sleep_impl(Some(sleep_impl));
cfg.push_layer(timeout_config);
let runtime_components = RuntimeComponentsBuilder::for_tests()
.with_sleep_impl(Some(sleep_impl))
.build()
.unwrap();
let mut timeout_config = CloneableLayer::new("timeout");
timeout_config.store_put(TimeoutConfig::builder().build());
let cfg = ConfigBag::of_layers(vec![timeout_config.into()]);
let maybe_timeout =
MaybeTimeoutConfig::new(&runtime_components, &cfg, TimeoutKind::Operation);
underlying_future
.maybe_timeout(&cfg, TimeoutKind::Operation)
.maybe_timeout(maybe_timeout)
.await
.expect("success");
@ -229,19 +221,21 @@ mod tests {
let now = tokio::time::Instant::now();
tokio::time::pause();
let mut cfg = ConfigBag::base();
let mut timeout_config = Layer::new("timeout");
let runtime_components = RuntimeComponentsBuilder::for_tests()
.with_sleep_impl(Some(sleep_impl))
.build()
.unwrap();
let mut timeout_config = CloneableLayer::new("timeout");
timeout_config.store_put(
TimeoutConfig::builder()
.operation_timeout(Duration::from_millis(250))
.build(),
);
timeout_config.set_sleep_impl(Some(sleep_impl));
cfg.push_layer(timeout_config);
let cfg = ConfigBag::of_layers(vec![timeout_config.into()]);
let result = underlying_future
.maybe_timeout(&cfg, TimeoutKind::Operation)
.await;
let maybe_timeout =
MaybeTimeoutConfig::new(&runtime_components, &cfg, TimeoutKind::Operation);
let result = underlying_future.maybe_timeout(maybe_timeout).await;
let err = result.expect_err("should have timed out");
assert_eq!(format!("{:?}", err), "TimeoutError(TimeoutError { source: MaybeTimeoutError { kind: Operation, duration: 250ms } })");

View File

@ -119,7 +119,7 @@ impl CloneableLayer {
/// Removes `T` from this bag
pub fn unset<T: Send + Sync + Clone + Debug + 'static>(&mut self) -> &mut Self {
self.0.unset::<T>();
self.put_directly_cloneable::<StoreReplace<T>>(Value::ExplicitlyUnset(type_name::<T>()));
self
}
@ -851,6 +851,10 @@ mod test {
let layer_1_cloned = layer_1.clone();
assert_eq!(expected_str, &layer_1_cloned.load::<TestStr>().unwrap().0);
// Should still be cloneable after unsetting a field
layer_1.unset::<TestStr>();
assert!(layer_1.try_clone().unwrap().load::<TestStr>().is_none());
#[derive(Clone, Debug)]
struct Rope(String);
impl Storable for Rope {

View File

@ -5,20 +5,33 @@
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut;
use aws_smithy_runtime_api::client::interceptors::{
Interceptor, InterceptorRegistrar, SharedInterceptor,
use aws_smithy_runtime_api::client::interceptors::{Interceptor, SharedInterceptor};
use aws_smithy_runtime_api::client::runtime_components::{
RuntimeComponents, RuntimeComponentsBuilder,
};
use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin;
use aws_smithy_types::base64;
use aws_smithy_types::config_bag::ConfigBag;
use http::header::HeaderName;
use std::borrow::Cow;
#[derive(Debug)]
pub(crate) struct HttpChecksumRequiredRuntimePlugin;
pub(crate) struct HttpChecksumRequiredRuntimePlugin {
runtime_components: RuntimeComponentsBuilder,
}
impl HttpChecksumRequiredRuntimePlugin {
pub(crate) fn new() -> Self {
Self {
runtime_components: RuntimeComponentsBuilder::new("HttpChecksumRequiredRuntimePlugin")
.with_interceptor(SharedInterceptor::new(HttpChecksumRequiredInterceptor)),
}
}
}
impl RuntimePlugin for HttpChecksumRequiredRuntimePlugin {
fn interceptors(&self, interceptors: &mut InterceptorRegistrar) {
interceptors.register(SharedInterceptor::new(HttpChecksumRequiredInterceptor));
fn runtime_components(&self) -> Cow<'_, RuntimeComponentsBuilder> {
Cow::Borrowed(&self.runtime_components)
}
}
@ -29,6 +42,7 @@ impl Interceptor for HttpChecksumRequiredInterceptor {
fn modify_before_signing(
&self,
context: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
_cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let request = context.request_mut();

View File

@ -8,16 +8,18 @@ use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::interceptors::context::{
BeforeSerializationInterceptorContextMut, Input,
};
use aws_smithy_runtime_api::client::interceptors::{
Interceptor, InterceptorRegistrar, SharedInterceptor,
use aws_smithy_runtime_api::client::interceptors::{Interceptor, SharedInterceptor};
use aws_smithy_runtime_api::client::runtime_components::{
RuntimeComponents, RuntimeComponentsBuilder,
};
use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin;
use aws_smithy_types::config_bag::ConfigBag;
use std::borrow::Cow;
use std::fmt;
#[derive(Debug)]
pub(crate) struct IdempotencyTokenRuntimePlugin {
interceptor: SharedInterceptor,
runtime_components: RuntimeComponentsBuilder,
}
impl IdempotencyTokenRuntimePlugin {
@ -26,14 +28,17 @@ impl IdempotencyTokenRuntimePlugin {
S: Fn(IdempotencyTokenProvider, &mut Input) + Send + Sync + 'static,
{
Self {
interceptor: SharedInterceptor::new(IdempotencyTokenInterceptor { set_token }),
runtime_components: RuntimeComponentsBuilder::new("IdempotencyTokenRuntimePlugin")
.with_interceptor(SharedInterceptor::new(IdempotencyTokenInterceptor {
set_token,
})),
}
}
}
impl RuntimePlugin for IdempotencyTokenRuntimePlugin {
fn interceptors(&self, interceptors: &mut InterceptorRegistrar) {
interceptors.register(self.interceptor.clone());
fn runtime_components(&self) -> Cow<'_, RuntimeComponentsBuilder> {
Cow::Borrowed(&self.runtime_components)
}
}
@ -54,6 +59,7 @@ where
fn modify_before_serialization(
&self,
context: &mut BeforeSerializationInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let token_provider = cfg

View File

@ -31,6 +31,7 @@ services_that_pass_tests=(\
"polly"\
"qldbsession"\
"route53"\
"s3"\
"s3control"\
"sso"\
"sts"\