mirror of https://github.com/smithy-lang/smithy-rs
Make the SDK ad hoc tests pass against the orchestrator (#2708)
## Motivation and Context This PR refactors the client protocol test generator machinery to use a client instead of calling `make_operation` directly, and then fixes the ad hoc tests for the orchestrator. The ad hoc tests revealed that overriding the signing region/service via endpoint config was lost when porting SigV4 signing to the orchestrator, so this PR updates the SigV4 `HttpRequestSigner` implementation to restore this functionality. It is doing this in the signer directly rather than via an interceptor since it should only run this logic when SigV4 is the selected auth scheme. Other notable changes: - Adds `--no-fail-fast` arg to `cargoTest` targets so that all Rust tests run in CI rather than stopping on the first failure - Changes `EndpointResolver::resolve_and_apply_endpoint` to just `resolve_endpoint` so that the orchestrator can place the endpoint config into the request state, which is required for the signer to make use of it - Adds a `set_region` method to SDK service configs - Deletes the API Gateway model and integration test from the SDK smoke test since it is covered by the ad hoc tests - Adds a comment explaining where the API Gateway model comes from in the ad hoc tests - Adds a `smithy.runtime.mode` Gradle property to `aws:sdk` and `aws:sdk-adhoc-test` to trivially switch between middleware and orchestrator when testing/generating locally ---- _By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice._
This commit is contained in:
parent
41774b8405
commit
9bfe936fbc
|
@ -11,12 +11,18 @@ pub mod sigv4 {
|
|||
SignableRequest, SignatureLocation, SigningParams, SigningSettings,
|
||||
UriPathNormalizationMode,
|
||||
};
|
||||
use aws_smithy_runtime_api::client::auth::{AuthSchemeId, HttpAuthScheme, HttpRequestSigner};
|
||||
use aws_smithy_runtime_api::client::auth::{
|
||||
AuthSchemeEndpointConfig, AuthSchemeId, HttpAuthScheme, HttpRequestSigner,
|
||||
};
|
||||
use aws_smithy_runtime_api::client::identity::{Identity, IdentityResolver, IdentityResolvers};
|
||||
use aws_smithy_runtime_api::client::orchestrator::{BoxError, ConfigBagAccessors, HttpRequest};
|
||||
use aws_smithy_runtime_api::config_bag::ConfigBag;
|
||||
use aws_types::region::SigningRegion;
|
||||
use aws_smithy_types::Document;
|
||||
use aws_types::region::{Region, SigningRegion};
|
||||
use aws_types::SigningService;
|
||||
use std::borrow::Cow;
|
||||
use std::error::Error as StdError;
|
||||
use std::fmt;
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
const EXPIRATION_WARNING: &str = "Presigned request will expire before the given \
|
||||
|
@ -25,6 +31,53 @@ pub mod sigv4 {
|
|||
/// Auth scheme ID for SigV4.
|
||||
pub const SCHEME_ID: AuthSchemeId = AuthSchemeId::new("sigv4");
|
||||
|
||||
struct EndpointAuthSchemeConfig {
|
||||
signing_region_override: Option<SigningRegion>,
|
||||
signing_service_override: Option<SigningService>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum SigV4SigningError {
|
||||
MissingOperationSigningConfig,
|
||||
MissingSigningRegion,
|
||||
MissingSigningService,
|
||||
WrongIdentityType(Identity),
|
||||
BadTypeInEndpointAuthSchemeConfig(&'static str),
|
||||
}
|
||||
|
||||
impl fmt::Display for SigV4SigningError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
use SigV4SigningError::*;
|
||||
let mut w = |s| f.write_str(s);
|
||||
match self {
|
||||
MissingOperationSigningConfig => w("missing operation signing config for SigV4"),
|
||||
MissingSigningRegion => w("missing signing region for SigV4 signing"),
|
||||
MissingSigningService => w("missing signing service for SigV4 signing"),
|
||||
WrongIdentityType(identity) => {
|
||||
write!(f, "wrong identity type for SigV4: {identity:?}")
|
||||
}
|
||||
BadTypeInEndpointAuthSchemeConfig(field_name) => {
|
||||
write!(
|
||||
f,
|
||||
"unexpected type for `{field_name}` in endpoint auth scheme config",
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StdError for SigV4SigningError {
|
||||
fn source(&self) -> Option<&(dyn StdError + 'static)> {
|
||||
match self {
|
||||
Self::MissingOperationSigningConfig => None,
|
||||
Self::MissingSigningRegion => None,
|
||||
Self::MissingSigningService => None,
|
||||
Self::WrongIdentityType(_) => None,
|
||||
Self::BadTypeInEndpointAuthSchemeConfig(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// SigV4 auth scheme.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct SigV4HttpAuthScheme {
|
||||
|
@ -111,9 +164,9 @@ pub mod sigv4 {
|
|||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct SigV4OperationSigningConfig {
|
||||
/// AWS Region to sign for.
|
||||
pub region: SigningRegion,
|
||||
pub region: Option<SigningRegion>,
|
||||
/// AWS Service to sign for.
|
||||
pub service: SigningService,
|
||||
pub service: Option<SigningService>,
|
||||
/// Signing options.
|
||||
pub signing_options: SigningOptions,
|
||||
}
|
||||
|
@ -165,7 +218,7 @@ pub mod sigv4 {
|
|||
credentials: &'a Credentials,
|
||||
operation_config: &'a SigV4OperationSigningConfig,
|
||||
request_timestamp: SystemTime,
|
||||
) -> SigningParams<'a> {
|
||||
) -> Result<SigningParams<'a>, SigV4SigningError> {
|
||||
if let Some(expires_in) = settings.expires_in {
|
||||
if let Some(creds_expires_time) = credentials.expiry() {
|
||||
let presigned_expires_time = request_timestamp + expires_in;
|
||||
|
@ -178,12 +231,75 @@ pub mod sigv4 {
|
|||
let mut builder = SigningParams::builder()
|
||||
.access_key(credentials.access_key_id())
|
||||
.secret_key(credentials.secret_access_key())
|
||||
.region(operation_config.region.as_ref())
|
||||
.service_name(operation_config.service.as_ref())
|
||||
.region(
|
||||
operation_config
|
||||
.region
|
||||
.as_ref()
|
||||
.ok_or(SigV4SigningError::MissingSigningRegion)?
|
||||
.as_ref(),
|
||||
)
|
||||
.service_name(
|
||||
operation_config
|
||||
.service
|
||||
.as_ref()
|
||||
.ok_or(SigV4SigningError::MissingSigningService)?
|
||||
.as_ref(),
|
||||
)
|
||||
.time(request_timestamp)
|
||||
.settings(settings);
|
||||
builder.set_security_token(credentials.session_token());
|
||||
builder.build().expect("all required fields set")
|
||||
Ok(builder.build().expect("all required fields set"))
|
||||
}
|
||||
|
||||
fn extract_operation_config<'a>(
|
||||
auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'a>,
|
||||
config_bag: &'a ConfigBag,
|
||||
) -> Result<Cow<'a, SigV4OperationSigningConfig>, SigV4SigningError> {
|
||||
let operation_config = config_bag
|
||||
.get::<SigV4OperationSigningConfig>()
|
||||
.ok_or(SigV4SigningError::MissingOperationSigningConfig)?;
|
||||
|
||||
let EndpointAuthSchemeConfig {
|
||||
signing_region_override,
|
||||
signing_service_override,
|
||||
} = Self::extract_endpoint_auth_scheme_config(auth_scheme_endpoint_config)?;
|
||||
|
||||
match (signing_region_override, signing_service_override) {
|
||||
(None, None) => Ok(Cow::Borrowed(operation_config)),
|
||||
(region, service) => {
|
||||
let mut operation_config = operation_config.clone();
|
||||
if region.is_some() {
|
||||
operation_config.region = region;
|
||||
}
|
||||
if service.is_some() {
|
||||
operation_config.service = service;
|
||||
}
|
||||
Ok(Cow::Owned(operation_config))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_endpoint_auth_scheme_config(
|
||||
endpoint_config: AuthSchemeEndpointConfig<'_>,
|
||||
) -> Result<EndpointAuthSchemeConfig, SigV4SigningError> {
|
||||
let (mut signing_region_override, mut signing_service_override) = (None, None);
|
||||
if let Some(config) = endpoint_config.config().and_then(Document::as_object) {
|
||||
use SigV4SigningError::BadTypeInEndpointAuthSchemeConfig as UnexpectedType;
|
||||
signing_region_override = match config.get("signingRegion") {
|
||||
Some(Document::String(s)) => Some(SigningRegion::from(Region::new(s.clone()))),
|
||||
None => None,
|
||||
_ => return Err(UnexpectedType("signingRegion")),
|
||||
};
|
||||
signing_service_override = match config.get("signingName") {
|
||||
Some(Document::String(s)) => Some(SigningService::from(s.to_string())),
|
||||
None => None,
|
||||
_ => return Err(UnexpectedType("signingName")),
|
||||
};
|
||||
}
|
||||
Ok(EndpointAuthSchemeConfig {
|
||||
signing_region_override,
|
||||
signing_service_override,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -192,11 +308,11 @@ pub mod sigv4 {
|
|||
&self,
|
||||
request: &mut HttpRequest,
|
||||
identity: &Identity,
|
||||
auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
|
||||
config_bag: &ConfigBag,
|
||||
) -> Result<(), BoxError> {
|
||||
let operation_config = config_bag
|
||||
.get::<SigV4OperationSigningConfig>()
|
||||
.ok_or("missing operation signing config for SigV4")?;
|
||||
let operation_config =
|
||||
Self::extract_operation_config(auth_scheme_endpoint_config, config_bag)?;
|
||||
let request_time = config_bag.request_time().unwrap_or_default().system_time();
|
||||
|
||||
let credentials = if let Some(creds) = identity.data::<Credentials>() {
|
||||
|
@ -205,12 +321,12 @@ pub mod sigv4 {
|
|||
tracing::debug!("skipped SigV4 signing since signing is optional for this operation and there are no credentials");
|
||||
return Ok(());
|
||||
} else {
|
||||
return Err(format!("wrong identity type for SigV4: {identity:?}").into());
|
||||
return Err(SigV4SigningError::WrongIdentityType(identity.clone()).into());
|
||||
};
|
||||
|
||||
let settings = Self::settings(operation_config);
|
||||
let settings = Self::settings(&operation_config);
|
||||
let signing_params =
|
||||
Self::signing_params(settings, credentials, operation_config, request_time);
|
||||
Self::signing_params(settings, credentials, &operation_config, request_time)?;
|
||||
|
||||
let (signing_instructions, _signature) = {
|
||||
// A body that is already in memory can be signed directly. A body that is not in memory
|
||||
|
@ -250,6 +366,9 @@ pub mod sigv4 {
|
|||
use super::*;
|
||||
use aws_credential_types::Credentials;
|
||||
use aws_sigv4::http_request::SigningSettings;
|
||||
use aws_types::region::SigningRegion;
|
||||
use aws_types::SigningService;
|
||||
use std::collections::HashMap;
|
||||
use std::time::{Duration, SystemTime};
|
||||
use tracing_test::traced_test;
|
||||
|
||||
|
@ -270,8 +389,8 @@ pub mod sigv4 {
|
|||
"test",
|
||||
);
|
||||
let operation_config = SigV4OperationSigningConfig {
|
||||
region: SigningRegion::from_static("test"),
|
||||
service: SigningService::from_static("test"),
|
||||
region: Some(SigningRegion::from_static("test")),
|
||||
service: Some(SigningService::from_static("test")),
|
||||
signing_options: SigningOptions {
|
||||
double_uri_encode: true,
|
||||
content_sha256_header: true,
|
||||
|
@ -283,14 +402,74 @@ pub mod sigv4 {
|
|||
payload_override: None,
|
||||
},
|
||||
};
|
||||
SigV4HttpRequestSigner::signing_params(settings, &credentials, &operation_config, now);
|
||||
SigV4HttpRequestSigner::signing_params(settings, &credentials, &operation_config, now)
|
||||
.unwrap();
|
||||
assert!(!logs_contain(EXPIRATION_WARNING));
|
||||
|
||||
let mut settings = SigningSettings::default();
|
||||
settings.expires_in = Some(creds_expire_in + Duration::from_secs(10));
|
||||
|
||||
SigV4HttpRequestSigner::signing_params(settings, &credentials, &operation_config, now);
|
||||
SigV4HttpRequestSigner::signing_params(settings, &credentials, &operation_config, now)
|
||||
.unwrap();
|
||||
assert!(logs_contain(EXPIRATION_WARNING));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn endpoint_config_overrides_region_and_service() {
|
||||
let mut cfg = ConfigBag::base();
|
||||
cfg.put(SigV4OperationSigningConfig {
|
||||
region: Some(SigningRegion::from(Region::new("override-this-region"))),
|
||||
service: Some(SigningService::from_static("override-this-service")),
|
||||
signing_options: Default::default(),
|
||||
});
|
||||
let config = Document::Object({
|
||||
let mut out = HashMap::new();
|
||||
out.insert("name".to_string(), "sigv4".to_string().into());
|
||||
out.insert(
|
||||
"signingName".to_string(),
|
||||
"qldb-override".to_string().into(),
|
||||
);
|
||||
out.insert(
|
||||
"signingRegion".to_string(),
|
||||
"us-east-override".to_string().into(),
|
||||
);
|
||||
out
|
||||
});
|
||||
let config = AuthSchemeEndpointConfig::new(Some(&config));
|
||||
|
||||
let result =
|
||||
SigV4HttpRequestSigner::extract_operation_config(config, &cfg).expect("success");
|
||||
|
||||
assert_eq!(
|
||||
result.region,
|
||||
Some(SigningRegion::from(Region::new("us-east-override")))
|
||||
);
|
||||
assert_eq!(
|
||||
result.service,
|
||||
Some(SigningService::from_static("qldb-override"))
|
||||
);
|
||||
assert!(matches!(result, Cow::Owned(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn endpoint_config_supports_fallback_when_region_or_service_are_unset() {
|
||||
let mut cfg = ConfigBag::base();
|
||||
cfg.put(SigV4OperationSigningConfig {
|
||||
region: Some(SigningRegion::from(Region::new("us-east-1"))),
|
||||
service: Some(SigningService::from_static("qldb")),
|
||||
signing_options: Default::default(),
|
||||
});
|
||||
let config = AuthSchemeEndpointConfig::empty();
|
||||
|
||||
let result =
|
||||
SigV4HttpRequestSigner::extract_operation_config(config, &cfg).expect("success");
|
||||
|
||||
assert_eq!(
|
||||
result.region,
|
||||
Some(SigningRegion::from(Region::new("us-east-1")))
|
||||
);
|
||||
assert_eq!(result.service, Some(SigningService::from_static("qldb")));
|
||||
assert!(matches!(result, Cow::Borrowed(_)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,6 +37,8 @@ dependencies {
|
|||
implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
|
||||
}
|
||||
|
||||
fun getSmithyRuntimeMode(): String = properties.get("smithy.runtime.mode") ?: "middleware"
|
||||
|
||||
val allCodegenTests = listOf(
|
||||
CodegenTest(
|
||||
"com.amazonaws.apigateway#BackplaneControlService",
|
||||
|
@ -46,6 +48,7 @@ val allCodegenTests = listOf(
|
|||
,
|
||||
"codegen": {
|
||||
"includeFluentClient": false,
|
||||
"enableNewSmithyRuntime": "${getSmithyRuntimeMode()}"
|
||||
},
|
||||
"customizationConfig": {
|
||||
"awsSdk": {
|
||||
|
@ -62,6 +65,7 @@ val allCodegenTests = listOf(
|
|||
,
|
||||
"codegen": {
|
||||
"includeFluentClient": false,
|
||||
"enableNewSmithyRuntime": "${getSmithyRuntimeMode()}"
|
||||
},
|
||||
"customizationConfig": {
|
||||
"awsSdk": {
|
||||
|
|
|
@ -1,8 +1,16 @@
|
|||
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// The API Gateway model is coming from Smithy's protocol tests, and includes an `Accept` header test:
|
||||
// https://github.com/awslabs/smithy/blob/2f6553ff39e6bba9edc644ef5832661821785319/smithy-aws-protocol-tests/model/restJson1/services/apigateway.smithy#L30-L43
|
||||
|
||||
$version: "1.0"
|
||||
|
||||
namespace com.amazonaws.apigateway
|
||||
|
||||
use smithy.rules#endpointRuleSet
|
||||
|
||||
// Add an endpoint ruleset to the Smithy protocol test API Gateway model so that the code generator doesn't fail
|
||||
apply BackplaneControlService @endpointRuleSet({
|
||||
"version": "1.0",
|
||||
"rules": [{
|
||||
|
|
|
@ -14,6 +14,8 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.client.Fluen
|
|||
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerics
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientSection
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.NoClientGenerics
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.DefaultProtocolTestGenerator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolTestGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Feature
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.GenericTypeArg
|
||||
|
@ -96,6 +98,29 @@ class AwsFluentClientDecorator : ClientCodegenDecorator {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun protocolTestGenerator(
|
||||
codegenContext: ClientCodegenContext,
|
||||
baseGenerator: ProtocolTestGenerator,
|
||||
): ProtocolTestGenerator = DefaultProtocolTestGenerator(
|
||||
codegenContext,
|
||||
baseGenerator.protocolSupport,
|
||||
baseGenerator.operationShape,
|
||||
renderClientCreation = { params ->
|
||||
rustTemplate(
|
||||
"""
|
||||
// If the test case was missing endpoint parameters, default a region so it doesn't fail
|
||||
let mut ${params.configBuilderName} = ${params.configBuilderName};
|
||||
if ${params.configBuilderName}.region.is_none() {
|
||||
${params.configBuilderName}.set_region(Some(crate::config::Region::new("us-east-1")));
|
||||
}
|
||||
let config = ${params.configBuilderName}.http_connector(${params.connectorName}).build();
|
||||
let ${params.clientName} = #{Client}::from_conf(config);
|
||||
""",
|
||||
"Client" to ClientRustModule.root.toType().resolve("Client"),
|
||||
)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
private class AwsFluentClientExtensions(types: Types) {
|
||||
|
|
|
@ -171,7 +171,7 @@ class RegionProviderConfig(codegenContext: CodegenContext) : ConfigCustomization
|
|||
)
|
||||
|
||||
ServiceConfig.BuilderStruct ->
|
||||
rustTemplate("region: Option<#{Region}>,", *codegenScope)
|
||||
rustTemplate("pub(crate) region: Option<#{Region}>,", *codegenScope)
|
||||
|
||||
ServiceConfig.BuilderImpl ->
|
||||
rustTemplate(
|
||||
|
@ -191,6 +191,12 @@ class RegionProviderConfig(codegenContext: CodegenContext) : ConfigCustomization
|
|||
self.region = region.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the AWS region to use when making requests.
|
||||
pub fn set_region(&mut self, region: Option<#{Region}>) -> &mut Self {
|
||||
self.region = region;
|
||||
self
|
||||
}
|
||||
""",
|
||||
*codegenScope,
|
||||
)
|
||||
|
|
|
@ -73,6 +73,7 @@ class OperationRetryClassifiersFeature(
|
|||
"ClassifyRetry" to smithyRuntimeApi.resolve("client::retries::ClassifyRetry"),
|
||||
"RetryClassifiers" to smithyRuntimeApi.resolve("client::retries::RetryClassifiers"),
|
||||
"OperationError" to codegenContext.symbolProvider.symbolForOperationError(operationShape),
|
||||
"OrchestratorError" to smithyRuntimeApi.resolve("client::orchestrator::OrchestratorError"),
|
||||
"SdkError" to RuntimeType.smithyHttp(runtimeConfig).resolve("result::SdkError"),
|
||||
"ErasedError" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("type_erasure::TypeErasedError"),
|
||||
)
|
||||
|
@ -89,9 +90,9 @@ class OperationRetryClassifiersFeature(
|
|||
}
|
||||
}
|
||||
impl #{ClassifyRetry} for HttpStatusCodeClassifier {
|
||||
fn classify_retry(&self, error: &#{ErasedError}) -> Option<#{RetryReason}> {
|
||||
let error = error.downcast_ref::<#{SdkError}<#{OperationError}>>().expect("The error type is always known");
|
||||
self.0.classify_error(error)
|
||||
fn classify_retry(&self, _error: &#{OrchestratorError}<#{ErasedError}>) -> Option<#{RetryReason}> {
|
||||
// TODO(enableNewSmithyRuntime): classify the error with self.0
|
||||
None
|
||||
}
|
||||
}
|
||||
""",
|
||||
|
@ -108,9 +109,9 @@ class OperationRetryClassifiersFeature(
|
|||
}
|
||||
}
|
||||
impl #{ClassifyRetry} for AwsErrorCodeClassifier {
|
||||
fn classify_retry(&self, error: &#{ErasedError}) -> Option<#{RetryReason}> {
|
||||
let error = error.downcast_ref::<#{SdkError}<#{OperationError}>>().expect("The error type is always known");
|
||||
self.0.classify_error(error)
|
||||
fn classify_retry(&self, _error: &#{OrchestratorError}<#{ErasedError}>) -> Option<#{RetryReason}> {
|
||||
// TODO(enableNewSmithyRuntime): classify the error with self.0
|
||||
None
|
||||
}
|
||||
}
|
||||
""",
|
||||
|
@ -127,9 +128,9 @@ class OperationRetryClassifiersFeature(
|
|||
}
|
||||
}
|
||||
impl #{ClassifyRetry} for ModeledAsRetryableClassifier {
|
||||
fn classify_retry(&self, error: &#{ErasedError}) -> Option<#{RetryReason}> {
|
||||
let error = error.downcast_ref::<#{SdkError}<#{OperationError}>>().expect("The error type is always known");
|
||||
self.0.classify_error(error)
|
||||
fn classify_retry(&self, _error: &#{OrchestratorError}<#{ErasedError}>) -> Option<#{RetryReason}> {
|
||||
// TODO(enableNewSmithyRuntime): classify the error with self.0
|
||||
None
|
||||
}
|
||||
}
|
||||
""",
|
||||
|
@ -146,9 +147,9 @@ class OperationRetryClassifiersFeature(
|
|||
}
|
||||
}
|
||||
impl #{ClassifyRetry} for AmzRetryAfterHeaderClassifier {
|
||||
fn classify_retry(&self, error: &#{ErasedError}) -> Option<#{RetryReason}> {
|
||||
let error = error.downcast_ref::<#{SdkError}<#{OperationError}>>().expect("The error type is always known");
|
||||
self.0.classify_error(error)
|
||||
fn classify_retry(&self, _error: &#{OrchestratorError}<#{ErasedError}>) -> Option<#{RetryReason}> {
|
||||
// TODO(enableNewSmithyRuntime): classify the error with self.0
|
||||
None
|
||||
}
|
||||
}
|
||||
""",
|
||||
|
@ -165,9 +166,9 @@ class OperationRetryClassifiersFeature(
|
|||
}
|
||||
}
|
||||
impl #{ClassifyRetry} for SmithyErrorClassifier {
|
||||
fn classify_retry(&self, error: &#{ErasedError}) -> Option<#{RetryReason}> {
|
||||
let error = error.downcast_ref::<#{SdkError}<#{OperationError}>>().expect("The error type is always known");
|
||||
self.0.classify_error(error)
|
||||
fn classify_retry(&self, _error: &#{OrchestratorError}<#{ErasedError}>) -> Option<#{RetryReason}> {
|
||||
// TODO(enableNewSmithyRuntime): classify the error with self.0
|
||||
None
|
||||
}
|
||||
}
|
||||
""",
|
||||
|
|
|
@ -126,8 +126,8 @@ private class AuthOperationRuntimePluginCustomization(private val codegenContext
|
|||
val signingOptional = section.operationShape.hasTrait<OptionalAuthTrait>()
|
||||
rustTemplate(
|
||||
"""
|
||||
let signing_region = cfg.get::<#{SigningRegion}>().expect("region required for signing").clone();
|
||||
let signing_service = cfg.get::<#{SigningService}>().expect("service required for signing").clone();
|
||||
let signing_region = cfg.get::<#{SigningRegion}>().cloned();
|
||||
let signing_service = cfg.get::<#{SigningService}>().cloned();
|
||||
let mut signing_options = #{SigningOptions}::default();
|
||||
signing_options.double_uri_encode = $doubleUriEncode;
|
||||
signing_options.content_sha256_header = $contentSha256Header;
|
||||
|
|
|
@ -17,6 +17,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointCustom
|
|||
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginCustomization
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorCustomization
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolTestGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderCustomization
|
||||
|
@ -138,4 +139,11 @@ class ServiceSpecificDecorator(
|
|||
): List<ServiceRuntimePluginCustomization> = baseCustomizations.maybeApply(codegenContext.serviceShape) {
|
||||
delegateTo.serviceRuntimePluginCustomizations(codegenContext, baseCustomizations)
|
||||
}
|
||||
|
||||
override fun protocolTestGenerator(
|
||||
codegenContext: ClientCodegenContext,
|
||||
baseGenerator: ProtocolTestGenerator,
|
||||
): ProtocolTestGenerator = baseGenerator.maybeApply(codegenContext.serviceShape) {
|
||||
delegateTo.protocolTestGenerator(codegenContext, baseGenerator)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ import software.amazon.smithy.rulesengine.traits.EndpointTestOperationInput
|
|||
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointTypesGenerator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.clientInstantiator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.ClientInstantiator
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.AttributeKind
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.escape
|
||||
|
@ -120,7 +120,7 @@ class OperationInputTestGenerator(_ctx: ClientCodegenContext, private val test:
|
|||
private val moduleName = ctx.moduleUseName()
|
||||
private val endpointCustomizations = ctx.rootDecorator.endpointCustomizations(ctx)
|
||||
private val model = ctx.model
|
||||
private val instantiator = clientInstantiator(ctx)
|
||||
private val instantiator = ClientInstantiator(ctx)
|
||||
|
||||
private fun EndpointTestOperationInput.operationId() =
|
||||
ShapeId.fromOptionalNamespace(ctx.serviceShape.id.namespace, operationName)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -60,6 +60,7 @@ val crateVersioner by lazy { aws.sdk.CrateVersioner.defaultFor(rootProject, prop
|
|||
|
||||
fun getRustMSRV(): String = properties.get("rust.msrv") ?: throw Exception("Rust MSRV missing")
|
||||
fun getPreviousReleaseVersionManifestPath(): String? = properties.get("aws.sdk.previous.release.versions.manifest")
|
||||
fun getSmithyRuntimeMode(): String = properties.get("smithy.runtime.mode") ?: "middleware"
|
||||
|
||||
fun loadServiceMembership(): Membership {
|
||||
val membershipOverride = properties.get("aws.services")?.let { parseMembership(it) }
|
||||
|
@ -102,7 +103,7 @@ fun generateSmithyBuild(services: AwsServices): String {
|
|||
"renameErrors": false,
|
||||
"debugMode": $debugMode,
|
||||
"eventStreamAllowList": [$eventStreamAllowListMembers],
|
||||
"enableNewSmithyRuntime": "middleware"
|
||||
"enableNewSmithyRuntime": "${getSmithyRuntimeMode()}"
|
||||
},
|
||||
"service": "${service.service}",
|
||||
"module": "$moduleName",
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
# `./gradlew -Paws.fullsdk=true :aws:sdk:assemble` these tests are copied into their respective Service crates.
|
||||
[workspace]
|
||||
members = [
|
||||
"apigateway",
|
||||
"dynamodb",
|
||||
"ec2",
|
||||
"glacier",
|
||||
|
|
|
@ -1,17 +0,0 @@
|
|||
[package]
|
||||
name = "apigateway-tests"
|
||||
version = "0.1.0"
|
||||
authors = ["AWS Rust SDK Team <aws-sdk-rust@amazon.com>"]
|
||||
edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
repository = "https://github.com/awslabs/smithy-rs"
|
||||
publish = false
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
aws-credential-types = { path = "../../build/aws-sdk/sdk/aws-credential-types", features = ["test-util"] }
|
||||
aws-sdk-apigateway = { path = "../../build/aws-sdk/sdk/apigateway" }
|
||||
aws-smithy-client = { path = "../../build/aws-sdk/sdk/aws-smithy-client", features = ["test-util", "rustls"] }
|
||||
aws-smithy-protocol-test = { path = "../../build/aws-sdk/sdk/aws-smithy-protocol-test"}
|
||||
tokio = { version = "1.23.1", features = ["full", "test-util"]}
|
|
@ -1,32 +0,0 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
use aws_sdk_apigateway::config::{Credentials, Region};
|
||||
use aws_sdk_apigateway::{Client, Config};
|
||||
use aws_smithy_client::test_connection::capture_request;
|
||||
use aws_smithy_protocol_test::{assert_ok, validate_headers};
|
||||
|
||||
#[tokio::test]
|
||||
async fn accept_header_is_application_json() {
|
||||
let (conn, handler) = capture_request(None);
|
||||
let conf = Config::builder()
|
||||
.region(Region::new("us-east-1"))
|
||||
.credentials_provider(Credentials::for_tests())
|
||||
.http_connector(conn)
|
||||
.build();
|
||||
|
||||
let client = Client::from_conf(conf);
|
||||
let _result = client
|
||||
.delete_resource()
|
||||
.rest_api_id("some-rest-api-id")
|
||||
.resource_id("some-resource-id")
|
||||
.send()
|
||||
.await;
|
||||
let request = handler.expect_request();
|
||||
assert_ok(validate_headers(
|
||||
request.headers(),
|
||||
[("accept", "application/json")],
|
||||
));
|
||||
}
|
|
@ -252,7 +252,7 @@ fun Project.registerCargoCommandsTasks(
|
|||
dependsOn(dependentTasks)
|
||||
workingDir(outputDir)
|
||||
environment("RUSTFLAGS", "--cfg aws_sdk_unstable")
|
||||
commandLine("cargo", "test", "--all-features")
|
||||
commandLine("cargo", "test", "--all-features", "--no-fail-fast")
|
||||
}
|
||||
|
||||
this.tasks.register<Exec>(Cargo.DOCS.toString) {
|
||||
|
|
|
@ -24,7 +24,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceGener
|
|||
import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorGenerator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.error.OperationErrorGenerator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolTestGenerator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.DefaultProtocolTestGenerator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientProtocolLoader
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.transformers.AddErrorMessage
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.transformers.RemoveEventStreamOperations
|
||||
|
@ -311,22 +311,24 @@ class ClientCodegenVisitor(
|
|||
)
|
||||
|
||||
// render protocol tests into `operation.rs` (note operationWriter vs. inputWriter)
|
||||
ProtocolTestGenerator(
|
||||
codegenDecorator.protocolTestGenerator(
|
||||
codegenContext,
|
||||
protocolGeneratorFactory.support(),
|
||||
operationShape,
|
||||
this@operationWriter,
|
||||
).render()
|
||||
DefaultProtocolTestGenerator(
|
||||
codegenContext,
|
||||
protocolGeneratorFactory.support(),
|
||||
operationShape,
|
||||
),
|
||||
).render(this@operationWriter)
|
||||
}
|
||||
}
|
||||
|
||||
rustCrate.withModule(symbolProvider.moduleForOperationError(operationShape)) {
|
||||
OperationErrorGenerator(
|
||||
model,
|
||||
symbolProvider,
|
||||
operationShape,
|
||||
codegenDecorator.errorCustomizations(codegenContext, emptyList()),
|
||||
).render(this)
|
||||
rustCrate.withModule(symbolProvider.moduleForOperationError(operationShape)) {
|
||||
OperationErrorGenerator(
|
||||
model,
|
||||
symbolProvider,
|
||||
operationShape,
|
||||
codegenDecorator.errorCustomizations(codegenContext, emptyList()),
|
||||
).render(this)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,25 +15,28 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
|
|||
import software.amazon.smithy.rust.codegen.core.rustlang.writable
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection
|
||||
import software.amazon.smithy.rust.codegen.core.util.orNull
|
||||
|
||||
class EndpointPrefixGenerator(private val codegenContext: ClientCodegenContext, private val shape: OperationShape) :
|
||||
OperationCustomization() {
|
||||
override fun section(section: OperationSection): Writable = when (section) {
|
||||
is OperationSection.MutateRequest -> writable {
|
||||
companion object {
|
||||
fun endpointTraitBindings(codegenContext: ClientCodegenContext, shape: OperationShape): EndpointTraitBindings? =
|
||||
shape.getTrait(EndpointTrait::class.java).map { epTrait ->
|
||||
val endpointTraitBindings = EndpointTraitBindings(
|
||||
EndpointTraitBindings(
|
||||
codegenContext.model,
|
||||
codegenContext.symbolProvider,
|
||||
codegenContext.runtimeConfig,
|
||||
shape,
|
||||
epTrait,
|
||||
)
|
||||
}.orNull()
|
||||
}
|
||||
|
||||
override fun section(section: OperationSection): Writable = when (section) {
|
||||
is OperationSection.MutateRequest -> writable {
|
||||
endpointTraitBindings(codegenContext, shape)?.also { endpointTraitBindings ->
|
||||
withBlock("let endpoint_prefix = ", "?;") {
|
||||
endpointTraitBindings.render(
|
||||
this,
|
||||
"self",
|
||||
codegenContext.smithyRuntimeMode,
|
||||
)
|
||||
endpointTraitBindings.render(this, "self", codegenContext.smithyRuntimeMode)
|
||||
}
|
||||
rust("request.properties_mut().insert(endpoint_prefix);")
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRunti
|
|||
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorCustomization
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolTestGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.customize.CombinedCoreCodegenDecorator
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.customize.CoreCodegenDecorator
|
||||
|
@ -77,6 +78,14 @@ interface ClientCodegenDecorator : CoreCodegenDecorator<ClientCodegenContext> {
|
|||
operation: OperationShape,
|
||||
baseCustomizations: List<OperationRuntimePluginCustomization>,
|
||||
): List<OperationRuntimePluginCustomization> = baseCustomizations
|
||||
|
||||
/**
|
||||
* Hook to override the protocol test generator
|
||||
*/
|
||||
fun protocolTestGenerator(
|
||||
codegenContext: ClientCodegenContext,
|
||||
baseGenerator: ProtocolTestGenerator,
|
||||
): ProtocolTestGenerator = baseGenerator
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -143,6 +152,13 @@ open class CombinedClientCodegenDecorator(decorators: List<ClientCodegenDecorato
|
|||
decorator.operationRuntimePluginCustomizations(codegenContext, operation, customizations)
|
||||
}
|
||||
|
||||
override fun protocolTestGenerator(
|
||||
codegenContext: ClientCodegenContext,
|
||||
baseGenerator: ProtocolTestGenerator,
|
||||
): ProtocolTestGenerator = combineCustomizations(baseGenerator) { decorator, gen ->
|
||||
decorator.protocolTestGenerator(codegenContext, gen)
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun fromClasspath(
|
||||
context: PluginContext,
|
||||
|
|
|
@ -10,10 +10,11 @@ import software.amazon.smithy.rulesengine.language.syntax.Identifier
|
|||
import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters
|
||||
import software.amazon.smithy.rulesengine.traits.EndpointTestCase
|
||||
import software.amazon.smithy.rulesengine.traits.ExpectedEndpoint
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointCustomization
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.Types
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.clientInstantiator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.ClientInstantiator
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.docs
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.escape
|
||||
|
@ -22,7 +23,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
|
|||
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.writable
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
|
||||
import software.amazon.smithy.rust.codegen.core.util.PANIC
|
||||
import software.amazon.smithy.rust.codegen.core.util.dq
|
||||
|
@ -34,8 +34,7 @@ internal class EndpointTestGenerator(
|
|||
private val resolverType: RuntimeType,
|
||||
private val params: Parameters,
|
||||
private val endpointCustomizations: List<EndpointCustomization>,
|
||||
codegenContext: CodegenContext,
|
||||
|
||||
codegenContext: ClientCodegenContext,
|
||||
) {
|
||||
private val runtimeConfig = codegenContext.runtimeConfig
|
||||
private val serviceShape = codegenContext.serviceShape
|
||||
|
@ -50,7 +49,7 @@ internal class EndpointTestGenerator(
|
|||
"capture_request" to RuntimeType.captureRequest(runtimeConfig),
|
||||
)
|
||||
|
||||
private val instantiator = clientInstantiator(codegenContext)
|
||||
private val instantiator = ClientInstantiator(codegenContext)
|
||||
|
||||
private fun EndpointTestCase.docs(): Writable {
|
||||
val self = this
|
||||
|
|
|
@ -6,8 +6,14 @@
|
|||
package software.amazon.smithy.rust.codegen.client.smithy.generators
|
||||
|
||||
import software.amazon.smithy.codegen.core.Symbol
|
||||
import software.amazon.smithy.model.node.Node
|
||||
import software.amazon.smithy.model.node.ObjectNode
|
||||
import software.amazon.smithy.model.shapes.MemberShape
|
||||
import software.amazon.smithy.model.shapes.OperationShape
|
||||
import software.amazon.smithy.model.shapes.StructureShape
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rust
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.writable
|
||||
|
@ -29,11 +35,27 @@ class ClientBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiat
|
|||
override fun doesSetterTakeInOption(memberShape: MemberShape): Boolean = true
|
||||
}
|
||||
|
||||
fun clientInstantiator(codegenContext: CodegenContext) =
|
||||
Instantiator(
|
||||
codegenContext.symbolProvider,
|
||||
codegenContext.model,
|
||||
codegenContext.runtimeConfig,
|
||||
ClientBuilderKindBehavior(codegenContext),
|
||||
::enumFromStringFn,
|
||||
)
|
||||
class ClientInstantiator(private val codegenContext: ClientCodegenContext) : Instantiator(
|
||||
codegenContext.symbolProvider,
|
||||
codegenContext.model,
|
||||
codegenContext.runtimeConfig,
|
||||
ClientBuilderKindBehavior(codegenContext),
|
||||
::enumFromStringFn,
|
||||
) {
|
||||
fun renderFluentCall(
|
||||
writer: RustWriter,
|
||||
clientName: String,
|
||||
operationShape: OperationShape,
|
||||
inputShape: StructureShape,
|
||||
data: Node,
|
||||
headers: Map<String, String> = mapOf(),
|
||||
ctx: Ctx = Ctx(),
|
||||
) {
|
||||
val operationBuilderName =
|
||||
FluentClientGenerator.clientOperationFnName(operationShape, codegenContext.symbolProvider)
|
||||
|
||||
writer.rust("$clientName.$operationBuilderName()")
|
||||
|
||||
renderStructureMembers(writer, inputShape, data as ObjectNode, headers, ctx)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,7 +45,12 @@ class EndpointTraitBindings(
|
|||
*
|
||||
* The returned expression is a `Result<EndpointPrefix, UriError>`
|
||||
*/
|
||||
fun render(writer: RustWriter, input: String, smithyRuntimeMode: SmithyRuntimeMode) {
|
||||
fun render(
|
||||
writer: RustWriter,
|
||||
input: String,
|
||||
smithyRuntimeMode: SmithyRuntimeMode,
|
||||
generateValidation: Boolean = true,
|
||||
) {
|
||||
// the Rust format pattern to make the endpoint prefix e.g. "{}.foo"
|
||||
val formatLiteral = endpointTrait.prefixFormatString()
|
||||
if (endpointTrait.hostPrefix.labels.isEmpty()) {
|
||||
|
@ -68,28 +73,30 @@ class EndpointTraitBindings(
|
|||
// NOTE: this is dead code until we start respecting @required
|
||||
rust("let $field = &$input.$field;")
|
||||
}
|
||||
val contents = if (smithyRuntimeMode.generateOrchestrator) {
|
||||
// TODO(enableNewSmithyRuntime): Remove the allow attribute once all places need .into method
|
||||
"""
|
||||
if $field.is_empty() {
|
||||
##[allow(clippy::useless_conversion)]
|
||||
return Err(#{invalidFieldError:W}.into())
|
||||
if (generateValidation) {
|
||||
val contents = if (smithyRuntimeMode.generateOrchestrator) {
|
||||
// TODO(enableNewSmithyRuntime): Remove the allow attribute once all places need .into method
|
||||
"""
|
||||
if $field.is_empty() {
|
||||
##[allow(clippy::useless_conversion)]
|
||||
return Err(#{invalidFieldError:W}.into())
|
||||
}
|
||||
"""
|
||||
} else {
|
||||
"""
|
||||
if $field.is_empty() {
|
||||
return Err(#{invalidFieldError:W})
|
||||
}
|
||||
"""
|
||||
}
|
||||
"""
|
||||
} else {
|
||||
"""
|
||||
if $field.is_empty() {
|
||||
return Err(#{invalidFieldError:W})
|
||||
}
|
||||
"""
|
||||
rustTemplate(
|
||||
contents,
|
||||
"invalidFieldError" to OperationBuildError(runtimeConfig).invalidField(
|
||||
field,
|
||||
"$field was unset or empty but must be set as part of the endpoint prefix",
|
||||
),
|
||||
)
|
||||
}
|
||||
rustTemplate(
|
||||
contents,
|
||||
"invalidFieldError" to OperationBuildError(runtimeConfig).invalidField(
|
||||
field,
|
||||
"$field was unset or empty but must be set as part of the endpoint prefix",
|
||||
),
|
||||
)
|
||||
"${label.content} = $field"
|
||||
}
|
||||
rustTemplate(
|
||||
|
|
|
@ -20,9 +20,11 @@ import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase
|
|||
import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.clientInstantiator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.customizations.EndpointPrefixGenerator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.ClientInstantiator
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.allow
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
|
||||
|
@ -44,15 +46,47 @@ import software.amazon.smithy.rust.codegen.core.util.outputShape
|
|||
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
|
||||
import java.util.logging.Logger
|
||||
|
||||
data class ClientCreationParams(
|
||||
val codegenContext: ClientCodegenContext,
|
||||
val connectorName: String,
|
||||
val configBuilderName: String,
|
||||
val clientName: String,
|
||||
)
|
||||
|
||||
interface ProtocolTestGenerator {
|
||||
val codegenContext: ClientCodegenContext
|
||||
val protocolSupport: ProtocolSupport
|
||||
val operationShape: OperationShape
|
||||
|
||||
fun render(writer: RustWriter)
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate protocol tests for an operation
|
||||
*/
|
||||
class ProtocolTestGenerator(
|
||||
private val codegenContext: ClientCodegenContext,
|
||||
private val protocolSupport: ProtocolSupport,
|
||||
private val operationShape: OperationShape,
|
||||
private val writer: RustWriter,
|
||||
) {
|
||||
class DefaultProtocolTestGenerator(
|
||||
override val codegenContext: ClientCodegenContext,
|
||||
override val protocolSupport: ProtocolSupport,
|
||||
override val operationShape: OperationShape,
|
||||
|
||||
private val renderClientCreation: RustWriter.(ClientCreationParams) -> Unit = { params ->
|
||||
rustTemplate(
|
||||
"""
|
||||
let smithy_client = #{Builder}::new()
|
||||
.connector(${params.connectorName})
|
||||
.middleware(#{MapRequestLayer}::for_mapper(#{SmithyEndpointStage}::new()))
|
||||
.build();
|
||||
let ${params.clientName} = #{Client}::with_config(smithy_client, ${params.configBuilderName}.build());
|
||||
""",
|
||||
"Client" to ClientRustModule.root.toType().resolve("Client"),
|
||||
"Builder" to ClientRustModule.client.toType().resolve("Builder"),
|
||||
"SmithyEndpointStage" to RuntimeType.smithyHttp(codegenContext.runtimeConfig)
|
||||
.resolve("endpoint::middleware::SmithyEndpointStage"),
|
||||
"MapRequestLayer" to RuntimeType.smithyHttpTower(codegenContext.runtimeConfig)
|
||||
.resolve("map_request::MapRequestLayer"),
|
||||
)
|
||||
},
|
||||
) : ProtocolTestGenerator {
|
||||
private val logger = Logger.getLogger(javaClass.name)
|
||||
|
||||
private val inputShape = operationShape.inputShape(codegenContext.model)
|
||||
|
@ -60,7 +94,7 @@ class ProtocolTestGenerator(
|
|||
private val operationSymbol = codegenContext.symbolProvider.toSymbol(operationShape)
|
||||
private val operationIndex = OperationIndex.of(codegenContext.model)
|
||||
|
||||
private val instantiator = clientInstantiator(codegenContext)
|
||||
private val instantiator = ClientInstantiator(codegenContext)
|
||||
|
||||
private val codegenScope = arrayOf(
|
||||
"SmithyHttp" to RuntimeType.smithyHttp(codegenContext.runtimeConfig),
|
||||
|
@ -75,7 +109,7 @@ class ProtocolTestGenerator(
|
|||
TestCase()
|
||||
}
|
||||
|
||||
fun render() {
|
||||
override fun render(writer: RustWriter) {
|
||||
val requestTests = operationShape.getTrait<HttpRequestTestsTrait>()
|
||||
?.getTestCasesFor(AppliesTo.CLIENT).orEmpty().map { TestCase.RequestTest(it) }
|
||||
val responseTests = operationShape.getTrait<HttpResponseTestsTrait>()
|
||||
|
@ -150,6 +184,7 @@ class ProtocolTestGenerator(
|
|||
is Action.Response -> "_response"
|
||||
is Action.Request -> "_request"
|
||||
}
|
||||
Attribute.AllowUnusedMut.render(testModuleWriter)
|
||||
testModuleWriter.rustBlock("async fn ${testCase.id.toSnakeCase()}$fnName()") {
|
||||
block(this)
|
||||
}
|
||||
|
@ -166,33 +201,58 @@ class ProtocolTestGenerator(
|
|||
writable {
|
||||
val customizations = codegenContext.rootDecorator.endpointCustomizations(codegenContext)
|
||||
params.getObjectMember("builtInParams").orNull()?.members?.forEach { (name, value) ->
|
||||
customizations.firstNotNullOf { it.setBuiltInOnServiceConfig(name.value, value, "builder") }(this)
|
||||
customizations.firstNotNullOf {
|
||||
it.setBuiltInOnServiceConfig(name.value, value, "config_builder")
|
||||
}(this)
|
||||
}
|
||||
}
|
||||
} ?: writable { }
|
||||
rustTemplate(
|
||||
"""
|
||||
let builder = #{config}::Config::builder().with_test_defaults().endpoint_resolver("https://example.com");
|
||||
let (conn, request_receiver) = #{capture_request}(None);
|
||||
let config_builder = #{config}::Config::builder().with_test_defaults().endpoint_resolver("https://example.com");
|
||||
#{customParams}
|
||||
let config = builder.build();
|
||||
|
||||
""",
|
||||
"capture_request" to CargoDependency.smithyClient(codegenContext.runtimeConfig)
|
||||
.toDevDependency()
|
||||
.withFeature("test-util")
|
||||
.toType()
|
||||
.resolve("test_connection::capture_request"),
|
||||
"config" to ClientRustModule.Config,
|
||||
"customParams" to customParams,
|
||||
)
|
||||
writeInline("let input =")
|
||||
instantiator.render(this, inputShape, httpRequestTestCase.params)
|
||||
renderClientCreation(this, ClientCreationParams(codegenContext, "conn", "config_builder", "client"))
|
||||
|
||||
writeInline("let result = ")
|
||||
instantiator.renderFluentCall(this, "client", operationShape, inputShape, httpRequestTestCase.params)
|
||||
rust(""".send().await;""")
|
||||
// Response parsing will always fail since we feed it an empty response body, so we don't care
|
||||
// if it fails, but it is helpful to print what that failure was for debugging
|
||||
rust("let _ = dbg!(result);")
|
||||
rust("""let http_request = request_receiver.expect_request();""")
|
||||
|
||||
rust(""".make_operation(&config).await.expect("operation failed to build");""")
|
||||
rust("let (http_request, parts) = input.into_request_response().0.into_parts();")
|
||||
with(httpRequestTestCase) {
|
||||
// Override the endpoint for tests that set a `host`, for example:
|
||||
// https://github.com/awslabs/smithy/blob/be68f3bbdfe5bf50a104b387094d40c8069f16b1/smithy-aws-protocol-tests/model/restJson1/endpoint-paths.smithy#L19
|
||||
host.orNull()?.also { host ->
|
||||
val withScheme = "http://$host"
|
||||
when (val bindings = EndpointPrefixGenerator.endpointTraitBindings(codegenContext, operationShape)) {
|
||||
null -> rust("let endpoint_prefix = None;")
|
||||
else -> {
|
||||
withBlock("let input = ", ";") {
|
||||
instantiator.render(this@renderHttpRequestTestCase, inputShape, httpRequestTestCase.params)
|
||||
}
|
||||
withBlock("let endpoint_prefix = Some({", "}.unwrap());") {
|
||||
bindings.render(this, "input", codegenContext.smithyRuntimeMode, generateValidation = false)
|
||||
}
|
||||
}
|
||||
}
|
||||
rustTemplate(
|
||||
"""
|
||||
let mut http_request = http_request;
|
||||
let ep = #{SmithyHttp}::endpoint::Endpoint::mutable(${withScheme.dq()}).expect("valid endpoint");
|
||||
ep.set_endpoint(http_request.uri_mut(), parts.acquire().get()).expect("valid endpoint");
|
||||
ep.set_endpoint(http_request.uri_mut(), endpoint_prefix.as_ref()).expect("valid endpoint");
|
||||
""",
|
||||
*codegenScope,
|
||||
)
|
||||
|
|
|
@ -141,7 +141,7 @@ class ResponseDeserializerGenerator(
|
|||
} else {
|
||||
#{parse_response}(status, headers, body)
|
||||
};
|
||||
#{type_erase_result}(parse_result).into()
|
||||
#{type_erase_result}(parse_result)
|
||||
""",
|
||||
*codegenScope,
|
||||
"parse_error" to parserGenerator.parseErrorFn(operationShape, customizations),
|
||||
|
|
|
@ -47,7 +47,7 @@ internal class ClientInstantiatorTest {
|
|||
@Test
|
||||
fun `generate named enums`() {
|
||||
val shape = model.lookup<StringShape>("com.test#NamedEnum")
|
||||
val sut = clientInstantiator(codegenContext)
|
||||
val sut = ClientInstantiator(codegenContext)
|
||||
val data = Node.parse("t2.nano".dq())
|
||||
|
||||
val project = TestWorkspace.testProject(symbolProvider)
|
||||
|
@ -66,7 +66,7 @@ internal class ClientInstantiatorTest {
|
|||
@Test
|
||||
fun `generate unnamed enums`() {
|
||||
val shape = model.lookup<StringShape>("com.test#UnnamedEnum")
|
||||
val sut = clientInstantiator(codegenContext)
|
||||
val sut = ClientInstantiator(codegenContext)
|
||||
val data = Node.parse("t2.nano".dq())
|
||||
|
||||
val project = TestWorkspace.testProject(symbolProvider)
|
||||
|
|
|
@ -333,6 +333,23 @@ open class Instantiator(
|
|||
* ```
|
||||
*/
|
||||
private fun renderStructure(writer: RustWriter, shape: StructureShape, data: ObjectNode, headers: Map<String, String>, ctx: Ctx) {
|
||||
writer.rust("#T::builder()", symbolProvider.toSymbol(shape))
|
||||
|
||||
renderStructureMembers(writer, shape, data, headers, ctx)
|
||||
|
||||
writer.rust(".build()")
|
||||
if (builderKindBehavior.hasFallibleBuilder(shape)) {
|
||||
writer.rust(".unwrap()")
|
||||
}
|
||||
}
|
||||
|
||||
protected fun renderStructureMembers(
|
||||
writer: RustWriter,
|
||||
shape: StructureShape,
|
||||
data: ObjectNode,
|
||||
headers: Map<String, String>,
|
||||
ctx: Ctx,
|
||||
) {
|
||||
fun renderMemberHelper(memberShape: MemberShape, value: Node) {
|
||||
val setterName = builderKindBehavior.setterName(memberShape)
|
||||
writer.withBlock(".$setterName(", ")") {
|
||||
|
@ -340,7 +357,6 @@ open class Instantiator(
|
|||
}
|
||||
}
|
||||
|
||||
writer.rust("#T::builder()", symbolProvider.toSymbol(shape))
|
||||
if (defaultsForRequiredFields) {
|
||||
shape.allMembers.entries
|
||||
.filter { (name, memberShape) ->
|
||||
|
@ -374,11 +390,6 @@ open class Instantiator(
|
|||
?.let {
|
||||
renderMemberHelper(it.value, fillDefaultValue(model.expectShape(it.value.target)))
|
||||
}
|
||||
|
||||
writer.rust(".build()")
|
||||
if (builderKindBehavior.hasFallibleBuilder(shape)) {
|
||||
writer.rust(".unwrap()")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -7,6 +7,7 @@ use crate::client::identity::{Identity, IdentityResolver, IdentityResolvers};
|
|||
use crate::client::orchestrator::{BoxError, HttpRequest};
|
||||
use crate::config_bag::ConfigBag;
|
||||
use crate::type_erasure::{TypeErasedBox, TypedBox};
|
||||
use aws_smithy_types::Document;
|
||||
use std::borrow::Cow;
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
@ -34,6 +35,12 @@ impl AuthSchemeId {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<&'static str> for AuthSchemeId {
|
||||
fn from(scheme_id: &'static str) -> Self {
|
||||
Self::new(scheme_id)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthOptionResolverParams(TypeErasedBox);
|
||||
|
||||
|
@ -105,10 +112,34 @@ pub trait HttpRequestSigner: Send + Sync + fmt::Debug {
|
|||
&self,
|
||||
request: &mut HttpRequest,
|
||||
identity: &Identity,
|
||||
auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
|
||||
config_bag: &ConfigBag,
|
||||
) -> Result<(), BoxError>;
|
||||
}
|
||||
|
||||
/// Endpoint configuration for the selected auth scheme.
|
||||
///
|
||||
/// This struct gets added to the request state by the auth orchestrator.
|
||||
#[non_exhaustive]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AuthSchemeEndpointConfig<'a>(Option<&'a Document>);
|
||||
|
||||
impl<'a> AuthSchemeEndpointConfig<'a> {
|
||||
/// Creates a new [`AuthSchemeEndpointConfig`].
|
||||
pub fn new(config: Option<&'a Document>) -> Self {
|
||||
Self(config)
|
||||
}
|
||||
|
||||
/// Creates an empty AuthSchemeEndpointConfig.
|
||||
pub fn empty() -> Self {
|
||||
Self(None)
|
||||
}
|
||||
|
||||
pub fn config(&self) -> Option<&'a Document> {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
pub mod builders {
|
||||
use super::*;
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ use crate::type_erasure::{TypeErasedBox, TypedBox};
|
|||
use aws_smithy_async::future::now_or_later::NowOrLater;
|
||||
use aws_smithy_async::rt::sleep::AsyncSleep;
|
||||
use aws_smithy_http::body::SdkBody;
|
||||
use aws_smithy_http::endpoint::EndpointPrefix;
|
||||
use aws_smithy_types::endpoint::Endpoint;
|
||||
use std::fmt;
|
||||
use std::future::Future as StdFuture;
|
||||
use std::pin::Pin;
|
||||
|
@ -74,12 +74,7 @@ impl EndpointResolverParams {
|
|||
}
|
||||
|
||||
pub trait EndpointResolver: Send + Sync + fmt::Debug {
|
||||
fn resolve_and_apply_endpoint(
|
||||
&self,
|
||||
params: &EndpointResolverParams,
|
||||
endpoint_prefix: Option<&EndpointPrefix>,
|
||||
request: &mut HttpRequest,
|
||||
) -> Result<(), BoxError>;
|
||||
fn resolve_endpoint(&self, params: &EndpointResolverParams) -> Result<Endpoint, BoxError>;
|
||||
}
|
||||
|
||||
/// Time that the request is being made (so that time can be overridden in the [`ConfigBag`]).
|
||||
|
|
|
@ -8,7 +8,9 @@ use aws_smithy_runtime_api::client::auth::http::{
|
|||
HTTP_API_KEY_AUTH_SCHEME_ID, HTTP_BASIC_AUTH_SCHEME_ID, HTTP_BEARER_AUTH_SCHEME_ID,
|
||||
HTTP_DIGEST_AUTH_SCHEME_ID,
|
||||
};
|
||||
use aws_smithy_runtime_api::client::auth::{AuthSchemeId, HttpAuthScheme, HttpRequestSigner};
|
||||
use aws_smithy_runtime_api::client::auth::{
|
||||
AuthSchemeEndpointConfig, AuthSchemeId, HttpAuthScheme, HttpRequestSigner,
|
||||
};
|
||||
use aws_smithy_runtime_api::client::identity::http::{Login, Token};
|
||||
use aws_smithy_runtime_api::client::identity::{Identity, IdentityResolver, IdentityResolvers};
|
||||
use aws_smithy_runtime_api::client::orchestrator::{BoxError, HttpRequest};
|
||||
|
@ -76,6 +78,7 @@ impl HttpRequestSigner for ApiKeySigner {
|
|||
&self,
|
||||
request: &mut HttpRequest,
|
||||
identity: &Identity,
|
||||
_auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
|
||||
_config_bag: &ConfigBag,
|
||||
) -> Result<(), BoxError> {
|
||||
let api_key = identity
|
||||
|
@ -141,6 +144,7 @@ impl HttpRequestSigner for BasicAuthSigner {
|
|||
&self,
|
||||
request: &mut HttpRequest,
|
||||
identity: &Identity,
|
||||
_auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
|
||||
_config_bag: &ConfigBag,
|
||||
) -> Result<(), BoxError> {
|
||||
let login = identity
|
||||
|
@ -198,6 +202,7 @@ impl HttpRequestSigner for BearerAuthSigner {
|
|||
&self,
|
||||
request: &mut HttpRequest,
|
||||
identity: &Identity,
|
||||
_auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
|
||||
_config_bag: &ConfigBag,
|
||||
) -> Result<(), BoxError> {
|
||||
let token = identity
|
||||
|
@ -253,6 +258,7 @@ impl HttpRequestSigner for DigestAuthSigner {
|
|||
&self,
|
||||
_request: &mut HttpRequest,
|
||||
_identity: &Identity,
|
||||
_auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
|
||||
_config_bag: &ConfigBag,
|
||||
) -> Result<(), BoxError> {
|
||||
unimplemented!(
|
||||
|
@ -281,7 +287,12 @@ mod tests {
|
|||
.body(SdkBody::empty())
|
||||
.unwrap();
|
||||
signer
|
||||
.sign_request(&mut request, &identity, &config_bag)
|
||||
.sign_request(
|
||||
&mut request,
|
||||
&identity,
|
||||
AuthSchemeEndpointConfig::empty(),
|
||||
&config_bag,
|
||||
)
|
||||
.expect("success");
|
||||
assert_eq!(
|
||||
"SomeSchemeName some-token",
|
||||
|
@ -304,7 +315,12 @@ mod tests {
|
|||
.body(SdkBody::empty())
|
||||
.unwrap();
|
||||
signer
|
||||
.sign_request(&mut request, &identity, &config_bag)
|
||||
.sign_request(
|
||||
&mut request,
|
||||
&identity,
|
||||
AuthSchemeEndpointConfig::empty(),
|
||||
&config_bag,
|
||||
)
|
||||
.expect("success");
|
||||
assert!(request.headers().get("some-query-name").is_none());
|
||||
assert_eq!(
|
||||
|
@ -321,7 +337,12 @@ mod tests {
|
|||
let mut request = http::Request::builder().body(SdkBody::empty()).unwrap();
|
||||
|
||||
signer
|
||||
.sign_request(&mut request, &identity, &config_bag)
|
||||
.sign_request(
|
||||
&mut request,
|
||||
&identity,
|
||||
AuthSchemeEndpointConfig::empty(),
|
||||
&config_bag,
|
||||
)
|
||||
.expect("success");
|
||||
assert_eq!(
|
||||
"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==",
|
||||
|
@ -337,7 +358,12 @@ mod tests {
|
|||
let identity = Identity::new(Token::new("some-token", None), None);
|
||||
let mut request = http::Request::builder().body(SdkBody::empty()).unwrap();
|
||||
signer
|
||||
.sign_request(&mut request, &identity, &config_bag)
|
||||
.sign_request(
|
||||
&mut request,
|
||||
&identity,
|
||||
AuthSchemeEndpointConfig::empty(),
|
||||
&config_bag,
|
||||
)
|
||||
.expect("success");
|
||||
assert_eq!(
|
||||
"Bearer some-token",
|
||||
|
|
|
@ -3,11 +3,62 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
use aws_smithy_runtime_api::client::auth::{AuthSchemeEndpointConfig, AuthSchemeId};
|
||||
use aws_smithy_runtime_api::client::interceptors::InterceptorContext;
|
||||
use aws_smithy_runtime_api::client::orchestrator::{BoxError, ConfigBagAccessors};
|
||||
use aws_smithy_runtime_api::config_bag::ConfigBag;
|
||||
use aws_smithy_types::endpoint::Endpoint;
|
||||
use aws_smithy_types::Document;
|
||||
use std::borrow::Cow;
|
||||
use std::error::Error as StdError;
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum AuthOrchestrationError {
|
||||
NoMatchingAuthScheme,
|
||||
BadAuthSchemeEndpointConfig(Cow<'static, str>),
|
||||
AuthSchemeEndpointConfigMismatch(String),
|
||||
}
|
||||
|
||||
impl AuthOrchestrationError {
|
||||
fn auth_scheme_endpoint_config_mismatch<'a>(
|
||||
auth_schemes: impl Iterator<Item = &'a Document>,
|
||||
) -> Self {
|
||||
Self::AuthSchemeEndpointConfigMismatch(
|
||||
auth_schemes
|
||||
.flat_map(|s| match s {
|
||||
Document::Object(map) => match map.get("name") {
|
||||
Some(Document::String(name)) => Some(name.as_str()),
|
||||
_ => None,
|
||||
},
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for AuthOrchestrationError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::NoMatchingAuthScheme => f.write_str(
|
||||
"no auth scheme matched auth options. This is a bug. Please file an issue.",
|
||||
),
|
||||
Self::BadAuthSchemeEndpointConfig(message) => f.write_str(message),
|
||||
Self::AuthSchemeEndpointConfigMismatch(supported_schemes) => {
|
||||
write!(f,
|
||||
"selected auth scheme / endpoint config mismatch. Couldn't find `sigv4` endpoint config for this endpoint. \
|
||||
The authentication schemes supported by this endpoint are: {:?}",
|
||||
supported_schemes
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StdError for AuthOrchestrationError {}
|
||||
|
||||
pub(super) async fn orchestrate_auth(
|
||||
ctx: &mut InterceptorContext,
|
||||
cfg: &ConfigBag,
|
||||
|
@ -22,36 +73,62 @@ pub(super) async fn orchestrate_auth(
|
|||
identity_resolvers = ?identity_resolvers,
|
||||
"orchestrating auth",
|
||||
);
|
||||
|
||||
for &scheme_id in auth_options.as_ref() {
|
||||
if let Some(auth_scheme) = cfg.http_auth_schemes().scheme(scheme_id) {
|
||||
if let Some(identity_resolver) = auth_scheme.identity_resolver(identity_resolvers) {
|
||||
let request_signer = auth_scheme.request_signer();
|
||||
let endpoint = cfg
|
||||
.get::<Endpoint>()
|
||||
.expect("endpoint added to config bag by endpoint orchestrator");
|
||||
let auth_scheme_endpoint_config =
|
||||
extract_endpoint_auth_scheme_config(endpoint, scheme_id)?;
|
||||
|
||||
let identity = identity_resolver.resolve_identity(cfg).await?;
|
||||
let request = ctx.request_mut();
|
||||
request_signer.sign_request(request, &identity, cfg)?;
|
||||
request_signer.sign_request(
|
||||
request,
|
||||
&identity,
|
||||
auth_scheme_endpoint_config,
|
||||
cfg,
|
||||
)?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(NoMatchingAuthScheme.into())
|
||||
Err(AuthOrchestrationError::NoMatchingAuthScheme.into())
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct NoMatchingAuthScheme;
|
||||
|
||||
impl fmt::Display for NoMatchingAuthScheme {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"no auth scheme matched auth options. This is a bug. Please file an issue."
|
||||
)
|
||||
}
|
||||
fn extract_endpoint_auth_scheme_config(
|
||||
endpoint: &Endpoint,
|
||||
scheme_id: AuthSchemeId,
|
||||
) -> Result<AuthSchemeEndpointConfig<'_>, AuthOrchestrationError> {
|
||||
let auth_schemes = match endpoint.properties().get("authSchemes") {
|
||||
Some(Document::Array(schemes)) => schemes,
|
||||
// no auth schemes:
|
||||
None => return Ok(AuthSchemeEndpointConfig::new(None)),
|
||||
_other => {
|
||||
return Err(AuthOrchestrationError::BadAuthSchemeEndpointConfig(
|
||||
"expected an array for `authSchemes` in endpoint config".into(),
|
||||
))
|
||||
}
|
||||
};
|
||||
let auth_scheme_config = auth_schemes
|
||||
.iter()
|
||||
.find(|doc| {
|
||||
let config_scheme_id = doc
|
||||
.as_object()
|
||||
.and_then(|object| object.get("name"))
|
||||
.and_then(Document::as_string);
|
||||
config_scheme_id == Some(scheme_id.as_str())
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
AuthOrchestrationError::auth_scheme_endpoint_config_mismatch(auth_schemes.iter())
|
||||
})?;
|
||||
Ok(AuthSchemeEndpointConfig::new(Some(auth_scheme_config)))
|
||||
}
|
||||
|
||||
impl std::error::Error for NoMatchingAuthScheme {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
@ -64,6 +141,7 @@ mod tests {
|
|||
use aws_smithy_runtime_api::client::interceptors::InterceptorContext;
|
||||
use aws_smithy_runtime_api::client::orchestrator::{Future, HttpRequest};
|
||||
use aws_smithy_runtime_api::type_erasure::TypedBox;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[tokio::test]
|
||||
async fn basic_case() {
|
||||
|
@ -83,6 +161,7 @@ mod tests {
|
|||
&self,
|
||||
request: &mut HttpRequest,
|
||||
_identity: &Identity,
|
||||
_auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
|
||||
_config_bag: &ConfigBag,
|
||||
) -> Result<(), BoxError> {
|
||||
request
|
||||
|
@ -134,6 +213,7 @@ mod tests {
|
|||
.auth_scheme(TEST_SCHEME_ID, TestAuthScheme { signer: TestSigner })
|
||||
.build(),
|
||||
);
|
||||
cfg.put(Endpoint::builder().url("dontcare").build());
|
||||
|
||||
orchestrate_auth(&mut ctx, &cfg).await.expect("success");
|
||||
|
||||
|
@ -170,6 +250,7 @@ mod tests {
|
|||
.auth_scheme(HTTP_BEARER_AUTH_SCHEME_ID, BearerAuthScheme::new())
|
||||
.build(),
|
||||
);
|
||||
cfg.put(Endpoint::builder().url("dontcare").build());
|
||||
|
||||
// First, test the presence of a basic auth login and absence of a bearer token
|
||||
cfg.set_identity_resolvers(
|
||||
|
@ -203,4 +284,92 @@ mod tests {
|
|||
ctx.request().headers().get("Authorization").unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_endpoint_auth_scheme_config_no_config() {
|
||||
let endpoint = Endpoint::builder()
|
||||
.url("dontcare")
|
||||
.property("something-unrelated", Document::Null)
|
||||
.build();
|
||||
let config = extract_endpoint_auth_scheme_config(&endpoint, "test-scheme-id".into())
|
||||
.expect("success");
|
||||
assert!(config.config().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_endpoint_auth_scheme_config_wrong_type() {
|
||||
let endpoint = Endpoint::builder()
|
||||
.url("dontcare")
|
||||
.property("authSchemes", Document::String("bad".into()))
|
||||
.build();
|
||||
extract_endpoint_auth_scheme_config(&endpoint, "test-scheme-id".into())
|
||||
.expect_err("should fail because authSchemes is the wrong type");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_endpoint_auth_scheme_config_no_matching_scheme() {
|
||||
let endpoint = Endpoint::builder()
|
||||
.url("dontcare")
|
||||
.property(
|
||||
"authSchemes",
|
||||
vec![
|
||||
Document::Object({
|
||||
let mut out = HashMap::new();
|
||||
out.insert("name".to_string(), "wrong-scheme-id".to_string().into());
|
||||
out
|
||||
}),
|
||||
Document::Object({
|
||||
let mut out = HashMap::new();
|
||||
out.insert(
|
||||
"name".to_string(),
|
||||
"another-wrong-scheme-id".to_string().into(),
|
||||
);
|
||||
out
|
||||
}),
|
||||
],
|
||||
)
|
||||
.build();
|
||||
extract_endpoint_auth_scheme_config(&endpoint, "test-scheme-id".into())
|
||||
.expect_err("should fail because authSchemes doesn't include the desired scheme");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_endpoint_auth_scheme_config_successfully() {
|
||||
let endpoint = Endpoint::builder()
|
||||
.url("dontcare")
|
||||
.property(
|
||||
"authSchemes",
|
||||
vec![
|
||||
Document::Object({
|
||||
let mut out = HashMap::new();
|
||||
out.insert("name".to_string(), "wrong-scheme-id".to_string().into());
|
||||
out
|
||||
}),
|
||||
Document::Object({
|
||||
let mut out = HashMap::new();
|
||||
out.insert("name".to_string(), "test-scheme-id".to_string().into());
|
||||
out.insert(
|
||||
"magicString".to_string(),
|
||||
"magic string value".to_string().into(),
|
||||
);
|
||||
out
|
||||
}),
|
||||
],
|
||||
)
|
||||
.build();
|
||||
let config = extract_endpoint_auth_scheme_config(&endpoint, "test-scheme-id".into())
|
||||
.expect("should find test-scheme-id");
|
||||
assert_eq!(
|
||||
"magic string value",
|
||||
config
|
||||
.config()
|
||||
.expect("config is set")
|
||||
.as_object()
|
||||
.expect("it's an object")
|
||||
.get("magicString")
|
||||
.expect("magicString is set")
|
||||
.as_string()
|
||||
.expect("gimme the string, dammit!")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,13 +5,15 @@
|
|||
|
||||
use aws_smithy_http::endpoint::error::ResolveEndpointError;
|
||||
use aws_smithy_http::endpoint::{
|
||||
apply_endpoint, EndpointPrefix, ResolveEndpoint, SharedEndpointResolver,
|
||||
apply_endpoint as apply_endpoint_to_request_uri, EndpointPrefix, ResolveEndpoint,
|
||||
SharedEndpointResolver,
|
||||
};
|
||||
use aws_smithy_runtime_api::client::interceptors::InterceptorContext;
|
||||
use aws_smithy_runtime_api::client::orchestrator::{
|
||||
BoxError, ConfigBagAccessors, EndpointResolver, EndpointResolverParams, HttpRequest,
|
||||
};
|
||||
use aws_smithy_runtime_api::config_bag::ConfigBag;
|
||||
use aws_smithy_types::endpoint::Endpoint;
|
||||
use http::header::HeaderName;
|
||||
use http::{HeaderValue, Uri};
|
||||
use std::fmt::Debug;
|
||||
|
@ -36,14 +38,8 @@ impl StaticUriEndpointResolver {
|
|||
}
|
||||
|
||||
impl EndpointResolver for StaticUriEndpointResolver {
|
||||
fn resolve_and_apply_endpoint(
|
||||
&self,
|
||||
_params: &EndpointResolverParams,
|
||||
_endpoint_prefix: Option<&EndpointPrefix>,
|
||||
request: &mut HttpRequest,
|
||||
) -> Result<(), BoxError> {
|
||||
apply_endpoint(request.uri_mut(), &self.endpoint, None)?;
|
||||
Ok(())
|
||||
fn resolve_endpoint(&self, _params: &EndpointResolverParams) -> Result<Endpoint, BoxError> {
|
||||
Ok(Endpoint::builder().url(self.endpoint.to_string()).build())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -81,63 +77,64 @@ impl<Params> EndpointResolver for DefaultEndpointResolver<Params>
|
|||
where
|
||||
Params: Debug + Send + Sync + 'static,
|
||||
{
|
||||
fn resolve_and_apply_endpoint(
|
||||
&self,
|
||||
params: &EndpointResolverParams,
|
||||
endpoint_prefix: Option<&EndpointPrefix>,
|
||||
request: &mut HttpRequest,
|
||||
) -> Result<(), BoxError> {
|
||||
let endpoint = match params.get::<Params>() {
|
||||
Some(params) => self.inner.resolve_endpoint(params)?,
|
||||
None => {
|
||||
return Err(Box::new(ResolveEndpointError::message(
|
||||
"params of expected type was not present",
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let uri: Uri = endpoint.url().parse().map_err(|err| {
|
||||
ResolveEndpointError::from_source("endpoint did not have a valid uri", err)
|
||||
})?;
|
||||
|
||||
apply_endpoint(request.uri_mut(), &uri, endpoint_prefix).map_err(|err| {
|
||||
ResolveEndpointError::message(format!(
|
||||
"failed to apply endpoint `{:?}` to request `{:?}`",
|
||||
uri, request,
|
||||
))
|
||||
.with_source(Some(err.into()))
|
||||
})?;
|
||||
|
||||
for (header_name, header_values) in endpoint.headers() {
|
||||
request.headers_mut().remove(header_name);
|
||||
for value in header_values {
|
||||
request.headers_mut().insert(
|
||||
HeaderName::from_str(header_name).map_err(|err| {
|
||||
ResolveEndpointError::message("invalid header name")
|
||||
.with_source(Some(err.into()))
|
||||
})?,
|
||||
HeaderValue::from_str(value).map_err(|err| {
|
||||
ResolveEndpointError::message("invalid header value")
|
||||
.with_source(Some(err.into()))
|
||||
})?,
|
||||
);
|
||||
}
|
||||
fn resolve_endpoint(&self, params: &EndpointResolverParams) -> Result<Endpoint, BoxError> {
|
||||
match params.get::<Params>() {
|
||||
Some(params) => Ok(self.inner.resolve_endpoint(params)?),
|
||||
None => Err(Box::new(ResolveEndpointError::message(
|
||||
"params of expected type was not present",
|
||||
))),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn orchestrate_endpoint(
|
||||
ctx: &mut InterceptorContext,
|
||||
cfg: &ConfigBag,
|
||||
cfg: &mut ConfigBag,
|
||||
) -> Result<(), BoxError> {
|
||||
let params = cfg.endpoint_resolver_params();
|
||||
let endpoint_prefix = cfg.get::<EndpointPrefix>();
|
||||
let request = ctx.request_mut();
|
||||
|
||||
let endpoint_resolver = cfg.endpoint_resolver();
|
||||
endpoint_resolver.resolve_and_apply_endpoint(params, endpoint_prefix, request)?;
|
||||
let endpoint = endpoint_resolver.resolve_endpoint(params)?;
|
||||
apply_endpoint(request, &endpoint, endpoint_prefix)?;
|
||||
|
||||
// Make the endpoint config available to interceptors
|
||||
cfg.put(endpoint);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn apply_endpoint(
|
||||
request: &mut HttpRequest,
|
||||
endpoint: &Endpoint,
|
||||
endpoint_prefix: Option<&EndpointPrefix>,
|
||||
) -> Result<(), BoxError> {
|
||||
let uri: Uri = endpoint.url().parse().map_err(|err| {
|
||||
ResolveEndpointError::from_source("endpoint did not have a valid uri", err)
|
||||
})?;
|
||||
|
||||
apply_endpoint_to_request_uri(request.uri_mut(), &uri, endpoint_prefix).map_err(|err| {
|
||||
ResolveEndpointError::message(format!(
|
||||
"failed to apply endpoint `{:?}` to request `{:?}`",
|
||||
uri, request,
|
||||
))
|
||||
.with_source(Some(err.into()))
|
||||
})?;
|
||||
|
||||
for (header_name, header_values) in endpoint.headers() {
|
||||
request.headers_mut().remove(header_name);
|
||||
for value in header_values {
|
||||
request.headers_mut().insert(
|
||||
HeaderName::from_str(header_name).map_err(|err| {
|
||||
ResolveEndpointError::message("invalid header name")
|
||||
.with_source(Some(err.into()))
|
||||
})?,
|
||||
HeaderValue::from_str(value).map_err(|err| {
|
||||
ResolveEndpointError::message("invalid header value")
|
||||
.with_source(Some(err.into()))
|
||||
})?,
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ use aws_smithy_runtime_api::client::auth::option_resolver::{
|
|||
StaticAuthOptionResolver, StaticAuthOptionResolverParams,
|
||||
};
|
||||
use aws_smithy_runtime_api::client::auth::{
|
||||
AuthSchemeId, HttpAuthScheme, HttpAuthSchemes, HttpRequestSigner,
|
||||
AuthSchemeEndpointConfig, AuthSchemeId, HttpAuthScheme, HttpAuthSchemes, HttpRequestSigner,
|
||||
};
|
||||
use aws_smithy_runtime_api::client::identity::{Identity, IdentityResolver, IdentityResolvers};
|
||||
use aws_smithy_runtime_api::client::interceptors::InterceptorRegistrar;
|
||||
|
@ -83,6 +83,7 @@ impl HttpRequestSigner for AnonymousSigner {
|
|||
&self,
|
||||
_request: &mut HttpRequest,
|
||||
_identity: &Identity,
|
||||
_auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
|
||||
_config_bag: &ConfigBag,
|
||||
) -> Result<(), BoxError> {
|
||||
Ok(())
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
*/
|
||||
|
||||
use crate::Number;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/* ANCHOR: document */
|
||||
|
@ -14,7 +15,7 @@ use std::collections::HashMap;
|
|||
/// Open content is useful for modeling unstructured data that has no schema, data that can't be
|
||||
/// modeled using rigid types, or data that has a schema that evolves outside of the purview of a model.
|
||||
/// The serialization format of a document is an implementation detail of a protocol.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum Document {
|
||||
/// JSON object
|
||||
Object(HashMap<String, Document>),
|
||||
|
@ -30,12 +31,135 @@ pub enum Document {
|
|||
Null,
|
||||
}
|
||||
|
||||
impl Document {
|
||||
/// Returns the inner map value if this `Document` is an object.
|
||||
pub fn as_object(&self) -> Option<&HashMap<String, Document>> {
|
||||
if let Self::Object(object) = self {
|
||||
Some(object)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the mutable inner map value if this `Document` is an object.
|
||||
pub fn as_object_mut(&mut self) -> Option<&mut HashMap<String, Document>> {
|
||||
if let Self::Object(object) = self {
|
||||
Some(object)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the inner array value if this `Document` is an array.
|
||||
pub fn as_array(&self) -> Option<&Vec<Document>> {
|
||||
if let Self::Array(array) = self {
|
||||
Some(array)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the mutable inner array value if this `Document` is an array.
|
||||
pub fn as_array_mut(&mut self) -> Option<&mut Vec<Document>> {
|
||||
if let Self::Array(array) = self {
|
||||
Some(array)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the inner number value if this `Document` is a number.
|
||||
pub fn as_number(&self) -> Option<&Number> {
|
||||
if let Self::Number(number) = self {
|
||||
Some(number)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the inner string value if this `Document` is a string.
|
||||
pub fn as_string(&self) -> Option<&str> {
|
||||
if let Self::String(string) = self {
|
||||
Some(string)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the inner boolean value if this `Document` is a boolean.
|
||||
pub fn as_bool(&self) -> Option<bool> {
|
||||
if let Self::Bool(boolean) = self {
|
||||
Some(*boolean)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `Some(())` if this `Document` is a null.
|
||||
pub fn as_null(&self) -> Option<()> {
|
||||
if let Self::Null = self {
|
||||
Some(())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if this `Document` is an object.
|
||||
pub fn is_object(&self) -> bool {
|
||||
matches!(self, Self::Object(_))
|
||||
}
|
||||
|
||||
/// Returns `true` if this `Document` is an array.
|
||||
pub fn is_array(&self) -> bool {
|
||||
matches!(self, Self::Array(_))
|
||||
}
|
||||
|
||||
/// Returns `true` if this `Document` is a number.
|
||||
pub fn is_number(&self) -> bool {
|
||||
matches!(self, Self::Number(_))
|
||||
}
|
||||
|
||||
/// Returns `true` if this `Document` is a string.
|
||||
pub fn is_string(&self) -> bool {
|
||||
matches!(self, Self::String(_))
|
||||
}
|
||||
|
||||
/// Returns `true` if this `Document` is a bool.
|
||||
pub fn is_bool(&self) -> bool {
|
||||
matches!(self, Self::Bool(_))
|
||||
}
|
||||
|
||||
/// Returns `true` if this `Document` is a boolean.
|
||||
pub fn is_null(&self) -> bool {
|
||||
matches!(self, Self::Null)
|
||||
}
|
||||
}
|
||||
|
||||
/// The default value is `Document::Null`.
|
||||
impl Default for Document {
|
||||
fn default() -> Self {
|
||||
Self::Null
|
||||
}
|
||||
}
|
||||
|
||||
impl From<bool> for Document {
|
||||
fn from(value: bool) -> Self {
|
||||
Document::Bool(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a str> for Document {
|
||||
fn from(value: &'a str) -> Self {
|
||||
Document::String(value.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<Cow<'a, str>> for Document {
|
||||
fn from(value: Cow<'a, str>) -> Self {
|
||||
Document::String(value.into_owned())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for Document {
|
||||
fn from(value: String) -> Self {
|
||||
Document::String(value)
|
||||
|
@ -71,3 +195,15 @@ impl From<i32> for Document {
|
|||
Document::Number(Number::NegInt(value as i64))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f64> for Document {
|
||||
fn from(value: f64) -> Self {
|
||||
Document::Number(Number::Float(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Number> for Document {
|
||||
fn from(value: Number) -> Self {
|
||||
Document::Number(value)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,17 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
|
||||
set -eux
|
||||
C_YELLOW='\033[1;33m'
|
||||
C_RESET='\033[0m'
|
||||
|
||||
set -eu
|
||||
cd smithy-rs
|
||||
./gradlew aws:sdk-adhoc-test:test
|
||||
|
||||
# TODO(enableNewSmithyRuntime): Remove the middleware test run when cleaning up middleware
|
||||
echo -e "## ${C_YELLOW}Running SDK adhoc tests against the middleware implementation...${C_RESET}"
|
||||
./gradlew aws:sdk-adhoc-test:clean
|
||||
./gradlew aws:sdk-adhoc-test:check -Psmithy.runtime.mode=middleware
|
||||
|
||||
echo -e "## ${C_YELLOW}Running SDK adhoc tests against the orchestrator implementation...${C_RESET}"
|
||||
./gradlew aws:sdk-adhoc-test:clean
|
||||
./gradlew aws:sdk-adhoc-test:check -Psmithy.runtime.mode=orchestrator
|
||||
|
|
Loading…
Reference in New Issue