Update runtime plugin trait (#2754)

## Motivation and Context
Update the RuntimePlugin trait based on discussion:
1. Methods are infallible
2. Split out `Config` and `Interceptors`
3. `ConfigBag` now has an explicit field `interceptor_state`
4. Refactor `ConfigBagAccessors` so that we can build the core around
the trait and keep the trait together with a `where Self` trick

## Description
- Update the `RuntimePlugin` trait
- Deal with resulting implications
## Testing
- [x] CI


----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._

---------

Co-authored-by: John DiSanti <jdisanti@amazon.com>
This commit is contained in:
Russell Cohen 2023-06-13 15:43:03 -04:00 committed by GitHub
parent 8e37d42f3c
commit 5473192d3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
44 changed files with 661 additions and 507 deletions

View File

@ -14,7 +14,7 @@ use aws_smithy_runtime_api::client::interceptors::{
BeforeSerializationInterceptorContextMut, BeforeTransmitInterceptorContextMut, BoxError,
Interceptor,
};
use aws_smithy_runtime_api::client::orchestrator::LoadedRequestBody;
use aws_smithy_runtime_api::client::orchestrator::{ConfigBagAccessors, LoadedRequestBody};
use aws_smithy_types::config_bag::ConfigBag;
use bytes::Bytes;
use http::header::{HeaderName, HeaderValue};
@ -119,7 +119,8 @@ impl Interceptor for GlacierTreeHashHeaderInterceptor {
) -> Result<(), BoxError> {
// Request the request body to be loaded into memory immediately after serialization
// so that it can be checksummed before signing and transmit
cfg.put(LoadedRequestBody::Requested);
cfg.interceptor_state()
.set_loaded_request_body(LoadedRequestBody::Requested);
Ok(())
}
@ -139,7 +140,7 @@ impl Interceptor for GlacierTreeHashHeaderInterceptor {
.clone();
signing_config.signing_options.payload_override =
Some(SignableBody::Precomputed(content_sha256));
cfg.put(signing_config);
cfg.interceptor_state().put(signing_config);
} else {
return Err(
"the request body wasn't loaded into memory before the retry loop, \

View File

@ -19,7 +19,7 @@ use aws_smithy_runtime_api::client::interceptors::{
};
use aws_smithy_runtime_api::client::orchestrator::ConfigBagAccessors;
use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin;
use aws_smithy_types::config_bag::ConfigBag;
use aws_smithy_types::config_bag::{ConfigBag, FrozenLayer, Layer};
/// Interceptor that tells the SigV4 signer to add the signature to query params,
/// and sets the request expiration time from the presigning config.
@ -40,14 +40,15 @@ impl Interceptor for SigV4PresigningInterceptor {
_context: &mut BeforeSerializationInterceptorContextMut<'_>,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
cfg.put::<HeaderSerializationSettings>(
cfg.interceptor_state().put::<HeaderSerializationSettings>(
HeaderSerializationSettings::new()
.omit_default_content_length()
.omit_default_content_type(),
);
cfg.set_request_time(SharedTimeSource::new(StaticTimeSource::new(
self.config.start_time(),
)));
cfg.interceptor_state()
.set_request_time(SharedTimeSource::new(StaticTimeSource::new(
self.config.start_time(),
)));
Ok(())
}
@ -61,7 +62,8 @@ impl Interceptor for SigV4PresigningInterceptor {
config.signing_options.signature_type = HttpSignatureType::HttpRequestQueryParams;
config.signing_options.payload_override =
Some(aws_sigv4::http_request::SignableBody::UnsignedPayload);
cfg.put::<SigV4OperationSigningConfig>(config);
cfg.interceptor_state()
.put::<SigV4OperationSigningConfig>(config);
Ok(())
} else {
Err(
@ -87,18 +89,15 @@ impl SigV4PresigningRuntimePlugin {
}
impl RuntimePlugin for SigV4PresigningRuntimePlugin {
fn configure(
&self,
cfg: &mut ConfigBag,
interceptors: &mut InterceptorRegistrar,
) -> Result<(), BoxError> {
// Disable some SDK interceptors that shouldn't run for presigning
cfg.put(disable_interceptor::<InvocationIdInterceptor>("presigning"));
cfg.put(disable_interceptor::<RequestInfoInterceptor>("presigning"));
cfg.put(disable_interceptor::<UserAgentInterceptor>("presigning"));
fn config(&self) -> Option<FrozenLayer> {
let mut layer = Layer::new("Presigning");
layer.put(disable_interceptor::<InvocationIdInterceptor>("presigning"));
layer.put(disable_interceptor::<RequestInfoInterceptor>("presigning"));
layer.put(disable_interceptor::<UserAgentInterceptor>("presigning"));
Some(layer.freeze())
}
// Register the presigning interceptor
fn interceptors(&self, interceptors: &mut InterceptorRegistrar) {
interceptors.register(self.interceptor.clone());
Ok(())
}
}

View File

@ -506,6 +506,7 @@ mod tests {
use super::*;
use aws_credential_types::Credentials;
use aws_sigv4::http_request::SigningSettings;
use aws_smithy_types::config_bag::Layer;
use aws_types::region::SigningRegion;
use aws_types::SigningService;
use std::collections::HashMap;
@ -556,8 +557,8 @@ mod tests {
#[test]
fn endpoint_config_overrides_region_and_service() {
let mut cfg = ConfigBag::base();
cfg.put(SigV4OperationSigningConfig {
let mut layer = Layer::new("test");
layer.put(SigV4OperationSigningConfig {
region: Some(SigningRegion::from(Region::new("override-this-region"))),
service: Some(SigningService::from_static("override-this-service")),
signing_options: Default::default(),
@ -577,6 +578,7 @@ mod tests {
});
let config = AuthSchemeEndpointConfig::new(Some(&config));
let cfg = ConfigBag::of_layers(vec![layer]);
let result =
SigV4HttpRequestSigner::extract_operation_config(config, &cfg).expect("success");
@ -593,12 +595,13 @@ mod tests {
#[test]
fn endpoint_config_supports_fallback_when_region_or_service_are_unset() {
let mut cfg = ConfigBag::base();
cfg.put(SigV4OperationSigningConfig {
let mut layer = Layer::new("test");
layer.put(SigV4OperationSigningConfig {
region: Some(SigningRegion::from(Region::new("us-east-1"))),
service: Some(SigningService::from_static("qldb")),
signing_options: Default::default(),
});
let cfg = ConfigBag::of_layers(vec![layer]);
let config = AuthSchemeEndpointConfig::empty();
let result =

View File

@ -48,7 +48,8 @@ impl Interceptor for InvocationIdInterceptor {
.map(|gen| gen.generate())
.transpose()?
.flatten();
cfg.put::<InvocationId>(id.unwrap_or_default());
cfg.interceptor_state()
.put::<InvocationId>(id.unwrap_or_default());
Ok(())
}

View File

@ -166,7 +166,7 @@ mod tests {
use crate::request_info::RequestPairs;
use aws_smithy_http::body::SdkBody;
use aws_smithy_runtime_api::client::interceptors::{Interceptor, InterceptorContext};
use aws_smithy_types::config_bag::ConfigBag;
use aws_smithy_types::config_bag::{ConfigBag, Layer};
use aws_smithy_types::retry::RetryConfig;
use aws_smithy_types::timeout::TimeoutConfig;
use aws_smithy_types::type_erasure::TypeErasedBox;
@ -190,13 +190,14 @@ mod tests {
context.enter_serialization_phase();
context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap());
let mut config = ConfigBag::base();
config.put(RetryConfig::standard());
config.put(
let mut layer = Layer::new("test");
layer.put(RetryConfig::standard());
layer.put(
TimeoutConfig::builder()
.read_timeout(Duration::from_secs(30))
.build(),
);
let mut config = ConfigBag::of_layers(vec![layer]);
let _ = context.take_input();
context.enter_before_transmit_phase();

View File

@ -110,7 +110,7 @@ mod tests {
use super::*;
use aws_smithy_http::body::SdkBody;
use aws_smithy_runtime_api::client::interceptors::{Interceptor, InterceptorContext};
use aws_smithy_types::config_bag::ConfigBag;
use aws_smithy_types::config_bag::{ConfigBag, Layer};
use aws_smithy_types::error::display::DisplayErrorContext;
use aws_smithy_types::type_erasure::TypeErasedBox;
@ -138,14 +138,15 @@ mod tests {
fn test_overridden_ua() {
let mut context = context();
let mut config = ConfigBag::base();
config.put(AwsUserAgent::for_tests());
config.put(ApiMetadata::new("unused", "unused"));
let mut layer = Layer::new("test");
layer.put(AwsUserAgent::for_tests());
layer.put(ApiMetadata::new("unused", "unused"));
let mut cfg = ConfigBag::of_layers(vec![layer]);
let interceptor = UserAgentInterceptor::new();
let mut ctx = Into::into(&mut context);
interceptor
.modify_before_signing(&mut ctx, &mut config)
.modify_before_signing(&mut ctx, &mut cfg)
.unwrap();
let header = expect_header(&context, "user-agent");
@ -163,8 +164,9 @@ mod tests {
let mut context = context();
let api_metadata = ApiMetadata::new("some-service", "some-version");
let mut config = ConfigBag::base();
config.put(api_metadata.clone());
let mut layer = Layer::new("test");
layer.put(api_metadata.clone());
let mut config = ConfigBag::of_layers(vec![layer]);
let interceptor = UserAgentInterceptor::new();
let mut ctx = Into::into(&mut context);
@ -192,9 +194,10 @@ mod tests {
let mut context = context();
let api_metadata = ApiMetadata::new("some-service", "some-version");
let mut config = ConfigBag::base();
config.put(api_metadata);
config.put(AppName::new("my_awesome_app").unwrap());
let mut layer = Layer::new("test");
layer.put(api_metadata);
layer.put(AppName::new("my_awesome_app").unwrap());
let mut config = ConfigBag::of_layers(vec![layer]);
let interceptor = UserAgentInterceptor::new();
let mut ctx = Into::into(&mut context);

View File

@ -47,7 +47,7 @@ class CustomizableOperationTestHelpers(runtimeConfig: RuntimeConfig) :
pub fn request_time_for_tests(mut self, request_time: ::std::time::SystemTime) -> Self {
use #{ConfigBagAccessors};
let interceptor = #{TestParamsSetterInterceptor}::new(move |_: &mut #{BeforeTransmitInterceptorContextMut}<'_>, cfg: &mut #{ConfigBag}| {
cfg.set_request_time(#{SharedTimeSource}::new(request_time));
cfg.interceptor_state().set_request_time(#{SharedTimeSource}::new(request_time));
});
self.interceptors.push(#{SharedInterceptor}::new(interceptor));
self

View File

@ -36,7 +36,7 @@ private class InvocationIdRuntimePluginCustomization(
)
override fun section(section: ServiceRuntimePluginSection): Writable = writable {
if (section is ServiceRuntimePluginSection.AdditionalConfig) {
if (section is ServiceRuntimePluginSection.RegisterInterceptor) {
section.registerInterceptor(codegenContext.runtimeConfig, this) {
rustTemplate("#{InvocationIdInterceptor}::new()", *codegenScope)
}

View File

@ -31,7 +31,7 @@ private class RecursionDetectionRuntimePluginCustomization(
private val codegenContext: ClientCodegenContext,
) : ServiceRuntimePluginCustomization() {
override fun section(section: ServiceRuntimePluginSection): Writable = writable {
if (section is ServiceRuntimePluginSection.AdditionalConfig) {
if (section is ServiceRuntimePluginSection.RegisterInterceptor) {
section.registerInterceptor(codegenContext.runtimeConfig, this) {
rust(
"#T::new()",

View File

@ -35,7 +35,7 @@ private class AddRetryInformationHeaderInterceptors(codegenContext: ClientCodege
private val awsRuntime = AwsRuntimeType.awsRuntime(runtimeConfig)
override fun section(section: ServiceRuntimePluginSection): Writable = writable {
if (section is ServiceRuntimePluginSection.AdditionalConfig) {
if (section is ServiceRuntimePluginSection.RegisterInterceptor) {
// Track the latency between client and server.
section.registerInterceptor(runtimeConfig, this) {
rust(

View File

@ -138,7 +138,7 @@ private class AuthOperationCustomization(private val codegenContext: ClientCodeg
signing_options.signing_optional = $signingOptional;
signing_options.payload_override = #{payload_override};
${section.configBagName}.put(#{SigV4OperationSigningConfig} {
${section.newLayerName}.put(#{SigV4OperationSigningConfig} {
region: None,
service: None,
signing_options,
@ -147,7 +147,7 @@ private class AuthOperationCustomization(private val codegenContext: ClientCodeg
let auth_option_resolver = #{StaticAuthOptionResolver}::new(
vec![#{SIGV4_SCHEME_ID}]
);
${section.configBagName}.set_auth_option_resolver(auth_option_resolver);
${section.newLayerName}.set_auth_option_resolver(auth_option_resolver);
""",
*codegenScope,
"payload_override" to writable {

View File

@ -103,13 +103,19 @@ class UserAgentDecorator : ClientCodegenDecorator {
private val awsRuntime = AwsRuntimeType.awsRuntime(runtimeConfig)
override fun section(section: ServiceRuntimePluginSection): Writable = writable {
if (section is ServiceRuntimePluginSection.AdditionalConfig) {
section.putConfigValue(this) {
rust("#T.clone()", ClientRustModule.Meta.toType().resolve("API_METADATA"))
when (section) {
is ServiceRuntimePluginSection.AdditionalConfig -> {
section.putConfigValue(this) {
rust("#T.clone()", ClientRustModule.Meta.toType().resolve("API_METADATA"))
}
}
section.registerInterceptor(runtimeConfig, this) {
rust("#T::new()", awsRuntime.resolve("user_agent::UserAgentInterceptor"))
is ServiceRuntimePluginSection.RegisterInterceptor -> {
section.registerInterceptor(runtimeConfig, this) {
rust("#T::new()", awsRuntime.resolve("user_agent::UserAgentInterceptor"))
}
}
else -> emptySection
}
}
}

View File

@ -66,7 +66,7 @@ private class ApiGatewayAddAcceptHeader : OperationCustomization() {
private class ApiGatewayAcceptHeaderInterceptorCustomization(private val codegenContext: ClientCodegenContext) :
ServiceRuntimePluginCustomization() {
override fun section(section: ServiceRuntimePluginSection): Writable = writable {
if (section is ServiceRuntimePluginSection.AdditionalConfig) {
if (section is ServiceRuntimePluginSection.RegisterInterceptor) {
section.registerInterceptor(codegenContext.runtimeConfig, this) {
rustTemplate(
"#{Interceptor}::default()",

View File

@ -97,11 +97,14 @@ private class GlacierAccountIdCustomization(private val codegenContext: ClientCo
}
}
// TODO(enableNewSmithyRuntime): Install the glacier customizations as a single additional runtime plugin instead
// of wiring up the interceptors individually
/** Adds the `x-amz-glacier-version` header to all requests */
private class GlacierApiVersionCustomization(private val codegenContext: ClientCodegenContext) :
ServiceRuntimePluginCustomization() {
override fun section(section: ServiceRuntimePluginSection): Writable = writable {
if (section is ServiceRuntimePluginSection.AdditionalConfig) {
if (section is ServiceRuntimePluginSection.RegisterInterceptor) {
val apiVersion = codegenContext.serviceShape.version
section.registerInterceptor(codegenContext.runtimeConfig, this) {
rustTemplate(
@ -122,7 +125,7 @@ private class GlacierApiVersionCustomization(private val codegenContext: ClientC
private class GlacierOperationInterceptorsCustomization(private val codegenContext: ClientCodegenContext) :
OperationCustomization() {
override fun section(section: OperationSection): Writable = writable {
if (section is OperationSection.AdditionalRuntimePluginConfig) {
if (section is OperationSection.AdditionalInterceptors) {
val inputShape = codegenContext.model.expectShape(section.operationShape.inputShape) as StructureShape
val inlineModule = inlineModule(codegenContext.runtimeConfig)
if (inputShape.inputWithAccountId()) {

View File

@ -224,7 +224,7 @@ private class HttpAuthOperationCustomization(codegenContext: ClientCodegenContex
}
// TODO(enableNewSmithyRuntime): Make auth options additive in the config bag so that multiple codegen decorators can register them
rustTemplate("${section.configBagName}.set_auth_option_resolver(auth_option_resolver);", *codegenScope)
rustTemplate("${section.newLayerName}.set_auth_option_resolver(auth_option_resolver);", *codegenScope)
}
else -> emptySection

View File

@ -177,9 +177,9 @@ class InterceptorConfigCustomization(codegenContext: CodegenContext) : ConfigCus
""",
)
ServiceConfig.ToRuntimePlugin -> rust(
is ServiceConfig.RuntimePluginInterceptors -> rust(
"""
interceptors.extend(self.interceptors.iter().cloned());
${section.interceptors}.extend(self.interceptors.iter().cloned());
""",
)

View File

@ -271,12 +271,12 @@ class ResiliencyServiceRuntimePluginCustomization : ServiceRuntimePluginCustomiz
rust(
"""
if let Some(sleep_impl) = self.handle.conf.sleep_impl() {
${section.configBagName}.put(sleep_impl);
${section.newLayerName}.put(sleep_impl);
}
if let Some(timeout_config) = self.handle.conf.timeout_config() {
${section.configBagName}.put(timeout_config.clone());
${section.newLayerName}.put(timeout_config.clone());
}
${section.configBagName}.put(self.handle.conf.time_source.clone());
${section.newLayerName}.put(self.handle.conf.time_source.clone());
""",
)
}

View File

@ -43,7 +43,7 @@ private class EndpointParametersCustomization(
override fun section(section: OperationSection): Writable = writable {
val symbolProvider = codegenContext.symbolProvider
val operationName = symbolProvider.toSymbol(operation).name
if (section is OperationSection.AdditionalRuntimePluginConfig) {
if (section is OperationSection.AdditionalInterceptors) {
section.registerInterceptor(codegenContext.runtimeConfig, this) {
rust("${operationName}EndpointParamsInterceptor")
}

View File

@ -24,6 +24,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.inputShape
@ -44,6 +45,8 @@ class EndpointParamsInterceptorGenerator(
arrayOf(
"BoxError" to runtimeApi.resolve("client::runtime_plugin::BoxError"),
"ConfigBag" to smithyTypes.resolve("config_bag::ConfigBag"),
"ConfigBagAccessors" to RuntimeType.smithyRuntimeApi(rc)
.resolve("client::orchestrator::ConfigBagAccessors"),
"ContextAttachedError" to interceptors.resolve("error::ContextAttachedError"),
"EndpointResolverParams" to orchestrator.resolve("EndpointResolverParams"),
"HttpRequest" to orchestrator.resolve("HttpRequest"),
@ -74,6 +77,7 @@ class EndpointParamsInterceptorGenerator(
context: &#{BeforeSerializationInterceptorContextRef}<'_, #{Input}, #{Output}, #{Error}>,
cfg: &mut #{ConfigBag},
) -> Result<(), #{BoxError}> {
use #{ConfigBagAccessors};
let _input = context.input()
.downcast_ref::<${operationInput.name}>()
.ok_or("failed to downcast to ${operationInput.name}")?;
@ -89,7 +93,7 @@ class EndpointParamsInterceptorGenerator(
#{param_setters}
.build()
.map_err(|err| #{ContextAttachedError}::new("endpoint params could not be built", err))?;
cfg.put(#{EndpointResolverParams}::new(params));
cfg.interceptor_state().set_endpoint_resolver_params(#{EndpointResolverParams}::new(params));
Ok(())
}
}
@ -167,7 +171,7 @@ class EndpointParamsInterceptorGenerator(
codegenContext.smithyRuntimeMode,
)
}
rust("cfg.put(endpoint_prefix);")
rust("cfg.interceptor_state().put(endpoint_prefix);")
}
}
}

View File

@ -101,10 +101,15 @@ sealed class OperationSection(name: String) : Section(name) {
*/
data class AdditionalRuntimePluginConfig(
override val customizations: List<OperationCustomization>,
val configBagName: String,
val newLayerName: String,
val operationShape: OperationShape,
) : OperationSection("AdditionalRuntimePluginConfig")
data class AdditionalInterceptors(
override val customizations: List<OperationCustomization>,
val interceptorRegistrarName: String,
val operationShape: OperationShape,
) : OperationSection("AdditionalConfig") {
) : OperationSection("AdditionalInterceptors") {
fun registerInterceptor(runtimeConfig: RuntimeConfig, writer: RustWriter, interceptor: Writable) {
val smithyRuntimeApi = RuntimeType.smithyRuntimeApi(runtimeConfig)
writer.rustTemplate(

View File

@ -11,7 +11,9 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations
import software.amazon.smithy.rust.codegen.core.util.dq
/**
* Generates operation-level runtime plugins
@ -25,6 +27,8 @@ class OperationRuntimePluginGenerator(
arrayOf(
"AuthOptionResolverParams" to runtimeApi.resolve("client::auth::AuthOptionResolverParams"),
"BoxError" to runtimeApi.resolve("client::runtime_plugin::BoxError"),
"Layer" to smithyTypes.resolve("config_bag::Layer"),
"FrozenLayer" to smithyTypes.resolve("config_bag::FrozenLayer"),
"ConfigBag" to smithyTypes.resolve("config_bag::ConfigBag"),
"ConfigBagAccessors" to runtimeApi.resolve("client::orchestrator::ConfigBagAccessors"),
"InterceptorRegistrar" to runtimeApi.resolve("client::interceptors::InterceptorRegistrar"),
@ -43,7 +47,8 @@ class OperationRuntimePluginGenerator(
writer.rustTemplate(
"""
impl #{RuntimePlugin} for $operationStructName {
fn configure(&self, cfg: &mut #{ConfigBag}, _interceptors: &mut #{InterceptorRegistrar}) -> Result<(), #{BoxError}> {
fn config(&self) -> #{Option}<#{FrozenLayer}> {
let mut cfg = #{Layer}::new(${operationShape.id.name.dq()});
use #{ConfigBagAccessors} as _;
cfg.set_request_serializer(${operationStructName}RequestSerializer);
cfg.set_response_deserializer(${operationStructName}ResponseDeserializer);
@ -57,21 +62,33 @@ class OperationRuntimePluginGenerator(
cfg.set_retry_classifiers(retry_classifiers);
#{additional_config}
Ok(())
Some(cfg.freeze())
}
fn interceptors(&self, _interceptors: &mut #{InterceptorRegistrar}) {
#{interceptors}
}
}
#{runtime_plugin_supporting_types}
""",
*codegenScope,
*preludeScope,
"additional_config" to writable {
writeCustomizations(
customizations,
OperationSection.AdditionalRuntimePluginConfig(customizations, "cfg", "_interceptors", operationShape),
OperationSection.AdditionalRuntimePluginConfig(
customizations,
newLayerName = "cfg",
operationShape,
),
)
},
"retry_classifier_customizations" to writable {
writeCustomizations(customizations, OperationSection.RetryClassifier(customizations, "cfg", operationShape))
writeCustomizations(
customizations,
OperationSection.RetryClassifier(customizations, "cfg", operationShape),
)
},
"runtime_plugin_supporting_types" to writable {
writeCustomizations(
@ -79,6 +96,12 @@ class OperationRuntimePluginGenerator(
OperationSection.RuntimePluginSupportingTypes(customizations, "cfg", operationShape),
)
},
"interceptors" to writable {
writeCustomizations(
customizations,
OperationSection.AdditionalInterceptors(customizations, "_interceptors", operationShape),
)
},
)
}
}

View File

@ -54,7 +54,7 @@ class ServiceGenerator(
ServiceRuntimePluginGenerator(codegenContext)
.render(this, decorator.serviceRuntimePluginCustomizations(codegenContext, emptyList()))
serviceConfigGenerator.renderRuntimePluginImplForBuilder(this, codegenContext)
serviceConfigGenerator.renderRuntimePluginImplForBuilder(this)
}
}

View File

@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.pre
import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations
import software.amazon.smithy.rust.codegen.core.util.dq
sealed class ServiceRuntimePluginSection(name: String) : Section(name) {
/**
@ -42,12 +43,14 @@ sealed class ServiceRuntimePluginSection(name: String) : Section(name) {
/**
* Hook for adding additional things to config inside service runtime plugins.
*/
data class AdditionalConfig(val configBagName: String, val interceptorRegistrarName: String) : ServiceRuntimePluginSection("AdditionalConfig") {
data class AdditionalConfig(val newLayerName: String) : ServiceRuntimePluginSection("AdditionalConfig") {
/** Adds a value to the config bag */
fun putConfigValue(writer: RustWriter, value: Writable) {
writer.rust("$configBagName.put(#T);", value)
writer.rust("$newLayerName.put(#T);", value)
}
}
data class RegisterInterceptor(val interceptorRegistrarName: String) : ServiceRuntimePluginSection("RegisterInterceptor") {
/** Generates the code to register an interceptor */
fun registerInterceptor(runtimeConfig: RuntimeConfig, writer: RustWriter, interceptor: Writable) {
val smithyRuntimeApi = RuntimeType.smithyRuntimeApi(runtimeConfig)
@ -67,7 +70,7 @@ typealias ServiceRuntimePluginCustomization = NamedCustomization<ServiceRuntimeP
* Generates the service-level runtime plugin
*/
class ServiceRuntimePluginGenerator(
codegenContext: ClientCodegenContext,
private val codegenContext: ClientCodegenContext,
) {
private val endpointTypesGenerator = EndpointTypesGenerator.fromContext(codegenContext)
private val codegenScope = codegenContext.runtimeConfig.let { rc ->
@ -82,6 +85,8 @@ class ServiceRuntimePluginGenerator(
"AnonymousIdentityResolver" to runtimeApi.resolve("client::identity::AnonymousIdentityResolver"),
"BoxError" to runtimeApi.resolve("client::runtime_plugin::BoxError"),
"ConfigBag" to smithyTypes.resolve("config_bag::ConfigBag"),
"Layer" to smithyTypes.resolve("config_bag::Layer"),
"FrozenLayer" to smithyTypes.resolve("config_bag::FrozenLayer"),
"ConfigBagAccessors" to runtimeApi.resolve("client::orchestrator::ConfigBagAccessors"),
"Connection" to runtimeApi.resolve("client::orchestrator::Connection"),
"ConnectorSettings" to RuntimeType.smithyClient(rc).resolve("http_connector::ConnectorSettings"),
@ -118,8 +123,9 @@ class ServiceRuntimePluginGenerator(
}
impl #{RuntimePlugin} for ServiceRuntimePlugin {
fn configure(&self, cfg: &mut #{ConfigBag}, _interceptors: &mut #{InterceptorRegistrar}) -> #{Result}<(), #{BoxError}> {
fn config(&self) -> #{Option}<#{FrozenLayer}> {
use #{ConfigBagAccessors};
let mut cfg = #{Layer}::new(${codegenContext.serviceShape.id.name.dq()});
// HACK: Put the handle into the config bag to work around config not being fully implemented yet
cfg.put(self.handle.clone());
@ -147,27 +153,26 @@ class ServiceRuntimePluginGenerator(
if let Some(retry_config) = retry_config {
cfg.set_retry_strategy(#{StandardRetryStrategy}::new(retry_config));
} else if cfg.retry_strategy().is_none() {
cfg.set_retry_strategy(#{NeverRetryStrategy}::new());
}
let connector_settings = timeout_config.map(#{ConnectorSettings}::from_timeout_config).unwrap_or_default();
let connection: #{Box}<dyn #{Connection}> = #{Box}::new(#{DynConnectorAdapter}::new(
// TODO(enableNewSmithyRuntime): Replace the tower-based DynConnector and remove DynConnectorAdapter when deleting the middleware implementation
#{require_connector}(
self.handle.conf.http_connector()
.and_then(|c| c.connector(&connector_settings, sleep_impl.clone()))
.or_else(|| #{default_connector}(&connector_settings, sleep_impl))
)?
)) as _;
cfg.set_connection(connection);
if let Some(connection) = self.handle.conf.http_connector()
.and_then(|c| c.connector(&connector_settings, sleep_impl.clone()))
.or_else(|| #{default_connector}(&connector_settings, sleep_impl)) {
let connection: #{Box}<dyn #{Connection}> = #{Box}::new(#{DynConnectorAdapter}::new(
// TODO(enableNewSmithyRuntime): Replace the tower-based DynConnector and remove DynConnectorAdapter when deleting the middleware implementation
connection
)) as _;
cfg.set_connection(connection);
}
#{additional_config}
// Client-level Interceptors are registered after default Interceptors.
_interceptors.extend(self.handle.conf.interceptors.iter().cloned());
Some(cfg.freeze())
}
Ok(())
fn interceptors(&self, interceptors: &mut #{InterceptorRegistrar}) {
interceptors.extend(self.handle.conf.interceptors.iter().cloned());
#{additional_interceptors}
}
}
""",
@ -179,7 +184,10 @@ class ServiceRuntimePluginGenerator(
writeCustomizations(customizations, ServiceRuntimePluginSection.RetryClassifier("cfg"))
},
"additional_config" to writable {
writeCustomizations(customizations, ServiceRuntimePluginSection.AdditionalConfig("cfg", "_interceptors"))
writeCustomizations(customizations, ServiceRuntimePluginSection.AdditionalConfig("cfg"))
},
"additional_interceptors" to writable {
writeCustomizations(customizations, ServiceRuntimePluginSection.RegisterInterceptor("interceptors"))
},
)
}

View File

@ -20,13 +20,14 @@ import software.amazon.smithy.rust.codegen.core.rustlang.docsOrFallback
import software.amazon.smithy.rust.codegen.core.rustlang.raw
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations
import software.amazon.smithy.rust.codegen.core.smithy.makeOptional
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.letIf
@ -91,7 +92,9 @@ sealed class ServiceConfig(name: String) : Section(name) {
/**
* A section for setting up a field to be used by RuntimePlugin
*/
object ToRuntimePlugin : ServiceConfig("ToRuntimePlugin")
data class RuntimePluginConfig(val cfg: String) : ServiceConfig("ToRuntimePlugin")
data class RuntimePluginInterceptors(val interceptors: String) : ServiceConfig("ToRuntimePluginInterceptors")
/**
* A section for extra functionality that needs to be defined with the config module
@ -163,7 +166,7 @@ fun standardConfigParam(param: ConfigParam): ConfigCustomization = object : Conf
rust("${param.name}: self.${param.name}$default,")
}
ServiceConfig.ToRuntimePlugin -> emptySection
is ServiceConfig.RuntimePluginConfig -> emptySection
else -> emptySection
}
@ -196,7 +199,10 @@ typealias ConfigCustomization = NamedCustomization<ServiceConfig>
* // builder implementation
* }
*/
class ServiceConfigGenerator(private val customizations: List<ConfigCustomization> = listOf()) {
class ServiceConfigGenerator(
private val codegenContext: CodegenContext,
private val customizations: List<ConfigCustomization> = listOf(),
) {
companion object {
fun withBaseBehavior(
@ -207,10 +213,22 @@ class ServiceConfigGenerator(private val customizations: List<ConfigCustomizatio
if (codegenContext.serviceShape.needsIdempotencyToken(codegenContext.model)) {
baseFeatures.add(IdempotencyTokenProviderCustomization())
}
return ServiceConfigGenerator(baseFeatures + extraCustomizations)
return ServiceConfigGenerator(codegenContext, baseFeatures + extraCustomizations)
}
}
private val runtimeApi = RuntimeType.smithyRuntimeApi(codegenContext.runtimeConfig)
private val smithyTypes = RuntimeType.smithyTypes(codegenContext.runtimeConfig)
val codegenScope = arrayOf(
"BoxError" to runtimeApi.resolve("client::runtime_plugin::BoxError"),
"Layer" to smithyTypes.resolve("config_bag::Layer"),
"ConfigBag" to smithyTypes.resolve("config_bag::ConfigBag"),
"FrozenLayer" to smithyTypes.resolve("config_bag::FrozenLayer"),
"InterceptorRegistrar" to runtimeApi.resolve("client::interceptors::InterceptorRegistrar"),
"RuntimePlugin" to runtimeApi.resolve("client::runtime_plugin::RuntimePlugin"),
*preludeScope,
)
fun render(writer: RustWriter) {
writer.docs("Service config.\n")
customizations.forEach {
@ -306,30 +324,29 @@ class ServiceConfigGenerator(private val customizations: List<ConfigCustomizatio
}
}
fun renderRuntimePluginImplForBuilder(writer: RustWriter, codegenContext: CodegenContext) {
val runtimeApi = RuntimeType.smithyRuntimeApi(codegenContext.runtimeConfig)
val smithyTypes = RuntimeType.smithyTypes(codegenContext.runtimeConfig)
writer.rustBlockTemplate(
"impl #{RuntimePlugin} for Builder",
"RuntimePlugin" to runtimeApi.resolve("client::runtime_plugin::RuntimePlugin"),
) {
rustBlockTemplate(
"""
fn configure(&self, _cfg: &mut #{ConfigBag}, interceptors: &mut #{InterceptorRegistrar}) -> Result<(), #{BoxError}>
""",
"BoxError" to runtimeApi.resolve("client::runtime_plugin::BoxError"),
"ConfigBag" to smithyTypes.resolve("config_bag::ConfigBag"),
"InterceptorRegistrar" to runtimeApi.resolve("client::interceptors::InterceptorRegistrar"),
) {
rust("// TODO(enableNewSmithyRuntime): Put into `cfg` the fields in `self.config_override` that are not `None`")
customizations.forEach {
it.section(ServiceConfig.ToRuntimePlugin)(writer)
fun renderRuntimePluginImplForBuilder(writer: RustWriter) {
writer.rustTemplate(
"""
impl #{RuntimePlugin} for Builder {
fn config(&self) -> #{Option}<#{FrozenLayer}> {
// TODO(enableNewSmithyRuntime): Put into `cfg` the fields in `self.config_override` that are not `None`
##[allow(unused_mut)]
let mut cfg = #{Layer}::new("service config");
#{config}
Some(cfg.freeze())
}
rust("Ok(())")
fn interceptors(&self, _interceptors: &mut #{InterceptorRegistrar}) {
#{interceptors}
}
}
}
""",
*codegenScope,
"config" to writable { writeCustomizations(customizations, ServiceConfig.RuntimePluginConfig("cfg")) },
"interceptors" to writable {
writeCustomizations(customizations, ServiceConfig.RuntimePluginInterceptors("_interceptors"))
},
)
}
}

View File

@ -87,7 +87,7 @@ class ClientHttpBoundProtocolPayloadGenerator(
if (propertyBagAvailable) {
rust("properties.acquire_mut().insert(signer_sender);")
} else {
rust("_cfg.put(signer_sender);")
rust("_cfg.interceptor_state().put(signer_sender);")
}
},
)

View File

@ -14,6 +14,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace
import software.amazon.smithy.rust.codegen.core.testutil.TestWriterDelegator
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
@ -73,7 +74,7 @@ fun validateConfigCustomizations(
fun stubConfigProject(customization: ConfigCustomization, project: TestWriterDelegator): TestWriterDelegator {
val customizations = listOf(stubConfigCustomization("a")) + customization + stubConfigCustomization("b")
val generator = ServiceConfigGenerator(customizations = customizations.toList())
val generator = ServiceConfigGenerator(testClientCodegenContext("namespace test".asSmithyModel()), customizations = customizations.toList())
project.withModule(ClientRustModule.Config) {
generator.render(this)
unitTest(

View File

@ -23,6 +23,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig
import software.amazon.smithy.rust.codegen.core.testutil.TestModuleDocProvider
import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig
import software.amazon.smithy.rust.codegen.core.testutil.TestWriterDelegator
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
fun testClientRustSettings(
service: ShapeId = ShapeId.from("notrelevant#notrelevant"),
@ -72,7 +73,7 @@ fun testSymbolProvider(model: Model, serviceShape: ServiceShape? = null): RustSy
)
fun testClientCodegenContext(
model: Model,
model: Model = "namespace empty".asSmithyModel(),
symbolProvider: RustSymbolProvider? = null,
serviceShape: ServiceShape? = null,
settings: ClientRustSettings = testClientRustSettings(),

View File

@ -9,7 +9,7 @@ import io.kotest.matchers.shouldBe
import org.junit.jupiter.api.Test
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider
import software.amazon.smithy.rust.codegen.client.testutil.testClientCodegenContext
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.writable
@ -101,8 +101,9 @@ internal class ServiceConfigGeneratorTest {
}
}
}
val sut = ServiceConfigGenerator(listOf(ServiceCustomizer()))
val symbolProvider = testSymbolProvider("namespace empty".asSmithyModel())
val ctx = testClientCodegenContext()
val sut = ServiceConfigGenerator(ctx, listOf(ServiceCustomizer()))
val symbolProvider = ctx.symbolProvider
val project = TestWorkspace.testProject(symbolProvider)
project.withModule(ClientRustModule.Config) {
sut.render(this)

View File

@ -180,7 +180,11 @@ fun RustWriter.rustInline(
/* rewrite #{foo} to #{foo:T} (the smithy template format) */
private fun transformTemplate(template: String, scope: Array<out Pair<String, Any>>, trim: Boolean = true): String {
check(scope.distinctBy { it.first.lowercase() }.size == scope.size) { "Duplicate cased keys not supported" }
check(
scope.distinctBy {
it.first.lowercase()
}.size == scope.distinctBy { it.first }.size,
) { "Duplicate cased keys not supported" }
val output = template.replace(Regex("""#\{([a-zA-Z_0-9]+)(:\w)?\}""")) { matchResult ->
val keyName = matchResult.groupValues[1]
val templateType = matchResult.groupValues[2].ifEmpty { ":T" }

View File

@ -959,7 +959,8 @@ mod tests {
interceptors
.read_before_transmit(&mut InterceptorContext::new(Input::new(5)), &mut cfg)
.expect_err("interceptor returns error");
cfg.put(disable_interceptor::<PanicInterceptor>("test"));
cfg.interceptor_state()
.put(disable_interceptor::<PanicInterceptor>("test"));
assert_eq!(
interceptors
.interceptors()

View File

@ -12,11 +12,12 @@ use aws_smithy_async::future::now_or_later::NowOrLater;
use aws_smithy_async::rt::sleep::SharedAsyncSleep;
use aws_smithy_async::time::{SharedTimeSource, TimeSource};
use aws_smithy_http::body::SdkBody;
use aws_smithy_types::config_bag::ConfigBag;
use aws_smithy_types::config_bag::{ConfigBag, Layer};
use aws_smithy_types::endpoint::Endpoint;
use aws_smithy_types::type_erasure::{TypeErasedBox, TypedBox};
use bytes::Bytes;
use std::fmt;
use std::fmt::Debug;
use std::future::Future as StdFuture;
use std::pin::Pin;
use std::sync::Arc;
@ -98,196 +99,252 @@ pub enum LoadedRequestBody {
Loaded(Bytes),
}
pub trait ConfigBagAccessors {
fn auth_option_resolver_params(&self) -> &AuthOptionResolverParams;
fn set_auth_option_resolver_params(
&mut self,
auth_option_resolver_params: AuthOptionResolverParams,
);
fn auth_option_resolver(&self) -> &dyn AuthOptionResolver;
fn set_auth_option_resolver(&mut self, auth_option_resolver: impl AuthOptionResolver + 'static);
fn endpoint_resolver_params(&self) -> &EndpointResolverParams;
fn set_endpoint_resolver_params(&mut self, endpoint_resolver_params: EndpointResolverParams);
fn endpoint_resolver(&self) -> &dyn EndpointResolver;
fn set_endpoint_resolver(&mut self, endpoint_resolver: impl EndpointResolver + 'static);
fn identity_resolvers(&self) -> &IdentityResolvers;
fn set_identity_resolvers(&mut self, identity_resolvers: IdentityResolvers);
fn connection(&self) -> &dyn Connection;
fn set_connection(&mut self, connection: impl Connection + 'static);
fn http_auth_schemes(&self) -> &HttpAuthSchemes;
fn set_http_auth_schemes(&mut self, http_auth_schemes: HttpAuthSchemes);
fn request_serializer(&self) -> Arc<dyn RequestSerializer>;
fn set_request_serializer(&mut self, request_serializer: impl RequestSerializer + 'static);
fn response_deserializer(&self) -> &dyn ResponseDeserializer;
fn set_response_deserializer(
&mut self,
response_serializer: impl ResponseDeserializer + 'static,
);
fn retry_classifiers(&self) -> &RetryClassifiers;
fn set_retry_classifiers(&mut self, retry_classifier: RetryClassifiers);
fn retry_strategy(&self) -> Option<&dyn RetryStrategy>;
fn set_retry_strategy(&mut self, retry_strategy: impl RetryStrategy + 'static);
fn request_time(&self) -> Option<SharedTimeSource>;
fn set_request_time(&mut self, time_source: impl TimeSource + 'static);
fn sleep_impl(&self) -> Option<SharedAsyncSleep>;
fn set_sleep_impl(&mut self, async_sleep: Option<SharedAsyncSleep>);
fn loaded_request_body(&self) -> &LoadedRequestBody;
fn set_loaded_request_body(&mut self, loaded_request_body: LoadedRequestBody);
pub trait Settable {
fn layer(&mut self) -> &mut Layer;
fn put<T: Send + Sync + Debug + 'static>(&mut self, value: T) {
self.layer().put(value);
}
}
const NOT_NEEDED: LoadedRequestBody = LoadedRequestBody::NotNeeded;
pub trait Gettable {
fn config_bag(&self) -> &ConfigBag;
fn get<T: Send + Sync + Debug + 'static>(&self) -> Option<&T> {
self.config_bag().get::<T>()
}
}
impl ConfigBagAccessors for ConfigBag {
fn auth_option_resolver_params(&self) -> &AuthOptionResolverParams {
self.get::<AuthOptionResolverParams>()
impl Settable for Layer {
fn layer(&mut self) -> &mut Layer {
self
}
}
impl Gettable for ConfigBag {
fn config_bag(&self) -> &ConfigBag {
self
}
}
pub trait ConfigBagAccessors {
fn auth_option_resolver_params(&self) -> &AuthOptionResolverParams
where
Self: Gettable,
{
self.config_bag()
.get::<AuthOptionResolverParams>()
.expect("auth option resolver params must be set")
}
fn set_auth_option_resolver_params(
&mut self,
auth_option_resolver_params: AuthOptionResolverParams,
) {
) where
Self: Settable,
{
self.put::<AuthOptionResolverParams>(auth_option_resolver_params);
}
fn auth_option_resolver(&self) -> &dyn AuthOptionResolver {
fn auth_option_resolver(&self) -> &dyn AuthOptionResolver
where
Self: Gettable,
{
&**self
.config_bag()
.get::<Box<dyn AuthOptionResolver>>()
.expect("an auth option resolver must be set")
}
fn set_auth_option_resolver(
&mut self,
auth_option_resolver: impl AuthOptionResolver + 'static,
) {
fn set_auth_option_resolver(&mut self, auth_option_resolver: impl AuthOptionResolver + 'static)
where
Self: Settable,
{
self.put::<Box<dyn AuthOptionResolver>>(Box::new(auth_option_resolver));
}
fn endpoint_resolver_params(&self) -> &EndpointResolverParams {
self.get::<EndpointResolverParams>()
fn endpoint_resolver_params(&self) -> &EndpointResolverParams
where
Self: Gettable,
{
self.config_bag()
.get::<EndpointResolverParams>()
.expect("endpoint resolver params must be set")
}
fn set_endpoint_resolver_params(&mut self, endpoint_resolver_params: EndpointResolverParams) {
fn set_endpoint_resolver_params(&mut self, endpoint_resolver_params: EndpointResolverParams)
where
Self: Settable,
{
self.put::<EndpointResolverParams>(endpoint_resolver_params);
}
fn endpoint_resolver(&self) -> &dyn EndpointResolver {
fn endpoint_resolver(&self) -> &dyn EndpointResolver
where
Self: Gettable,
{
&**self
.config_bag()
.get::<Box<dyn EndpointResolver>>()
.expect("an endpoint resolver must be set")
}
fn set_endpoint_resolver(&mut self, endpoint_resolver: impl EndpointResolver + 'static) {
fn set_endpoint_resolver(&mut self, endpoint_resolver: impl EndpointResolver + 'static)
where
Self: Settable,
{
self.put::<Box<dyn EndpointResolver>>(Box::new(endpoint_resolver));
}
fn identity_resolvers(&self) -> &IdentityResolvers {
self.get::<IdentityResolvers>()
fn identity_resolvers(&self) -> &IdentityResolvers
where
Self: Gettable,
{
self.config_bag()
.get::<IdentityResolvers>()
.expect("identity resolvers must be configured")
}
fn set_identity_resolvers(&mut self, identity_resolvers: IdentityResolvers) {
fn set_identity_resolvers(&mut self, identity_resolvers: IdentityResolvers)
where
Self: Settable,
{
self.put::<IdentityResolvers>(identity_resolvers);
}
fn connection(&self) -> &dyn Connection {
fn connection(&self) -> &dyn Connection
where
Self: Gettable,
{
&**self
.config_bag()
.get::<Box<dyn Connection>>()
.expect("missing connector")
}
fn set_connection(&mut self, connection: impl Connection + 'static) {
fn set_connection(&mut self, connection: impl Connection + 'static)
where
Self: Settable,
{
self.put::<Box<dyn Connection>>(Box::new(connection));
}
fn http_auth_schemes(&self) -> &HttpAuthSchemes {
self.get::<HttpAuthSchemes>()
fn http_auth_schemes(&self) -> &HttpAuthSchemes
where
Self: Gettable,
{
self.config_bag()
.get::<HttpAuthSchemes>()
.expect("auth schemes must be set")
}
fn set_http_auth_schemes(&mut self, http_auth_schemes: HttpAuthSchemes) {
fn set_http_auth_schemes(&mut self, http_auth_schemes: HttpAuthSchemes)
where
Self: Settable,
{
self.put::<HttpAuthSchemes>(http_auth_schemes);
}
fn request_serializer(&self) -> Arc<dyn RequestSerializer> {
fn request_serializer(&self) -> Arc<dyn RequestSerializer>
where
Self: Gettable,
{
self.get::<Arc<dyn RequestSerializer>>()
.expect("missing request serializer")
.clone()
}
fn set_request_serializer(&mut self, request_serializer: impl RequestSerializer + 'static) {
fn set_request_serializer(&mut self, request_serializer: impl RequestSerializer + 'static)
where
Self: Settable,
{
self.put::<Arc<dyn RequestSerializer>>(Arc::new(request_serializer));
}
fn response_deserializer(&self) -> &dyn ResponseDeserializer {
fn response_deserializer(&self) -> &dyn ResponseDeserializer
where
Self: Gettable,
{
&**self
.get::<Box<dyn ResponseDeserializer>>()
.expect("missing response deserializer")
}
fn set_response_deserializer(
&mut self,
response_deserializer: impl ResponseDeserializer + 'static,
) {
) where
Self: Settable,
{
self.put::<Box<dyn ResponseDeserializer>>(Box::new(response_deserializer));
}
fn retry_classifiers(&self) -> &RetryClassifiers {
fn retry_classifiers(&self) -> &RetryClassifiers
where
Self: Gettable,
{
self.get::<RetryClassifiers>()
.expect("retry classifiers must be set")
}
fn set_retry_classifiers(&mut self, retry_classifiers: RetryClassifiers) {
fn set_retry_classifiers(&mut self, retry_classifiers: RetryClassifiers)
where
Self: Settable,
{
self.put::<RetryClassifiers>(retry_classifiers);
}
fn retry_strategy(&self) -> Option<&dyn RetryStrategy> {
fn retry_strategy(&self) -> Option<&dyn RetryStrategy>
where
Self: Gettable,
{
self.get::<Box<dyn RetryStrategy>>().map(|rs| &**rs)
}
fn set_retry_strategy(&mut self, retry_strategy: impl RetryStrategy + 'static) {
fn set_retry_strategy(&mut self, retry_strategy: impl RetryStrategy + 'static)
where
Self: Settable,
{
self.put::<Box<dyn RetryStrategy>>(Box::new(retry_strategy));
}
fn request_time(&self) -> Option<SharedTimeSource> {
fn request_time(&self) -> Option<SharedTimeSource>
where
Self: Gettable,
{
self.get::<SharedTimeSource>().cloned()
}
fn set_request_time(&mut self, request_time: impl TimeSource + 'static) {
self.put::<SharedTimeSource>(SharedTimeSource::new(request_time));
fn set_request_time(&mut self, time_source: impl TimeSource + 'static)
where
Self: Settable,
{
self.put::<SharedTimeSource>(SharedTimeSource::new(time_source));
}
fn sleep_impl(&self) -> Option<SharedAsyncSleep> {
fn sleep_impl(&self) -> Option<SharedAsyncSleep>
where
Self: Gettable,
{
self.get::<SharedAsyncSleep>().cloned()
}
fn set_sleep_impl(&mut self, sleep_impl: Option<SharedAsyncSleep>) {
if let Some(sleep_impl) = sleep_impl {
fn set_sleep_impl(&mut self, async_sleep: Option<SharedAsyncSleep>)
where
Self: Settable,
{
if let Some(sleep_impl) = async_sleep {
self.put::<SharedAsyncSleep>(sleep_impl);
} else {
self.unset::<SharedAsyncSleep>();
self.layer().unset::<SharedAsyncSleep>();
}
}
fn loaded_request_body(&self) -> &LoadedRequestBody {
fn loaded_request_body(&self) -> &LoadedRequestBody
where
Self: Gettable,
{
self.get::<LoadedRequestBody>().unwrap_or(&NOT_NEEDED)
}
fn set_loaded_request_body(&mut self, loaded_request_body: LoadedRequestBody) {
fn set_loaded_request_body(&mut self, loaded_request_body: LoadedRequestBody)
where
Self: Settable,
{
self.put::<LoadedRequestBody>(loaded_request_body);
}
}
const NOT_NEEDED: LoadedRequestBody = LoadedRequestBody::NotNeeded;
impl ConfigBagAccessors for ConfigBag {}
impl ConfigBagAccessors for Layer {}

View File

@ -4,27 +4,34 @@
*/
use crate::client::interceptors::InterceptorRegistrar;
use aws_smithy_types::config_bag::ConfigBag;
use aws_smithy_types::config_bag::{ConfigBag, FrozenLayer};
use std::fmt::Debug;
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
pub type BoxRuntimePlugin = Box<dyn RuntimePlugin + Send + Sync>;
/// RuntimePlugin Trait
///
/// A RuntimePlugin is the unit of configuration for augmenting the SDK with new behavior
///
/// Runtime plugins can set configuration and register interceptors.
pub trait RuntimePlugin: Debug {
fn configure(
&self,
cfg: &mut ConfigBag,
interceptors: &mut InterceptorRegistrar,
) -> Result<(), BoxError>;
fn config(&self) -> Option<FrozenLayer> {
None
}
fn interceptors(&self, interceptors: &mut InterceptorRegistrar) {
let _ = interceptors;
}
}
impl RuntimePlugin for BoxRuntimePlugin {
fn configure(
&self,
cfg: &mut ConfigBag,
interceptors: &mut InterceptorRegistrar,
) -> Result<(), BoxError> {
self.as_ref().configure(cfg, interceptors)
fn config(&self) -> Option<FrozenLayer> {
self.as_ref().config()
}
fn interceptors(&self, interceptors: &mut InterceptorRegistrar) {
self.as_ref().interceptors(interceptors)
}
}
@ -61,7 +68,10 @@ impl RuntimePlugins {
interceptors: &mut InterceptorRegistrar,
) -> Result<(), BoxError> {
for plugin in self.client_plugins.iter() {
plugin.configure(cfg, interceptors)?;
if let Some(layer) = plugin.config() {
cfg.push_shared_layer(layer);
}
plugin.interceptors(interceptors);
}
Ok(())
@ -73,7 +83,10 @@ impl RuntimePlugins {
interceptors: &mut InterceptorRegistrar,
) -> Result<(), BoxError> {
for plugin in self.operation_plugins.iter() {
plugin.configure(cfg, interceptors)?;
if let Some(layer) = plugin.config() {
cfg.push_shared_layer(layer);
}
plugin.interceptors(interceptors);
}
Ok(())
@ -82,22 +95,12 @@ impl RuntimePlugins {
#[cfg(test)]
mod tests {
use super::{BoxError, RuntimePlugin, RuntimePlugins};
use crate::client::interceptors::InterceptorRegistrar;
use aws_smithy_types::config_bag::ConfigBag;
use super::{RuntimePlugin, RuntimePlugins};
#[derive(Debug)]
struct SomeStruct;
impl RuntimePlugin for SomeStruct {
fn configure(
&self,
_cfg: &mut ConfigBag,
_inters: &mut InterceptorRegistrar,
) -> Result<(), BoxError> {
todo!()
}
}
impl RuntimePlugin for SomeStruct {}
#[test]
fn can_add_runtime_plugin_implementors_to_runtime_plugins() {

View File

@ -115,7 +115,7 @@ fn apply_configuration(
runtime_plugins: &RuntimePlugins,
) -> Result<(), BoxError> {
runtime_plugins.apply_client_configuration(cfg, interceptors.client_interceptors_mut())?;
continue_on_err!([ctx] =>interceptors.client_read_before_execution(ctx, cfg));
continue_on_err!([ctx] => interceptors.client_read_before_execution(ctx, cfg));
runtime_plugins
.apply_operation_configuration(cfg, interceptors.operation_interceptors_mut())?;
continue_on_err!([ctx] => interceptors.operation_read_before_execution(ctx, cfg));
@ -150,7 +150,8 @@ async fn try_op(
let loaded_body = halt_on_err!([ctx] => ByteStream::new(body).collect().await).into_bytes();
*ctx.request_mut().as_mut().expect("set above").body_mut() =
SdkBody::from(loaded_body.clone());
cfg.set_loaded_request_body(LoadedRequestBody::Loaded(loaded_body));
cfg.interceptor_state()
.set_loaded_request_body(LoadedRequestBody::Loaded(loaded_body));
}
// Before transmit
@ -191,7 +192,7 @@ async fn try_op(
break;
}
// Track which attempt we're currently on.
cfg.put::<RequestAttempts>(i.into());
cfg.interceptor_state().put::<RequestAttempts>(i.into());
let attempt_timeout_config = cfg.maybe_timeout_config(TimeoutKind::OperationAttempt);
let maybe_timeout = async {
try_attempt(ctx, cfg, interceptors, stop_point).await;
@ -338,7 +339,7 @@ mod tests {
};
use aws_smithy_runtime_api::client::orchestrator::{ConfigBagAccessors, OrchestratorError};
use aws_smithy_runtime_api::client::runtime_plugin::{BoxError, RuntimePlugin, RuntimePlugins};
use aws_smithy_types::config_bag::ConfigBag;
use aws_smithy_types::config_bag::{ConfigBag, FrozenLayer, Layer};
use aws_smithy_types::type_erasure::{TypeErasedBox, TypedBox};
use http::StatusCode;
use std::sync::atomic::{AtomicBool, Ordering};
@ -367,11 +368,8 @@ mod tests {
struct TestOperationRuntimePlugin;
impl RuntimePlugin for TestOperationRuntimePlugin {
fn configure(
&self,
cfg: &mut ConfigBag,
_interceptors: &mut InterceptorRegistrar,
) -> Result<(), BoxError> {
fn config(&self) -> Option<FrozenLayer> {
let mut cfg = Layer::new("test operation");
cfg.set_request_serializer(new_request_serializer());
cfg.set_response_deserializer(new_response_deserializer());
cfg.set_retry_strategy(NeverRetryStrategy::new());
@ -379,7 +377,7 @@ mod tests {
cfg.set_endpoint_resolver_params(StaticUriEndpointResolverParams::new().into());
cfg.set_connection(OkConnector::new());
Ok(())
Some(cfg.freeze())
}
}
@ -416,14 +414,8 @@ mod tests {
struct FailingInterceptorsClientRuntimePlugin;
impl RuntimePlugin for FailingInterceptorsClientRuntimePlugin {
fn configure(
&self,
_cfg: &mut ConfigBag,
interceptors: &mut InterceptorRegistrar,
) -> Result<(), BoxError> {
fn interceptors(&self, interceptors: &mut InterceptorRegistrar) {
interceptors.register(SharedInterceptor::new(FailingInterceptorA));
Ok(())
}
}
@ -431,15 +423,9 @@ mod tests {
struct FailingInterceptorsOperationRuntimePlugin;
impl RuntimePlugin for FailingInterceptorsOperationRuntimePlugin {
fn configure(
&self,
_cfg: &mut ConfigBag,
interceptors: &mut InterceptorRegistrar,
) -> Result<(), BoxError> {
fn interceptors(&self, interceptors: &mut InterceptorRegistrar) {
interceptors.register(SharedInterceptor::new(FailingInterceptorB));
interceptors.register(SharedInterceptor::new(FailingInterceptorC));
Ok(())
}
}
@ -447,7 +433,7 @@ mod tests {
let runtime_plugins = RuntimePlugins::new()
.with_client_plugin(FailingInterceptorsClientRuntimePlugin)
.with_operation_plugin(TestOperationRuntimePlugin)
.with_operation_plugin(AnonymousAuthRuntimePlugin)
.with_operation_plugin(AnonymousAuthRuntimePlugin::new())
.with_operation_plugin(FailingInterceptorsOperationRuntimePlugin);
let actual = invoke(input, &runtime_plugins)
.await
@ -702,22 +688,16 @@ mod tests {
struct InterceptorsTestOperationRuntimePlugin;
impl RuntimePlugin for InterceptorsTestOperationRuntimePlugin {
fn configure(
&self,
_cfg: &mut ConfigBag,
interceptors: &mut InterceptorRegistrar,
) -> Result<(), BoxError> {
fn interceptors(&self, interceptors: &mut InterceptorRegistrar) {
interceptors.register(SharedInterceptor::new(OriginInterceptor));
interceptors.register(SharedInterceptor::new(DestinationInterceptor));
Ok(())
}
}
let input = TypeErasedBox::new(Box::new(()));
let runtime_plugins = RuntimePlugins::new()
.with_operation_plugin(TestOperationRuntimePlugin)
.with_operation_plugin(AnonymousAuthRuntimePlugin)
.with_operation_plugin(AnonymousAuthRuntimePlugin::new())
.with_operation_plugin(InterceptorsTestOperationRuntimePlugin);
let actual = invoke(input, &runtime_plugins)
.await
@ -965,7 +945,7 @@ mod tests {
let runtime_plugins = || {
RuntimePlugins::new()
.with_operation_plugin(TestOperationRuntimePlugin)
.with_operation_plugin(AnonymousAuthRuntimePlugin)
.with_operation_plugin(AnonymousAuthRuntimePlugin::new())
};
// StopPoint::None should result in a response getting set since orchestration doesn't stop
@ -1043,15 +1023,14 @@ mod tests {
interceptor: TestInterceptor,
}
impl RuntimePlugin for TestInterceptorRuntimePlugin {
fn configure(
&self,
cfg: &mut ConfigBag,
interceptors: &mut InterceptorRegistrar,
) -> Result<(), BoxError> {
cfg.put(self.interceptor.clone());
fn config(&self) -> Option<FrozenLayer> {
let mut layer = Layer::new("test");
layer.put(self.interceptor.clone());
Some(layer.freeze())
}
fn interceptors(&self, interceptors: &mut InterceptorRegistrar) {
interceptors.register(SharedInterceptor::new(self.interceptor.clone()));
Ok(())
}
}
@ -1059,7 +1038,7 @@ mod tests {
let runtime_plugins = || {
RuntimePlugins::new()
.with_operation_plugin(TestOperationRuntimePlugin)
.with_operation_plugin(AnonymousAuthRuntimePlugin)
.with_operation_plugin(AnonymousAuthRuntimePlugin::new())
.with_operation_plugin(TestInterceptorRuntimePlugin {
interceptor: interceptor.clone(),
})

View File

@ -140,6 +140,7 @@ mod tests {
use aws_smithy_runtime_api::client::identity::{Identity, IdentityResolver, IdentityResolvers};
use aws_smithy_runtime_api::client::interceptors::InterceptorContext;
use aws_smithy_runtime_api::client::orchestrator::{Future, HttpRequest};
use aws_smithy_types::config_bag::Layer;
use aws_smithy_types::type_erasure::TypedBox;
use std::collections::HashMap;
@ -200,21 +201,23 @@ mod tests {
let _ = ctx.take_input();
ctx.enter_before_transmit_phase();
let mut cfg = ConfigBag::base();
cfg.set_auth_option_resolver_params(AuthOptionResolverParams::new("doesntmatter"));
cfg.set_auth_option_resolver(StaticAuthOptionResolver::new(vec![TEST_SCHEME_ID]));
cfg.set_identity_resolvers(
let mut layer = Layer::new("test");
layer.set_auth_option_resolver_params(AuthOptionResolverParams::new("doesntmatter"));
layer.set_auth_option_resolver(StaticAuthOptionResolver::new(vec![TEST_SCHEME_ID]));
layer.set_identity_resolvers(
IdentityResolvers::builder()
.identity_resolver(TEST_SCHEME_ID, TestIdentityResolver)
.build(),
);
cfg.set_http_auth_schemes(
layer.set_http_auth_schemes(
HttpAuthSchemes::builder()
.auth_scheme(TEST_SCHEME_ID, TestAuthScheme { signer: TestSigner })
.build(),
);
cfg.put(Endpoint::builder().url("dontcare").build());
layer.put(Endpoint::builder().url("dontcare").build());
let mut cfg = ConfigBag::base();
cfg.push_layer(layer);
orchestrate_auth(&mut ctx, &cfg).await.expect("success");
assert_eq!(
@ -242,26 +245,27 @@ mod tests {
let _ = ctx.take_input();
ctx.enter_before_transmit_phase();
let mut cfg = ConfigBag::base();
cfg.set_auth_option_resolver_params(AuthOptionResolverParams::new("doesntmatter"));
cfg.set_auth_option_resolver(StaticAuthOptionResolver::new(vec![
let mut layer = Layer::new("test");
layer.set_auth_option_resolver_params(AuthOptionResolverParams::new("doesntmatter"));
layer.set_auth_option_resolver(StaticAuthOptionResolver::new(vec![
HTTP_BASIC_AUTH_SCHEME_ID,
HTTP_BEARER_AUTH_SCHEME_ID,
]));
cfg.set_http_auth_schemes(
layer.set_http_auth_schemes(
HttpAuthSchemes::builder()
.auth_scheme(HTTP_BASIC_AUTH_SCHEME_ID, BasicAuthScheme::new())
.auth_scheme(HTTP_BEARER_AUTH_SCHEME_ID, BearerAuthScheme::new())
.build(),
);
cfg.put(Endpoint::builder().url("dontcare").build());
layer.put(Endpoint::builder().url("dontcare").build());
// First, test the presence of a basic auth login and absence of a bearer token
cfg.set_identity_resolvers(
layer.set_identity_resolvers(
IdentityResolvers::builder()
.identity_resolver(HTTP_BASIC_AUTH_SCHEME_ID, Login::new("a", "b", None))
.build(),
);
let mut cfg = ConfigBag::of_layers(vec![layer]);
orchestrate_auth(&mut ctx, &cfg).await.expect("success");
assert_eq!(
@ -274,12 +278,15 @@ mod tests {
.unwrap()
);
let mut additional_resolver = Layer::new("extra");
// Next, test the presence of a bearer token and absence of basic auth
cfg.set_identity_resolvers(
additional_resolver.set_identity_resolvers(
IdentityResolvers::builder()
.identity_resolver(HTTP_BEARER_AUTH_SCHEME_ID, Token::new("t", None))
.build(),
);
cfg.push_layer(additional_resolver);
let mut ctx = InterceptorContext::new(TypedBox::new("doesnt-matter").erase());
ctx.enter_serialization_phase();

View File

@ -100,7 +100,7 @@ pub(super) fn orchestrate_endpoint(
apply_endpoint(request, &endpoint, endpoint_prefix)?;
// Make the endpoint config available to interceptors
cfg.put(endpoint);
cfg.interceptor_state().put(endpoint);
Ok(())
}

View File

@ -78,7 +78,7 @@ impl Interceptor for ServiceClockSkewInterceptor {
}
};
let skew = ServiceClockSkew::new(calculate_skew(time_sent, time_received));
cfg.put(skew);
cfg.interceptor_state().put(skew);
Ok(())
}
}

View File

@ -144,7 +144,7 @@ mod tests {
use aws_smithy_runtime_api::client::orchestrator::{ConfigBagAccessors, OrchestratorError};
use aws_smithy_runtime_api::client::request_attempts::RequestAttempts;
use aws_smithy_runtime_api::client::retries::{AlwaysRetry, RetryClassifiers, RetryStrategy};
use aws_smithy_types::config_bag::ConfigBag;
use aws_smithy_types::config_bag::{ConfigBag, Layer};
use aws_smithy_types::retry::ErrorKind;
use aws_smithy_types::type_erasure::TypeErasedBox;
use std::time::Duration;
@ -167,9 +167,12 @@ mod tests {
) -> (InterceptorContext, ConfigBag) {
let mut ctx = InterceptorContext::new(TypeErasedBox::doesnt_matter());
ctx.set_output_or_error(Err(OrchestratorError::other("doesn't matter")));
let mut cfg = ConfigBag::base();
cfg.set_retry_classifiers(RetryClassifiers::new().with_classifier(AlwaysRetry(error_kind)));
cfg.put(RequestAttempts::new(current_request_attempts));
let mut layer = Layer::new("test");
layer.set_retry_classifiers(
RetryClassifiers::new().with_classifier(AlwaysRetry(error_kind)),
);
layer.put(RequestAttempts::new(current_request_attempts));
let cfg = ConfigBag::of_layers(vec![layer]);
(ctx, cfg)
}

View File

@ -13,10 +13,9 @@ use aws_smithy_runtime_api::client::auth::{
AuthSchemeEndpointConfig, AuthSchemeId, HttpAuthScheme, HttpAuthSchemes, HttpRequestSigner,
};
use aws_smithy_runtime_api::client::identity::{Identity, IdentityResolver, IdentityResolvers};
use aws_smithy_runtime_api::client::interceptors::InterceptorRegistrar;
use aws_smithy_runtime_api::client::orchestrator::{BoxError, ConfigBagAccessors, HttpRequest};
use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin;
use aws_smithy_types::config_bag::ConfigBag;
use aws_smithy_types::config_bag::{ConfigBag, FrozenLayer, Layer};
const ANONYMOUS_AUTH_SCHEME_ID: AuthSchemeId = AuthSchemeId::new("anonymous");
@ -30,21 +29,18 @@ const ANONYMOUS_AUTH_SCHEME_ID: AuthSchemeId = AuthSchemeId::new("anonymous");
/// - You only need to make anonymous requests, such as when interacting with [Open Data](https://aws.amazon.com/opendata/).
/// - You're writing orchestrator tests and don't care about authentication.
#[non_exhaustive]
#[derive(Debug, Default)]
pub struct AnonymousAuthRuntimePlugin;
#[derive(Debug)]
pub struct AnonymousAuthRuntimePlugin(FrozenLayer);
impl AnonymousAuthRuntimePlugin {
pub fn new() -> Self {
Self
impl Default for AnonymousAuthRuntimePlugin {
fn default() -> Self {
Self::new()
}
}
impl RuntimePlugin for AnonymousAuthRuntimePlugin {
fn configure(
&self,
cfg: &mut ConfigBag,
_interceptors: &mut InterceptorRegistrar,
) -> Result<(), BoxError> {
impl AnonymousAuthRuntimePlugin {
pub fn new() -> Self {
let mut cfg = Layer::new("AnonymousAuth");
cfg.set_auth_option_resolver_params(StaticAuthOptionResolverParams::new().into());
cfg.set_auth_option_resolver(StaticAuthOptionResolver::new(vec![
ANONYMOUS_AUTH_SCHEME_ID,
@ -59,8 +55,13 @@ impl RuntimePlugin for AnonymousAuthRuntimePlugin {
.auth_scheme(ANONYMOUS_AUTH_SCHEME_ID, AnonymousAuthScheme::new())
.build(),
);
Self(cfg.freeze())
}
}
Ok(())
impl RuntimePlugin for AnonymousAuthRuntimePlugin {
fn config(&self) -> Option<FrozenLayer> {
Some(self.0.clone())
}
}

View File

@ -4,12 +4,11 @@
*/
use aws_smithy_runtime_api::client::interceptors::context::{Error, Output};
use aws_smithy_runtime_api::client::interceptors::InterceptorRegistrar;
use aws_smithy_runtime_api::client::orchestrator::{
ConfigBagAccessors, HttpResponse, OrchestratorError, ResponseDeserializer,
};
use aws_smithy_runtime_api::client::runtime_plugin::{BoxError, RuntimePlugin};
use aws_smithy_types::config_bag::ConfigBag;
use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin;
use aws_smithy_types::config_bag::{FrozenLayer, Layer};
use std::sync::Mutex;
#[derive(Default, Debug)]
@ -44,15 +43,12 @@ impl ResponseDeserializer for CannedResponseDeserializer {
}
impl RuntimePlugin for CannedResponseDeserializer {
fn configure(
&self,
cfg: &mut ConfigBag,
_interceptors: &mut InterceptorRegistrar,
) -> Result<(), BoxError> {
fn config(&self) -> Option<FrozenLayer> {
let mut cfg = Layer::new("CannedResponse");
cfg.set_response_deserializer(Self {
inner: Mutex::new(self.take()),
});
Ok(())
Some(cfg.freeze())
}
}

View File

@ -63,7 +63,7 @@ mod tests {
let request_time = UNIX_EPOCH + Duration::from_secs(1624036048);
let interceptor = TestParamsSetterInterceptor::new(
move |_: &mut BeforeTransmitInterceptorContextMut<'_>, cfg: &mut ConfigBag| {
cfg.set_request_time(request_time);
cfg.interceptor_state().set_request_time(request_time);
},
);
interceptor

View File

@ -4,12 +4,11 @@
*/
use aws_smithy_runtime_api::client::interceptors::context::Input;
use aws_smithy_runtime_api::client::interceptors::InterceptorRegistrar;
use aws_smithy_runtime_api::client::orchestrator::{
ConfigBagAccessors, HttpRequest, RequestSerializer,
};
use aws_smithy_runtime_api::client::runtime_plugin::{BoxError, RuntimePlugin};
use aws_smithy_types::config_bag::ConfigBag;
use aws_smithy_types::config_bag::{ConfigBag, FrozenLayer, Layer};
use std::sync::Mutex;
#[derive(Default, Debug)]
@ -50,15 +49,11 @@ impl RequestSerializer for CannedRequestSerializer {
}
impl RuntimePlugin for CannedRequestSerializer {
fn configure(
&self,
cfg: &mut ConfigBag,
_interceptors: &mut InterceptorRegistrar,
) -> Result<(), BoxError> {
fn config(&self) -> Option<FrozenLayer> {
let mut cfg = Layer::new("CannedRequest");
cfg.set_request_serializer(Self {
inner: Mutex::new(self.take()),
});
Ok(())
Some(cfg.freeze())
}
}

View File

@ -183,6 +183,7 @@ mod tests {
use aws_smithy_async::assert_elapsed;
use aws_smithy_async::future::never::Never;
use aws_smithy_async::rt::sleep::TokioSleep;
use aws_smithy_types::config_bag::Layer;
#[tokio::test]
async fn test_no_timeout() {
@ -197,8 +198,10 @@ mod tests {
tokio::time::pause();
let mut cfg = ConfigBag::base();
cfg.put(TimeoutConfig::builder().build());
cfg.set_sleep_impl(Some(sleep_impl));
let mut timeout_config = Layer::new("timeout");
timeout_config.put(TimeoutConfig::builder().build());
timeout_config.set_sleep_impl(Some(sleep_impl));
cfg.push_layer(timeout_config);
underlying_future
.maybe_timeout(&cfg, TimeoutKind::Operation)
@ -221,12 +224,14 @@ mod tests {
tokio::time::pause();
let mut cfg = ConfigBag::base();
cfg.put(
let mut timeout_config = Layer::new("timeout");
timeout_config.put(
TimeoutConfig::builder()
.operation_timeout(Duration::from_millis(250))
.build(),
);
cfg.set_sleep_impl(Some(sleep_impl));
timeout_config.set_sleep_impl(Some(sleep_impl));
cfg.push_layer(timeout_config);
let result = underlying_future
.maybe_timeout(&cfg, TimeoutKind::Operation)

View File

@ -9,6 +9,7 @@
//! with the following properties:
//! 1. A new layer of configuration may be applied onto an existing configuration structure without modifying it or taking ownership.
//! 2. No lifetime shenanigans to deal with
mod storable;
mod typeid_map;
use crate::config_bag::typeid_map::TypeIdMap;
@ -18,17 +19,18 @@ use std::borrow::Cow;
use std::fmt::{Debug, Formatter};
use std::iter::Rev;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use std::slice;
use std::ops::Deref;
use std::slice::Iter;
use std::sync::Arc;
pub use storable::{AppendItemIter, Storable, Store, StoreAppend, StoreReplace};
/// Layered Configuration Structure
///
/// [`ConfigBag`] is the "unlocked" form of the bag. Only the top layer of the bag may be unlocked.
#[must_use]
pub struct ConfigBag {
head: Layer,
interceptor_state: Layer,
tail: Vec<FrozenLayer>,
}
@ -83,100 +85,6 @@ pub struct Layer {
props: TypeIdMap<TypeErasedBox>,
}
/// Trait defining how types can be stored and loaded from the config bag
pub trait Store: Sized + Send + Sync + 'static {
/// Denote the returned type when loaded from the config bag
type ReturnedType<'a>: Send + Sync;
/// Denote the stored type when stored into the config bag
type StoredType: Send + Sync + Debug;
/// Create a returned type from an iterable of items
fn merge_iter(iter: ItemIter<'_, Self>) -> Self::ReturnedType<'_>;
}
/// Store an item in the config bag by replacing the existing value
#[non_exhaustive]
pub struct StoreReplace<U>(PhantomData<U>);
impl<U> Debug for StoreReplace<U> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "StoreReplace")
}
}
/// Store an item in the config bag by effectively appending it to a list
#[non_exhaustive]
pub struct StoreAppend<U>(PhantomData<U>);
impl<U> Debug for StoreAppend<U> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "StoreAppend")
}
}
/// Trait that marks the implementing types as able to be stored in the config bag
pub trait Storable: Send + Sync + Debug + 'static {
/// Specify how an item is stored in the config bag, e.g. [`StoreReplace`] and [`StoreAppend`]
type Storer: Store;
}
impl<U: Send + Sync + Debug + 'static> Store for StoreReplace<U> {
type ReturnedType<'a> = Option<&'a U>;
type StoredType = Value<U>;
fn merge_iter(mut iter: ItemIter<'_, Self>) -> Self::ReturnedType<'_> {
iter.next().and_then(|item| match item {
Value::Set(item) => Some(item),
Value::ExplicitlyUnset(_) => None,
})
}
}
impl<U: Send + Sync + Debug + 'static> Store for StoreAppend<U> {
type ReturnedType<'a> = AppendItemIter<'a, U>;
type StoredType = Value<Vec<U>>;
fn merge_iter(iter: ItemIter<'_, Self>) -> Self::ReturnedType<'_> {
AppendItemIter {
inner: iter,
cur: None,
}
}
}
/// Iterator of items returned by [`StoreAppend`]
pub struct AppendItemIter<'a, U> {
inner: ItemIter<'a, StoreAppend<U>>,
cur: Option<Rev<slice::Iter<'a, U>>>,
}
impl<'a, U> Debug for AppendItemIter<'a, U> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "AppendItemIter")
}
}
impl<'a, U: 'a> Iterator for AppendItemIter<'a, U>
where
U: Send + Sync + Debug + 'static,
{
type Item = &'a U;
fn next(&mut self) -> Option<Self::Item> {
if let Some(buf) = &mut self.cur {
match buf.next() {
Some(item) => return Some(item),
None => self.cur = None,
}
}
match self.inner.next() {
None => None,
Some(Value::Set(u)) => {
self.cur = Some(u.iter().rev());
self.next()
}
Some(Value::ExplicitlyUnset(_)) => None,
}
}
}
impl Debug for Layer {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
struct Items<'a>(&'a Layer);
@ -269,20 +177,25 @@ impl Layer {
/// This can only be used for types that use [`StoreAppend`]
/// ```
/// use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreAppend, StoreReplace};
/// let mut bag = ConfigBag::base();
/// use aws_smithy_types::config_bag::{ConfigBag, Layer, Storable, StoreAppend, StoreReplace};
/// let mut layer_1 = Layer::new("example");
/// #[derive(Debug, PartialEq, Eq)]
/// struct Interceptor(&'static str);
/// impl Storable for Interceptor {
/// type Storer = StoreAppend<Interceptor>;
/// }
///
/// bag.store_append(Interceptor("123"));
/// bag.store_append(Interceptor("456"));
/// layer_1.store_append(Interceptor("321"));
/// layer_1.store_append(Interceptor("654"));
///
/// let mut layer_2 = Layer::new("second layer");
/// layer_2.store_append(Interceptor("987"));
///
/// let bag = ConfigBag::of_layers(vec![layer_1, layer_2]);
///
/// assert_eq!(
/// bag.load::<Interceptor>().collect::<Vec<_>>(),
/// vec![&Interceptor("456"), &Interceptor("123")]
/// vec![&Interceptor("987"), &Interceptor("654"), &Interceptor("321")]
/// );
/// ```
pub fn store_append<T>(&mut self, item: T) -> &mut Self
@ -296,6 +209,17 @@ impl Layer {
self
}
/// Clears the value of type `T` from the config bag
///
/// This internally marks the item of type `T` as cleared as opposed to wiping it out from the
/// config bag.
pub fn clear<T>(&mut self)
where
T: Storable<Storer = StoreAppend<T>>,
{
self.put_directly::<StoreAppend<T>>(Value::ExplicitlyUnset(type_name::<T>()));
}
/// Retrieves the value of type `T` from this layer if exists
fn get<T: Send + Sync + Store + 'static>(&self) -> Option<&T::StoredType> {
self.props
@ -334,21 +258,6 @@ impl FrozenLayer {
}
}
// TODO(refactor of configbag): consider removing these Deref impls—they exist to keep existing code compiling
impl Deref for ConfigBag {
type Target = Layer;
fn deref(&self) -> &Self::Target {
&self.head
}
}
impl DerefMut for ConfigBag {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.head
}
}
impl ConfigBag {
/// Create a new config bag "base".
///
@ -357,38 +266,34 @@ impl ConfigBag {
/// of configuration may then be "frozen" (made immutable) by calling [`ConfigBag::freeze`].
pub fn base() -> Self {
ConfigBag {
head: Layer {
name: Cow::Borrowed("base"),
interceptor_state: Layer {
name: Cow::Borrowed("interceptor_state"),
props: Default::default(),
},
tail: vec![],
}
}
pub fn push_layer(&mut self, layer: &FrozenLayer) -> &mut Self {
if !self.head.empty() {
self.freeze_head();
pub fn of_layers(layers: Vec<Layer>) -> Self {
let mut bag = ConfigBag::base();
for layer in layers {
bag.push_layer(layer);
}
self.tail.push(layer.clone());
bag
}
pub fn push_layer(&mut self, layer: Layer) -> &mut Self {
self.tail.push(layer.freeze());
self
}
fn freeze_head(&mut self) {
let new_head = Layer::new("scratch");
let old_head = std::mem::replace(&mut self.head, new_head);
self.tail.push(old_head.freeze());
pub fn push_shared_layer(&mut self, layer: FrozenLayer) -> &mut Self {
self.tail.push(layer);
self
}
/// Clears the value of type `T` from the config bag
///
/// This internally marks the item of type `T` as cleared as opposed to wiping it out from the
/// config bag.
pub fn clear<T>(&mut self)
where
T: Storable<Storer = StoreAppend<T>>,
{
self.head
.put_directly::<StoreAppend<T>>(Value::ExplicitlyUnset(type_name::<T>()));
pub fn interceptor_state(&mut self) -> &mut Layer {
&mut self.interceptor_state
}
/// Load a value (or values) of type `T` depending on how `T` implements [`Storable`]
@ -410,19 +315,19 @@ impl ConfigBag {
// this code looks weird to satisfy the borrow checker—we can't keep the result of `get_mut`
// alive (even in a returned branch) and then call `store_put`. So: drop the borrow immediately
// store, the value, then pull it right back
if matches!(self.head.get_mut::<StoreReplace<T>>(), None) {
if matches!(self.interceptor_state.get_mut::<StoreReplace<T>>(), None) {
let new_item = match self.tail.iter().find_map(|b| b.load::<T>()) {
Some(item) => item.clone(),
None => return None,
};
self.store_put(new_item);
self.interceptor_state.store_put(new_item);
self.get_mut()
} else if matches!(
self.head.get::<StoreReplace<T>>(),
self.interceptor_state.get::<StoreReplace<T>>(),
Some(Value::ExplicitlyUnset(_))
) {
None
} else if let Some(Value::Set(t)) = self.head.get_mut::<StoreReplace<T>>() {
} else if let Some(Value::Set(t)) = self.interceptor_state.get_mut::<StoreReplace<T>>() {
Some(t)
} else {
unreachable!()
@ -457,7 +362,7 @@ impl ConfigBag {
// alive (even in a returned branch) and then call `store_put`. So: drop the borrow immediately
// store, the value, then pull it right back
if self.get_mut::<T>().is_none() {
self.store_put((default)());
self.interceptor_state.store_put((default)());
return self
.get_mut()
.expect("item was just stored in the top layer");
@ -491,10 +396,13 @@ impl ConfigBag {
) -> ConfigBag {
let mut new_layer = Layer::new(name);
next(&mut new_layer);
let ConfigBag { head, mut tail } = self;
let ConfigBag {
interceptor_state: head,
mut tail,
} = self;
tail.push(head.freeze());
ConfigBag {
head: new_layer,
interceptor_state: new_layer,
tail,
}
}
@ -518,7 +426,7 @@ impl ConfigBag {
fn layers(&self) -> BagIter<'_> {
BagIter {
head: Some(&self.head),
head: Some(&self.interceptor_state),
tail: self.tail.iter().rev(),
}
}
@ -599,7 +507,7 @@ mod test {
let mut base_bag = ConfigBag::base()
.with_fn("a", layer_a)
.with_fn("b", layer_b);
base_bag.put(Prop3);
base_bag.interceptor_state().put(Prop3);
assert!(base_bag.get::<Prop1>().is_some());
#[derive(Debug)]
@ -640,27 +548,29 @@ mod test {
assert_eq!(operation_config.get::<SigningName>().unwrap().0, "s3");
let mut open_bag = operation_config.with_fn("my_custom_info", |_bag: &mut Layer| {});
open_bag.put("foo");
open_bag.interceptor_state().put("foo");
assert_eq!(open_bag.layers().count(), 4);
}
#[test]
fn store_append() {
let mut bag = ConfigBag::base();
let mut layer = Layer::new("test");
#[derive(Debug, PartialEq, Eq)]
struct Interceptor(&'static str);
impl Storable for Interceptor {
type Storer = StoreAppend<Interceptor>;
}
bag.clear::<Interceptor>();
layer.clear::<Interceptor>();
// you can only call store_append because interceptor is marked with a vec
bag.store_append(Interceptor("123"));
bag.store_append(Interceptor("456"));
layer.store_append(Interceptor("123"));
layer.store_append(Interceptor("456"));
let mut bag = bag.add_layer("next");
bag.store_append(Interceptor("789"));
let mut second_layer = Layer::new("next");
second_layer.store_append(Interceptor("789"));
let mut bag = ConfigBag::of_layers(vec![layer, second_layer]);
assert_eq!(
bag.load::<Interceptor>().collect::<Vec<_>>(),
@ -671,7 +581,9 @@ mod test {
]
);
bag.clear::<Interceptor>();
let mut final_layer = Layer::new("final");
final_layer.clear::<Interceptor>();
bag.push_layer(final_layer);
assert_eq!(bag.load::<Interceptor>().count(), 0);
}
@ -684,12 +596,13 @@ mod test {
}
let mut expected = vec![];
let mut bag = ConfigBag::base();
for layer in 0..100 {
bag = bag.add_layer(format!("{}", layer));
for layer_idx in 0..100 {
let mut layer = Layer::new(format!("{}", layer_idx));
for item in 0..100 {
expected.push(TestItem(layer, item));
bag.store_append(TestItem(layer, item));
expected.push(TestItem(layer_idx, item));
layer.store_append(TestItem(layer_idx, item));
}
bag.push_layer(layer);
}
expected.reverse();
assert_eq!(
@ -718,12 +631,17 @@ mod test {
let mut bag_1 = ConfigBag::base();
let mut bag_2 = ConfigBag::base();
bag_1.push_layer(&layer_1).push_layer(&layer_2);
bag_2.push_layer(&layer_2).push_layer(&layer_1);
bag_1
.push_shared_layer(layer_1.clone())
.push_shared_layer(layer_2.clone());
bag_2.push_shared_layer(layer_2).push_shared_layer(layer_1);
// bags have same layers but in different orders
assert_eq!(bag_1.load::<Foo>(), Some(&Foo(1)));
assert_eq!(bag_2.load::<Foo>(), Some(&Foo(0)));
bag_1.interceptor_state().put(Foo(3));
assert_eq!(bag_1.load::<Foo>(), Some(&Foo(3)));
}
#[test]
@ -749,7 +667,7 @@ mod test {
let new_ref = bag.load::<Foo>().unwrap();
assert_eq!(new_ref, &Foo(2));
bag.unset::<Foo>();
bag.interceptor_state().unset::<Foo>();
// if it was unset, we can't clone the current one, that would be wrong
assert_eq!(bag.get_mut::<Foo>(), None);
assert_eq!(bag.get_mut_or_default::<Foo>(), &Foo(0));

View File

@ -0,0 +1,108 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
use crate::config_bag::value::Value;
use crate::config_bag::ItemIter;
use std::fmt::{Debug, Formatter};
use std::iter::Rev;
use std::marker::PhantomData;
use std::slice;
/// Trait defining how types can be stored and loaded from the config bag
pub trait Store: Sized + Send + Sync + 'static {
/// Denote the returned type when loaded from the config bag
type ReturnedType<'a>: Send + Sync;
/// Denote the stored type when stored into the config bag
type StoredType: Send + Sync + Debug;
/// Create a returned type from an iterable of items
fn merge_iter(iter: ItemIter<'_, Self>) -> Self::ReturnedType<'_>;
}
/// Store an item in the config bag by replacing the existing value
#[non_exhaustive]
pub struct StoreReplace<U>(PhantomData<U>);
impl<U> Debug for StoreReplace<U> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "StoreReplace")
}
}
/// Store an item in the config bag by effectively appending it to a list
#[non_exhaustive]
pub struct StoreAppend<U>(PhantomData<U>);
impl<U> Debug for StoreAppend<U> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "StoreAppend")
}
}
/// Trait that marks the implementing types as able to be stored in the config bag
pub trait Storable: Send + Sync + Debug + 'static {
/// Specify how an item is stored in the config bag, e.g. [`StoreReplace`] and [`StoreAppend`]
type Storer: Store;
}
impl<U: Send + Sync + Debug + 'static> Store for StoreReplace<U> {
type ReturnedType<'a> = Option<&'a U>;
type StoredType = Value<U>;
fn merge_iter(mut iter: ItemIter<'_, Self>) -> Self::ReturnedType<'_> {
iter.next().and_then(|item| match item {
Value::Set(item) => Some(item),
Value::ExplicitlyUnset(_) => None,
})
}
}
impl<U: Send + Sync + Debug + 'static> Store for StoreAppend<U> {
type ReturnedType<'a> = AppendItemIter<'a, U>;
type StoredType = Value<Vec<U>>;
fn merge_iter(iter: ItemIter<'_, Self>) -> Self::ReturnedType<'_> {
AppendItemIter {
inner: iter,
cur: None,
}
}
}
/// Iterator of items returned by [`StoreAppend`]
pub struct AppendItemIter<'a, U> {
inner: ItemIter<'a, StoreAppend<U>>,
cur: Option<Rev<slice::Iter<'a, U>>>,
}
impl<'a, U> Debug for AppendItemIter<'a, U> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "AppendItemIter")
}
}
impl<'a, U: 'a> Iterator for AppendItemIter<'a, U>
where
U: Send + Sync + Debug + 'static,
{
type Item = &'a U;
fn next(&mut self) -> Option<Self::Item> {
if let Some(buf) = &mut self.cur {
match buf.next() {
Some(item) => return Some(item),
None => self.cur = None,
}
}
match self.inner.next() {
None => None,
Some(Value::Set(u)) => {
self.cur = Some(u.iter().rev());
self.next()
}
Some(Value::ExplicitlyUnset(_)) => None,
}
}
}