Make the SDK ad hoc tests pass against the orchestrator (#2708)

## Motivation and Context
This PR refactors the client protocol test generator machinery to use a
client instead of calling `make_operation` directly, and then fixes the
ad hoc tests for the orchestrator.

The ad hoc tests revealed that overriding the signing region/service via
endpoint config was lost when porting SigV4 signing to the orchestrator,
so this PR updates the SigV4 `HttpRequestSigner` implementation to
restore this functionality. It is doing this in the signer directly
rather than via an interceptor since it should only run this logic when
SigV4 is the selected auth scheme.

Other notable changes:
- Adds `--no-fail-fast` arg to `cargoTest` targets so that all Rust
tests run in CI rather than stopping on the first failure
- Changes `EndpointResolver::resolve_and_apply_endpoint` to just
`resolve_endpoint` so that the orchestrator can place the endpoint
config into the request state, which is required for the signer to make
use of it
- Adds a `set_region` method to SDK service configs
- Deletes the API Gateway model and integration test from the SDK smoke
test since it is covered by the ad hoc tests
- Adds a comment explaining where the API Gateway model comes from in
the ad hoc tests
- Adds a `smithy.runtime.mode` Gradle property to `aws:sdk` and
`aws:sdk-adhoc-test` to trivially switch between middleware and
orchestrator when testing/generating locally

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
This commit is contained in:
John DiSanti 2023-05-19 09:58:28 -07:00 committed by GitHub
parent 41774b8405
commit 9bfe936fbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 924 additions and 13121 deletions

View File

@ -11,12 +11,18 @@ pub mod sigv4 {
SignableRequest, SignatureLocation, SigningParams, SigningSettings,
UriPathNormalizationMode,
};
use aws_smithy_runtime_api::client::auth::{AuthSchemeId, HttpAuthScheme, HttpRequestSigner};
use aws_smithy_runtime_api::client::auth::{
AuthSchemeEndpointConfig, AuthSchemeId, HttpAuthScheme, HttpRequestSigner,
};
use aws_smithy_runtime_api::client::identity::{Identity, IdentityResolver, IdentityResolvers};
use aws_smithy_runtime_api::client::orchestrator::{BoxError, ConfigBagAccessors, HttpRequest};
use aws_smithy_runtime_api::config_bag::ConfigBag;
use aws_types::region::SigningRegion;
use aws_smithy_types::Document;
use aws_types::region::{Region, SigningRegion};
use aws_types::SigningService;
use std::borrow::Cow;
use std::error::Error as StdError;
use std::fmt;
use std::time::{Duration, SystemTime};
const EXPIRATION_WARNING: &str = "Presigned request will expire before the given \
@ -25,6 +31,53 @@ pub mod sigv4 {
/// Auth scheme ID for SigV4.
pub const SCHEME_ID: AuthSchemeId = AuthSchemeId::new("sigv4");
struct EndpointAuthSchemeConfig {
signing_region_override: Option<SigningRegion>,
signing_service_override: Option<SigningService>,
}
#[derive(Debug)]
enum SigV4SigningError {
MissingOperationSigningConfig,
MissingSigningRegion,
MissingSigningService,
WrongIdentityType(Identity),
BadTypeInEndpointAuthSchemeConfig(&'static str),
}
impl fmt::Display for SigV4SigningError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use SigV4SigningError::*;
let mut w = |s| f.write_str(s);
match self {
MissingOperationSigningConfig => w("missing operation signing config for SigV4"),
MissingSigningRegion => w("missing signing region for SigV4 signing"),
MissingSigningService => w("missing signing service for SigV4 signing"),
WrongIdentityType(identity) => {
write!(f, "wrong identity type for SigV4: {identity:?}")
}
BadTypeInEndpointAuthSchemeConfig(field_name) => {
write!(
f,
"unexpected type for `{field_name}` in endpoint auth scheme config",
)
}
}
}
}
impl StdError for SigV4SigningError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match self {
Self::MissingOperationSigningConfig => None,
Self::MissingSigningRegion => None,
Self::MissingSigningService => None,
Self::WrongIdentityType(_) => None,
Self::BadTypeInEndpointAuthSchemeConfig(_) => None,
}
}
}
/// SigV4 auth scheme.
#[derive(Debug, Default)]
pub struct SigV4HttpAuthScheme {
@ -111,9 +164,9 @@ pub mod sigv4 {
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct SigV4OperationSigningConfig {
/// AWS Region to sign for.
pub region: SigningRegion,
pub region: Option<SigningRegion>,
/// AWS Service to sign for.
pub service: SigningService,
pub service: Option<SigningService>,
/// Signing options.
pub signing_options: SigningOptions,
}
@ -165,7 +218,7 @@ pub mod sigv4 {
credentials: &'a Credentials,
operation_config: &'a SigV4OperationSigningConfig,
request_timestamp: SystemTime,
) -> SigningParams<'a> {
) -> Result<SigningParams<'a>, SigV4SigningError> {
if let Some(expires_in) = settings.expires_in {
if let Some(creds_expires_time) = credentials.expiry() {
let presigned_expires_time = request_timestamp + expires_in;
@ -178,12 +231,75 @@ pub mod sigv4 {
let mut builder = SigningParams::builder()
.access_key(credentials.access_key_id())
.secret_key(credentials.secret_access_key())
.region(operation_config.region.as_ref())
.service_name(operation_config.service.as_ref())
.region(
operation_config
.region
.as_ref()
.ok_or(SigV4SigningError::MissingSigningRegion)?
.as_ref(),
)
.service_name(
operation_config
.service
.as_ref()
.ok_or(SigV4SigningError::MissingSigningService)?
.as_ref(),
)
.time(request_timestamp)
.settings(settings);
builder.set_security_token(credentials.session_token());
builder.build().expect("all required fields set")
Ok(builder.build().expect("all required fields set"))
}
fn extract_operation_config<'a>(
auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'a>,
config_bag: &'a ConfigBag,
) -> Result<Cow<'a, SigV4OperationSigningConfig>, SigV4SigningError> {
let operation_config = config_bag
.get::<SigV4OperationSigningConfig>()
.ok_or(SigV4SigningError::MissingOperationSigningConfig)?;
let EndpointAuthSchemeConfig {
signing_region_override,
signing_service_override,
} = Self::extract_endpoint_auth_scheme_config(auth_scheme_endpoint_config)?;
match (signing_region_override, signing_service_override) {
(None, None) => Ok(Cow::Borrowed(operation_config)),
(region, service) => {
let mut operation_config = operation_config.clone();
if region.is_some() {
operation_config.region = region;
}
if service.is_some() {
operation_config.service = service;
}
Ok(Cow::Owned(operation_config))
}
}
}
fn extract_endpoint_auth_scheme_config(
endpoint_config: AuthSchemeEndpointConfig<'_>,
) -> Result<EndpointAuthSchemeConfig, SigV4SigningError> {
let (mut signing_region_override, mut signing_service_override) = (None, None);
if let Some(config) = endpoint_config.config().and_then(Document::as_object) {
use SigV4SigningError::BadTypeInEndpointAuthSchemeConfig as UnexpectedType;
signing_region_override = match config.get("signingRegion") {
Some(Document::String(s)) => Some(SigningRegion::from(Region::new(s.clone()))),
None => None,
_ => return Err(UnexpectedType("signingRegion")),
};
signing_service_override = match config.get("signingName") {
Some(Document::String(s)) => Some(SigningService::from(s.to_string())),
None => None,
_ => return Err(UnexpectedType("signingName")),
};
}
Ok(EndpointAuthSchemeConfig {
signing_region_override,
signing_service_override,
})
}
}
@ -192,11 +308,11 @@ pub mod sigv4 {
&self,
request: &mut HttpRequest,
identity: &Identity,
auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
config_bag: &ConfigBag,
) -> Result<(), BoxError> {
let operation_config = config_bag
.get::<SigV4OperationSigningConfig>()
.ok_or("missing operation signing config for SigV4")?;
let operation_config =
Self::extract_operation_config(auth_scheme_endpoint_config, config_bag)?;
let request_time = config_bag.request_time().unwrap_or_default().system_time();
let credentials = if let Some(creds) = identity.data::<Credentials>() {
@ -205,12 +321,12 @@ pub mod sigv4 {
tracing::debug!("skipped SigV4 signing since signing is optional for this operation and there are no credentials");
return Ok(());
} else {
return Err(format!("wrong identity type for SigV4: {identity:?}").into());
return Err(SigV4SigningError::WrongIdentityType(identity.clone()).into());
};
let settings = Self::settings(operation_config);
let settings = Self::settings(&operation_config);
let signing_params =
Self::signing_params(settings, credentials, operation_config, request_time);
Self::signing_params(settings, credentials, &operation_config, request_time)?;
let (signing_instructions, _signature) = {
// A body that is already in memory can be signed directly. A body that is not in memory
@ -250,6 +366,9 @@ pub mod sigv4 {
use super::*;
use aws_credential_types::Credentials;
use aws_sigv4::http_request::SigningSettings;
use aws_types::region::SigningRegion;
use aws_types::SigningService;
use std::collections::HashMap;
use std::time::{Duration, SystemTime};
use tracing_test::traced_test;
@ -270,8 +389,8 @@ pub mod sigv4 {
"test",
);
let operation_config = SigV4OperationSigningConfig {
region: SigningRegion::from_static("test"),
service: SigningService::from_static("test"),
region: Some(SigningRegion::from_static("test")),
service: Some(SigningService::from_static("test")),
signing_options: SigningOptions {
double_uri_encode: true,
content_sha256_header: true,
@ -283,14 +402,74 @@ pub mod sigv4 {
payload_override: None,
},
};
SigV4HttpRequestSigner::signing_params(settings, &credentials, &operation_config, now);
SigV4HttpRequestSigner::signing_params(settings, &credentials, &operation_config, now)
.unwrap();
assert!(!logs_contain(EXPIRATION_WARNING));
let mut settings = SigningSettings::default();
settings.expires_in = Some(creds_expire_in + Duration::from_secs(10));
SigV4HttpRequestSigner::signing_params(settings, &credentials, &operation_config, now);
SigV4HttpRequestSigner::signing_params(settings, &credentials, &operation_config, now)
.unwrap();
assert!(logs_contain(EXPIRATION_WARNING));
}
#[test]
fn endpoint_config_overrides_region_and_service() {
let mut cfg = ConfigBag::base();
cfg.put(SigV4OperationSigningConfig {
region: Some(SigningRegion::from(Region::new("override-this-region"))),
service: Some(SigningService::from_static("override-this-service")),
signing_options: Default::default(),
});
let config = Document::Object({
let mut out = HashMap::new();
out.insert("name".to_string(), "sigv4".to_string().into());
out.insert(
"signingName".to_string(),
"qldb-override".to_string().into(),
);
out.insert(
"signingRegion".to_string(),
"us-east-override".to_string().into(),
);
out
});
let config = AuthSchemeEndpointConfig::new(Some(&config));
let result =
SigV4HttpRequestSigner::extract_operation_config(config, &cfg).expect("success");
assert_eq!(
result.region,
Some(SigningRegion::from(Region::new("us-east-override")))
);
assert_eq!(
result.service,
Some(SigningService::from_static("qldb-override"))
);
assert!(matches!(result, Cow::Owned(_)));
}
#[test]
fn endpoint_config_supports_fallback_when_region_or_service_are_unset() {
let mut cfg = ConfigBag::base();
cfg.put(SigV4OperationSigningConfig {
region: Some(SigningRegion::from(Region::new("us-east-1"))),
service: Some(SigningService::from_static("qldb")),
signing_options: Default::default(),
});
let config = AuthSchemeEndpointConfig::empty();
let result =
SigV4HttpRequestSigner::extract_operation_config(config, &cfg).expect("success");
assert_eq!(
result.region,
Some(SigningRegion::from(Region::new("us-east-1")))
);
assert_eq!(result.service, Some(SigningService::from_static("qldb")));
assert!(matches!(result, Cow::Borrowed(_)));
}
}
}

View File

@ -37,6 +37,8 @@ dependencies {
implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
}
fun getSmithyRuntimeMode(): String = properties.get("smithy.runtime.mode") ?: "middleware"
val allCodegenTests = listOf(
CodegenTest(
"com.amazonaws.apigateway#BackplaneControlService",
@ -46,6 +48,7 @@ val allCodegenTests = listOf(
,
"codegen": {
"includeFluentClient": false,
"enableNewSmithyRuntime": "${getSmithyRuntimeMode()}"
},
"customizationConfig": {
"awsSdk": {
@ -62,6 +65,7 @@ val allCodegenTests = listOf(
,
"codegen": {
"includeFluentClient": false,
"enableNewSmithyRuntime": "${getSmithyRuntimeMode()}"
},
"customizationConfig": {
"awsSdk": {

View File

@ -1,8 +1,16 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
// The API Gateway model is coming from Smithy's protocol tests, and includes an `Accept` header test:
// https://github.com/awslabs/smithy/blob/2f6553ff39e6bba9edc644ef5832661821785319/smithy-aws-protocol-tests/model/restJson1/services/apigateway.smithy#L30-L43
$version: "1.0"
namespace com.amazonaws.apigateway
use smithy.rules#endpointRuleSet
// Add an endpoint ruleset to the Smithy protocol test API Gateway model so that the code generator doesn't fail
apply BackplaneControlService @endpointRuleSet({
"version": "1.0",
"rules": [{

View File

@ -14,6 +14,8 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.client.Fluen
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerics
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientSection
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.NoClientGenerics
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.DefaultProtocolTestGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolTestGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Feature
import software.amazon.smithy.rust.codegen.core.rustlang.GenericTypeArg
@ -96,6 +98,29 @@ class AwsFluentClientDecorator : ClientCodegenDecorator {
}
}
}
override fun protocolTestGenerator(
codegenContext: ClientCodegenContext,
baseGenerator: ProtocolTestGenerator,
): ProtocolTestGenerator = DefaultProtocolTestGenerator(
codegenContext,
baseGenerator.protocolSupport,
baseGenerator.operationShape,
renderClientCreation = { params ->
rustTemplate(
"""
// If the test case was missing endpoint parameters, default a region so it doesn't fail
let mut ${params.configBuilderName} = ${params.configBuilderName};
if ${params.configBuilderName}.region.is_none() {
${params.configBuilderName}.set_region(Some(crate::config::Region::new("us-east-1")));
}
let config = ${params.configBuilderName}.http_connector(${params.connectorName}).build();
let ${params.clientName} = #{Client}::from_conf(config);
""",
"Client" to ClientRustModule.root.toType().resolve("Client"),
)
},
)
}
private class AwsFluentClientExtensions(types: Types) {

View File

@ -171,7 +171,7 @@ class RegionProviderConfig(codegenContext: CodegenContext) : ConfigCustomization
)
ServiceConfig.BuilderStruct ->
rustTemplate("region: Option<#{Region}>,", *codegenScope)
rustTemplate("pub(crate) region: Option<#{Region}>,", *codegenScope)
ServiceConfig.BuilderImpl ->
rustTemplate(
@ -191,6 +191,12 @@ class RegionProviderConfig(codegenContext: CodegenContext) : ConfigCustomization
self.region = region.into();
self
}
/// Sets the AWS region to use when making requests.
pub fn set_region(&mut self, region: Option<#{Region}>) -> &mut Self {
self.region = region;
self
}
""",
*codegenScope,
)

View File

@ -73,6 +73,7 @@ class OperationRetryClassifiersFeature(
"ClassifyRetry" to smithyRuntimeApi.resolve("client::retries::ClassifyRetry"),
"RetryClassifiers" to smithyRuntimeApi.resolve("client::retries::RetryClassifiers"),
"OperationError" to codegenContext.symbolProvider.symbolForOperationError(operationShape),
"OrchestratorError" to smithyRuntimeApi.resolve("client::orchestrator::OrchestratorError"),
"SdkError" to RuntimeType.smithyHttp(runtimeConfig).resolve("result::SdkError"),
"ErasedError" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("type_erasure::TypeErasedError"),
)
@ -89,9 +90,9 @@ class OperationRetryClassifiersFeature(
}
}
impl #{ClassifyRetry} for HttpStatusCodeClassifier {
fn classify_retry(&self, error: &#{ErasedError}) -> Option<#{RetryReason}> {
let error = error.downcast_ref::<#{SdkError}<#{OperationError}>>().expect("The error type is always known");
self.0.classify_error(error)
fn classify_retry(&self, _error: &#{OrchestratorError}<#{ErasedError}>) -> Option<#{RetryReason}> {
// TODO(enableNewSmithyRuntime): classify the error with self.0
None
}
}
""",
@ -108,9 +109,9 @@ class OperationRetryClassifiersFeature(
}
}
impl #{ClassifyRetry} for AwsErrorCodeClassifier {
fn classify_retry(&self, error: &#{ErasedError}) -> Option<#{RetryReason}> {
let error = error.downcast_ref::<#{SdkError}<#{OperationError}>>().expect("The error type is always known");
self.0.classify_error(error)
fn classify_retry(&self, _error: &#{OrchestratorError}<#{ErasedError}>) -> Option<#{RetryReason}> {
// TODO(enableNewSmithyRuntime): classify the error with self.0
None
}
}
""",
@ -127,9 +128,9 @@ class OperationRetryClassifiersFeature(
}
}
impl #{ClassifyRetry} for ModeledAsRetryableClassifier {
fn classify_retry(&self, error: &#{ErasedError}) -> Option<#{RetryReason}> {
let error = error.downcast_ref::<#{SdkError}<#{OperationError}>>().expect("The error type is always known");
self.0.classify_error(error)
fn classify_retry(&self, _error: &#{OrchestratorError}<#{ErasedError}>) -> Option<#{RetryReason}> {
// TODO(enableNewSmithyRuntime): classify the error with self.0
None
}
}
""",
@ -146,9 +147,9 @@ class OperationRetryClassifiersFeature(
}
}
impl #{ClassifyRetry} for AmzRetryAfterHeaderClassifier {
fn classify_retry(&self, error: &#{ErasedError}) -> Option<#{RetryReason}> {
let error = error.downcast_ref::<#{SdkError}<#{OperationError}>>().expect("The error type is always known");
self.0.classify_error(error)
fn classify_retry(&self, _error: &#{OrchestratorError}<#{ErasedError}>) -> Option<#{RetryReason}> {
// TODO(enableNewSmithyRuntime): classify the error with self.0
None
}
}
""",
@ -165,9 +166,9 @@ class OperationRetryClassifiersFeature(
}
}
impl #{ClassifyRetry} for SmithyErrorClassifier {
fn classify_retry(&self, error: &#{ErasedError}) -> Option<#{RetryReason}> {
let error = error.downcast_ref::<#{SdkError}<#{OperationError}>>().expect("The error type is always known");
self.0.classify_error(error)
fn classify_retry(&self, _error: &#{OrchestratorError}<#{ErasedError}>) -> Option<#{RetryReason}> {
// TODO(enableNewSmithyRuntime): classify the error with self.0
None
}
}
""",

View File

@ -126,8 +126,8 @@ private class AuthOperationRuntimePluginCustomization(private val codegenContext
val signingOptional = section.operationShape.hasTrait<OptionalAuthTrait>()
rustTemplate(
"""
let signing_region = cfg.get::<#{SigningRegion}>().expect("region required for signing").clone();
let signing_service = cfg.get::<#{SigningService}>().expect("service required for signing").clone();
let signing_region = cfg.get::<#{SigningRegion}>().cloned();
let signing_service = cfg.get::<#{SigningService}>().cloned();
let mut signing_options = #{SigningOptions}::default();
signing_options.double_uri_encode = $doubleUriEncode;
signing_options.content_sha256_header = $contentSha256Header;

View File

@ -17,6 +17,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointCustom
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolTestGenerator
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderCustomization
@ -138,4 +139,11 @@ class ServiceSpecificDecorator(
): List<ServiceRuntimePluginCustomization> = baseCustomizations.maybeApply(codegenContext.serviceShape) {
delegateTo.serviceRuntimePluginCustomizations(codegenContext, baseCustomizations)
}
override fun protocolTestGenerator(
codegenContext: ClientCodegenContext,
baseGenerator: ProtocolTestGenerator,
): ProtocolTestGenerator = baseGenerator.maybeApply(codegenContext.serviceShape) {
delegateTo.protocolTestGenerator(codegenContext, baseGenerator)
}
}

View File

@ -14,7 +14,7 @@ import software.amazon.smithy.rulesengine.traits.EndpointTestOperationInput
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointTypesGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.clientInstantiator
import software.amazon.smithy.rust.codegen.client.smithy.generators.ClientInstantiator
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.AttributeKind
import software.amazon.smithy.rust.codegen.core.rustlang.escape
@ -120,7 +120,7 @@ class OperationInputTestGenerator(_ctx: ClientCodegenContext, private val test:
private val moduleName = ctx.moduleUseName()
private val endpointCustomizations = ctx.rootDecorator.endpointCustomizations(ctx)
private val model = ctx.model
private val instantiator = clientInstantiator(ctx)
private val instantiator = ClientInstantiator(ctx)
private fun EndpointTestOperationInput.operationId() =
ShapeId.fromOptionalNamespace(ctx.serviceShape.id.namespace, operationName)

File diff suppressed because it is too large Load Diff

View File

@ -60,6 +60,7 @@ val crateVersioner by lazy { aws.sdk.CrateVersioner.defaultFor(rootProject, prop
fun getRustMSRV(): String = properties.get("rust.msrv") ?: throw Exception("Rust MSRV missing")
fun getPreviousReleaseVersionManifestPath(): String? = properties.get("aws.sdk.previous.release.versions.manifest")
fun getSmithyRuntimeMode(): String = properties.get("smithy.runtime.mode") ?: "middleware"
fun loadServiceMembership(): Membership {
val membershipOverride = properties.get("aws.services")?.let { parseMembership(it) }
@ -102,7 +103,7 @@ fun generateSmithyBuild(services: AwsServices): String {
"renameErrors": false,
"debugMode": $debugMode,
"eventStreamAllowList": [$eventStreamAllowListMembers],
"enableNewSmithyRuntime": "middleware"
"enableNewSmithyRuntime": "${getSmithyRuntimeMode()}"
},
"service": "${service.service}",
"module": "$moduleName",

View File

@ -2,7 +2,6 @@
# `./gradlew -Paws.fullsdk=true :aws:sdk:assemble` these tests are copied into their respective Service crates.
[workspace]
members = [
"apigateway",
"dynamodb",
"ec2",
"glacier",

View File

@ -1,17 +0,0 @@
[package]
name = "apigateway-tests"
version = "0.1.0"
authors = ["AWS Rust SDK Team <aws-sdk-rust@amazon.com>"]
edition = "2021"
license = "Apache-2.0"
repository = "https://github.com/awslabs/smithy-rs"
publish = false
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
aws-credential-types = { path = "../../build/aws-sdk/sdk/aws-credential-types", features = ["test-util"] }
aws-sdk-apigateway = { path = "../../build/aws-sdk/sdk/apigateway" }
aws-smithy-client = { path = "../../build/aws-sdk/sdk/aws-smithy-client", features = ["test-util", "rustls"] }
aws-smithy-protocol-test = { path = "../../build/aws-sdk/sdk/aws-smithy-protocol-test"}
tokio = { version = "1.23.1", features = ["full", "test-util"]}

View File

@ -1,32 +0,0 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
use aws_sdk_apigateway::config::{Credentials, Region};
use aws_sdk_apigateway::{Client, Config};
use aws_smithy_client::test_connection::capture_request;
use aws_smithy_protocol_test::{assert_ok, validate_headers};
#[tokio::test]
async fn accept_header_is_application_json() {
let (conn, handler) = capture_request(None);
let conf = Config::builder()
.region(Region::new("us-east-1"))
.credentials_provider(Credentials::for_tests())
.http_connector(conn)
.build();
let client = Client::from_conf(conf);
let _result = client
.delete_resource()
.rest_api_id("some-rest-api-id")
.resource_id("some-resource-id")
.send()
.await;
let request = handler.expect_request();
assert_ok(validate_headers(
request.headers(),
[("accept", "application/json")],
));
}

View File

@ -252,7 +252,7 @@ fun Project.registerCargoCommandsTasks(
dependsOn(dependentTasks)
workingDir(outputDir)
environment("RUSTFLAGS", "--cfg aws_sdk_unstable")
commandLine("cargo", "test", "--all-features")
commandLine("cargo", "test", "--all-features", "--no-fail-fast")
}
this.tasks.register<Exec>(Cargo.DOCS.toString) {

View File

@ -24,7 +24,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceGener
import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.error.OperationErrorGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolTestGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.DefaultProtocolTestGenerator
import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientProtocolLoader
import software.amazon.smithy.rust.codegen.client.smithy.transformers.AddErrorMessage
import software.amazon.smithy.rust.codegen.client.smithy.transformers.RemoveEventStreamOperations
@ -311,22 +311,24 @@ class ClientCodegenVisitor(
)
// render protocol tests into `operation.rs` (note operationWriter vs. inputWriter)
ProtocolTestGenerator(
codegenDecorator.protocolTestGenerator(
codegenContext,
protocolGeneratorFactory.support(),
operationShape,
this@operationWriter,
).render()
DefaultProtocolTestGenerator(
codegenContext,
protocolGeneratorFactory.support(),
operationShape,
),
).render(this@operationWriter)
}
}
rustCrate.withModule(symbolProvider.moduleForOperationError(operationShape)) {
OperationErrorGenerator(
model,
symbolProvider,
operationShape,
codegenDecorator.errorCustomizations(codegenContext, emptyList()),
).render(this)
rustCrate.withModule(symbolProvider.moduleForOperationError(operationShape)) {
OperationErrorGenerator(
model,
symbolProvider,
operationShape,
codegenDecorator.errorCustomizations(codegenContext, emptyList()),
).render(this)
}
}
}
}

View File

@ -15,25 +15,28 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection
import software.amazon.smithy.rust.codegen.core.util.orNull
class EndpointPrefixGenerator(private val codegenContext: ClientCodegenContext, private val shape: OperationShape) :
OperationCustomization() {
override fun section(section: OperationSection): Writable = when (section) {
is OperationSection.MutateRequest -> writable {
companion object {
fun endpointTraitBindings(codegenContext: ClientCodegenContext, shape: OperationShape): EndpointTraitBindings? =
shape.getTrait(EndpointTrait::class.java).map { epTrait ->
val endpointTraitBindings = EndpointTraitBindings(
EndpointTraitBindings(
codegenContext.model,
codegenContext.symbolProvider,
codegenContext.runtimeConfig,
shape,
epTrait,
)
}.orNull()
}
override fun section(section: OperationSection): Writable = when (section) {
is OperationSection.MutateRequest -> writable {
endpointTraitBindings(codegenContext, shape)?.also { endpointTraitBindings ->
withBlock("let endpoint_prefix = ", "?;") {
endpointTraitBindings.render(
this,
"self",
codegenContext.smithyRuntimeMode,
)
endpointTraitBindings.render(this, "self", codegenContext.smithyRuntimeMode)
}
rust("request.properties_mut().insert(endpoint_prefix);")
}

View File

@ -15,6 +15,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRunti
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolTestGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.smithy.customize.CombinedCoreCodegenDecorator
import software.amazon.smithy.rust.codegen.core.smithy.customize.CoreCodegenDecorator
@ -77,6 +78,14 @@ interface ClientCodegenDecorator : CoreCodegenDecorator<ClientCodegenContext> {
operation: OperationShape,
baseCustomizations: List<OperationRuntimePluginCustomization>,
): List<OperationRuntimePluginCustomization> = baseCustomizations
/**
* Hook to override the protocol test generator
*/
fun protocolTestGenerator(
codegenContext: ClientCodegenContext,
baseGenerator: ProtocolTestGenerator,
): ProtocolTestGenerator = baseGenerator
}
/**
@ -143,6 +152,13 @@ open class CombinedClientCodegenDecorator(decorators: List<ClientCodegenDecorato
decorator.operationRuntimePluginCustomizations(codegenContext, operation, customizations)
}
override fun protocolTestGenerator(
codegenContext: ClientCodegenContext,
baseGenerator: ProtocolTestGenerator,
): ProtocolTestGenerator = combineCustomizations(baseGenerator) { decorator, gen ->
decorator.protocolTestGenerator(codegenContext, gen)
}
companion object {
fun fromClasspath(
context: PluginContext,

View File

@ -10,10 +10,11 @@ import software.amazon.smithy.rulesengine.language.syntax.Identifier
import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters
import software.amazon.smithy.rulesengine.traits.EndpointTestCase
import software.amazon.smithy.rulesengine.traits.ExpectedEndpoint
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointCustomization
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.Types
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName
import software.amazon.smithy.rust.codegen.client.smithy.generators.clientInstantiator
import software.amazon.smithy.rust.codegen.client.smithy.generators.ClientInstantiator
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.docs
import software.amazon.smithy.rust.codegen.core.rustlang.escape
@ -22,7 +23,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.dq
@ -34,8 +34,7 @@ internal class EndpointTestGenerator(
private val resolverType: RuntimeType,
private val params: Parameters,
private val endpointCustomizations: List<EndpointCustomization>,
codegenContext: CodegenContext,
codegenContext: ClientCodegenContext,
) {
private val runtimeConfig = codegenContext.runtimeConfig
private val serviceShape = codegenContext.serviceShape
@ -50,7 +49,7 @@ internal class EndpointTestGenerator(
"capture_request" to RuntimeType.captureRequest(runtimeConfig),
)
private val instantiator = clientInstantiator(codegenContext)
private val instantiator = ClientInstantiator(codegenContext)
private fun EndpointTestCase.docs(): Writable {
val self = this

View File

@ -6,8 +6,14 @@
package software.amazon.smithy.rust.codegen.client.smithy.generators
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.node.ObjectNode
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.writable
@ -29,11 +35,27 @@ class ClientBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiat
override fun doesSetterTakeInOption(memberShape: MemberShape): Boolean = true
}
fun clientInstantiator(codegenContext: CodegenContext) =
Instantiator(
codegenContext.symbolProvider,
codegenContext.model,
codegenContext.runtimeConfig,
ClientBuilderKindBehavior(codegenContext),
::enumFromStringFn,
)
class ClientInstantiator(private val codegenContext: ClientCodegenContext) : Instantiator(
codegenContext.symbolProvider,
codegenContext.model,
codegenContext.runtimeConfig,
ClientBuilderKindBehavior(codegenContext),
::enumFromStringFn,
) {
fun renderFluentCall(
writer: RustWriter,
clientName: String,
operationShape: OperationShape,
inputShape: StructureShape,
data: Node,
headers: Map<String, String> = mapOf(),
ctx: Ctx = Ctx(),
) {
val operationBuilderName =
FluentClientGenerator.clientOperationFnName(operationShape, codegenContext.symbolProvider)
writer.rust("$clientName.$operationBuilderName()")
renderStructureMembers(writer, inputShape, data as ObjectNode, headers, ctx)
}
}

View File

@ -45,7 +45,12 @@ class EndpointTraitBindings(
*
* The returned expression is a `Result<EndpointPrefix, UriError>`
*/
fun render(writer: RustWriter, input: String, smithyRuntimeMode: SmithyRuntimeMode) {
fun render(
writer: RustWriter,
input: String,
smithyRuntimeMode: SmithyRuntimeMode,
generateValidation: Boolean = true,
) {
// the Rust format pattern to make the endpoint prefix e.g. "{}.foo"
val formatLiteral = endpointTrait.prefixFormatString()
if (endpointTrait.hostPrefix.labels.isEmpty()) {
@ -68,28 +73,30 @@ class EndpointTraitBindings(
// NOTE: this is dead code until we start respecting @required
rust("let $field = &$input.$field;")
}
val contents = if (smithyRuntimeMode.generateOrchestrator) {
// TODO(enableNewSmithyRuntime): Remove the allow attribute once all places need .into method
"""
if $field.is_empty() {
##[allow(clippy::useless_conversion)]
return Err(#{invalidFieldError:W}.into())
if (generateValidation) {
val contents = if (smithyRuntimeMode.generateOrchestrator) {
// TODO(enableNewSmithyRuntime): Remove the allow attribute once all places need .into method
"""
if $field.is_empty() {
##[allow(clippy::useless_conversion)]
return Err(#{invalidFieldError:W}.into())
}
"""
} else {
"""
if $field.is_empty() {
return Err(#{invalidFieldError:W})
}
"""
}
"""
} else {
"""
if $field.is_empty() {
return Err(#{invalidFieldError:W})
}
"""
rustTemplate(
contents,
"invalidFieldError" to OperationBuildError(runtimeConfig).invalidField(
field,
"$field was unset or empty but must be set as part of the endpoint prefix",
),
)
}
rustTemplate(
contents,
"invalidFieldError" to OperationBuildError(runtimeConfig).invalidField(
field,
"$field was unset or empty but must be set as part of the endpoint prefix",
),
)
"${label.content} = $field"
}
rustTemplate(

View File

@ -20,9 +20,11 @@ import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase
import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
import software.amazon.smithy.rust.codegen.client.smithy.generators.clientInstantiator
import software.amazon.smithy.rust.codegen.client.smithy.customizations.EndpointPrefixGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.ClientInstantiator
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.allow
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
@ -44,15 +46,47 @@ import software.amazon.smithy.rust.codegen.core.util.outputShape
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import java.util.logging.Logger
data class ClientCreationParams(
val codegenContext: ClientCodegenContext,
val connectorName: String,
val configBuilderName: String,
val clientName: String,
)
interface ProtocolTestGenerator {
val codegenContext: ClientCodegenContext
val protocolSupport: ProtocolSupport
val operationShape: OperationShape
fun render(writer: RustWriter)
}
/**
* Generate protocol tests for an operation
*/
class ProtocolTestGenerator(
private val codegenContext: ClientCodegenContext,
private val protocolSupport: ProtocolSupport,
private val operationShape: OperationShape,
private val writer: RustWriter,
) {
class DefaultProtocolTestGenerator(
override val codegenContext: ClientCodegenContext,
override val protocolSupport: ProtocolSupport,
override val operationShape: OperationShape,
private val renderClientCreation: RustWriter.(ClientCreationParams) -> Unit = { params ->
rustTemplate(
"""
let smithy_client = #{Builder}::new()
.connector(${params.connectorName})
.middleware(#{MapRequestLayer}::for_mapper(#{SmithyEndpointStage}::new()))
.build();
let ${params.clientName} = #{Client}::with_config(smithy_client, ${params.configBuilderName}.build());
""",
"Client" to ClientRustModule.root.toType().resolve("Client"),
"Builder" to ClientRustModule.client.toType().resolve("Builder"),
"SmithyEndpointStage" to RuntimeType.smithyHttp(codegenContext.runtimeConfig)
.resolve("endpoint::middleware::SmithyEndpointStage"),
"MapRequestLayer" to RuntimeType.smithyHttpTower(codegenContext.runtimeConfig)
.resolve("map_request::MapRequestLayer"),
)
},
) : ProtocolTestGenerator {
private val logger = Logger.getLogger(javaClass.name)
private val inputShape = operationShape.inputShape(codegenContext.model)
@ -60,7 +94,7 @@ class ProtocolTestGenerator(
private val operationSymbol = codegenContext.symbolProvider.toSymbol(operationShape)
private val operationIndex = OperationIndex.of(codegenContext.model)
private val instantiator = clientInstantiator(codegenContext)
private val instantiator = ClientInstantiator(codegenContext)
private val codegenScope = arrayOf(
"SmithyHttp" to RuntimeType.smithyHttp(codegenContext.runtimeConfig),
@ -75,7 +109,7 @@ class ProtocolTestGenerator(
TestCase()
}
fun render() {
override fun render(writer: RustWriter) {
val requestTests = operationShape.getTrait<HttpRequestTestsTrait>()
?.getTestCasesFor(AppliesTo.CLIENT).orEmpty().map { TestCase.RequestTest(it) }
val responseTests = operationShape.getTrait<HttpResponseTestsTrait>()
@ -150,6 +184,7 @@ class ProtocolTestGenerator(
is Action.Response -> "_response"
is Action.Request -> "_request"
}
Attribute.AllowUnusedMut.render(testModuleWriter)
testModuleWriter.rustBlock("async fn ${testCase.id.toSnakeCase()}$fnName()") {
block(this)
}
@ -166,33 +201,58 @@ class ProtocolTestGenerator(
writable {
val customizations = codegenContext.rootDecorator.endpointCustomizations(codegenContext)
params.getObjectMember("builtInParams").orNull()?.members?.forEach { (name, value) ->
customizations.firstNotNullOf { it.setBuiltInOnServiceConfig(name.value, value, "builder") }(this)
customizations.firstNotNullOf {
it.setBuiltInOnServiceConfig(name.value, value, "config_builder")
}(this)
}
}
} ?: writable { }
rustTemplate(
"""
let builder = #{config}::Config::builder().with_test_defaults().endpoint_resolver("https://example.com");
let (conn, request_receiver) = #{capture_request}(None);
let config_builder = #{config}::Config::builder().with_test_defaults().endpoint_resolver("https://example.com");
#{customParams}
let config = builder.build();
""",
"capture_request" to CargoDependency.smithyClient(codegenContext.runtimeConfig)
.toDevDependency()
.withFeature("test-util")
.toType()
.resolve("test_connection::capture_request"),
"config" to ClientRustModule.Config,
"customParams" to customParams,
)
writeInline("let input =")
instantiator.render(this, inputShape, httpRequestTestCase.params)
renderClientCreation(this, ClientCreationParams(codegenContext, "conn", "config_builder", "client"))
writeInline("let result = ")
instantiator.renderFluentCall(this, "client", operationShape, inputShape, httpRequestTestCase.params)
rust(""".send().await;""")
// Response parsing will always fail since we feed it an empty response body, so we don't care
// if it fails, but it is helpful to print what that failure was for debugging
rust("let _ = dbg!(result);")
rust("""let http_request = request_receiver.expect_request();""")
rust(""".make_operation(&config).await.expect("operation failed to build");""")
rust("let (http_request, parts) = input.into_request_response().0.into_parts();")
with(httpRequestTestCase) {
// Override the endpoint for tests that set a `host`, for example:
// https://github.com/awslabs/smithy/blob/be68f3bbdfe5bf50a104b387094d40c8069f16b1/smithy-aws-protocol-tests/model/restJson1/endpoint-paths.smithy#L19
host.orNull()?.also { host ->
val withScheme = "http://$host"
when (val bindings = EndpointPrefixGenerator.endpointTraitBindings(codegenContext, operationShape)) {
null -> rust("let endpoint_prefix = None;")
else -> {
withBlock("let input = ", ";") {
instantiator.render(this@renderHttpRequestTestCase, inputShape, httpRequestTestCase.params)
}
withBlock("let endpoint_prefix = Some({", "}.unwrap());") {
bindings.render(this, "input", codegenContext.smithyRuntimeMode, generateValidation = false)
}
}
}
rustTemplate(
"""
let mut http_request = http_request;
let ep = #{SmithyHttp}::endpoint::Endpoint::mutable(${withScheme.dq()}).expect("valid endpoint");
ep.set_endpoint(http_request.uri_mut(), parts.acquire().get()).expect("valid endpoint");
ep.set_endpoint(http_request.uri_mut(), endpoint_prefix.as_ref()).expect("valid endpoint");
""",
*codegenScope,
)

View File

@ -141,7 +141,7 @@ class ResponseDeserializerGenerator(
} else {
#{parse_response}(status, headers, body)
};
#{type_erase_result}(parse_result).into()
#{type_erase_result}(parse_result)
""",
*codegenScope,
"parse_error" to parserGenerator.parseErrorFn(operationShape, customizations),

View File

@ -47,7 +47,7 @@ internal class ClientInstantiatorTest {
@Test
fun `generate named enums`() {
val shape = model.lookup<StringShape>("com.test#NamedEnum")
val sut = clientInstantiator(codegenContext)
val sut = ClientInstantiator(codegenContext)
val data = Node.parse("t2.nano".dq())
val project = TestWorkspace.testProject(symbolProvider)
@ -66,7 +66,7 @@ internal class ClientInstantiatorTest {
@Test
fun `generate unnamed enums`() {
val shape = model.lookup<StringShape>("com.test#UnnamedEnum")
val sut = clientInstantiator(codegenContext)
val sut = ClientInstantiator(codegenContext)
val data = Node.parse("t2.nano".dq())
val project = TestWorkspace.testProject(symbolProvider)

View File

@ -333,6 +333,23 @@ open class Instantiator(
* ```
*/
private fun renderStructure(writer: RustWriter, shape: StructureShape, data: ObjectNode, headers: Map<String, String>, ctx: Ctx) {
writer.rust("#T::builder()", symbolProvider.toSymbol(shape))
renderStructureMembers(writer, shape, data, headers, ctx)
writer.rust(".build()")
if (builderKindBehavior.hasFallibleBuilder(shape)) {
writer.rust(".unwrap()")
}
}
protected fun renderStructureMembers(
writer: RustWriter,
shape: StructureShape,
data: ObjectNode,
headers: Map<String, String>,
ctx: Ctx,
) {
fun renderMemberHelper(memberShape: MemberShape, value: Node) {
val setterName = builderKindBehavior.setterName(memberShape)
writer.withBlock(".$setterName(", ")") {
@ -340,7 +357,6 @@ open class Instantiator(
}
}
writer.rust("#T::builder()", symbolProvider.toSymbol(shape))
if (defaultsForRequiredFields) {
shape.allMembers.entries
.filter { (name, memberShape) ->
@ -374,11 +390,6 @@ open class Instantiator(
?.let {
renderMemberHelper(it.value, fillDefaultValue(model.expectShape(it.value.target)))
}
writer.rust(".build()")
if (builderKindBehavior.hasFallibleBuilder(shape)) {
writer.rust(".unwrap()")
}
}
/**

View File

@ -7,6 +7,7 @@ use crate::client::identity::{Identity, IdentityResolver, IdentityResolvers};
use crate::client::orchestrator::{BoxError, HttpRequest};
use crate::config_bag::ConfigBag;
use crate::type_erasure::{TypeErasedBox, TypedBox};
use aws_smithy_types::Document;
use std::borrow::Cow;
use std::fmt;
use std::sync::Arc;
@ -34,6 +35,12 @@ impl AuthSchemeId {
}
}
impl From<&'static str> for AuthSchemeId {
fn from(scheme_id: &'static str) -> Self {
Self::new(scheme_id)
}
}
#[derive(Debug)]
pub struct AuthOptionResolverParams(TypeErasedBox);
@ -105,10 +112,34 @@ pub trait HttpRequestSigner: Send + Sync + fmt::Debug {
&self,
request: &mut HttpRequest,
identity: &Identity,
auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
config_bag: &ConfigBag,
) -> Result<(), BoxError>;
}
/// Endpoint configuration for the selected auth scheme.
///
/// This struct gets added to the request state by the auth orchestrator.
#[non_exhaustive]
#[derive(Clone, Debug)]
pub struct AuthSchemeEndpointConfig<'a>(Option<&'a Document>);
impl<'a> AuthSchemeEndpointConfig<'a> {
/// Creates a new [`AuthSchemeEndpointConfig`].
pub fn new(config: Option<&'a Document>) -> Self {
Self(config)
}
/// Creates an empty AuthSchemeEndpointConfig.
pub fn empty() -> Self {
Self(None)
}
pub fn config(&self) -> Option<&'a Document> {
self.0
}
}
pub mod builders {
use super::*;

View File

@ -16,7 +16,7 @@ use crate::type_erasure::{TypeErasedBox, TypedBox};
use aws_smithy_async::future::now_or_later::NowOrLater;
use aws_smithy_async::rt::sleep::AsyncSleep;
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::endpoint::EndpointPrefix;
use aws_smithy_types::endpoint::Endpoint;
use std::fmt;
use std::future::Future as StdFuture;
use std::pin::Pin;
@ -74,12 +74,7 @@ impl EndpointResolverParams {
}
pub trait EndpointResolver: Send + Sync + fmt::Debug {
fn resolve_and_apply_endpoint(
&self,
params: &EndpointResolverParams,
endpoint_prefix: Option<&EndpointPrefix>,
request: &mut HttpRequest,
) -> Result<(), BoxError>;
fn resolve_endpoint(&self, params: &EndpointResolverParams) -> Result<Endpoint, BoxError>;
}
/// Time that the request is being made (so that time can be overridden in the [`ConfigBag`]).

View File

@ -8,7 +8,9 @@ use aws_smithy_runtime_api::client::auth::http::{
HTTP_API_KEY_AUTH_SCHEME_ID, HTTP_BASIC_AUTH_SCHEME_ID, HTTP_BEARER_AUTH_SCHEME_ID,
HTTP_DIGEST_AUTH_SCHEME_ID,
};
use aws_smithy_runtime_api::client::auth::{AuthSchemeId, HttpAuthScheme, HttpRequestSigner};
use aws_smithy_runtime_api::client::auth::{
AuthSchemeEndpointConfig, AuthSchemeId, HttpAuthScheme, HttpRequestSigner,
};
use aws_smithy_runtime_api::client::identity::http::{Login, Token};
use aws_smithy_runtime_api::client::identity::{Identity, IdentityResolver, IdentityResolvers};
use aws_smithy_runtime_api::client::orchestrator::{BoxError, HttpRequest};
@ -76,6 +78,7 @@ impl HttpRequestSigner for ApiKeySigner {
&self,
request: &mut HttpRequest,
identity: &Identity,
_auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
_config_bag: &ConfigBag,
) -> Result<(), BoxError> {
let api_key = identity
@ -141,6 +144,7 @@ impl HttpRequestSigner for BasicAuthSigner {
&self,
request: &mut HttpRequest,
identity: &Identity,
_auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
_config_bag: &ConfigBag,
) -> Result<(), BoxError> {
let login = identity
@ -198,6 +202,7 @@ impl HttpRequestSigner for BearerAuthSigner {
&self,
request: &mut HttpRequest,
identity: &Identity,
_auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
_config_bag: &ConfigBag,
) -> Result<(), BoxError> {
let token = identity
@ -253,6 +258,7 @@ impl HttpRequestSigner for DigestAuthSigner {
&self,
_request: &mut HttpRequest,
_identity: &Identity,
_auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
_config_bag: &ConfigBag,
) -> Result<(), BoxError> {
unimplemented!(
@ -281,7 +287,12 @@ mod tests {
.body(SdkBody::empty())
.unwrap();
signer
.sign_request(&mut request, &identity, &config_bag)
.sign_request(
&mut request,
&identity,
AuthSchemeEndpointConfig::empty(),
&config_bag,
)
.expect("success");
assert_eq!(
"SomeSchemeName some-token",
@ -304,7 +315,12 @@ mod tests {
.body(SdkBody::empty())
.unwrap();
signer
.sign_request(&mut request, &identity, &config_bag)
.sign_request(
&mut request,
&identity,
AuthSchemeEndpointConfig::empty(),
&config_bag,
)
.expect("success");
assert!(request.headers().get("some-query-name").is_none());
assert_eq!(
@ -321,7 +337,12 @@ mod tests {
let mut request = http::Request::builder().body(SdkBody::empty()).unwrap();
signer
.sign_request(&mut request, &identity, &config_bag)
.sign_request(
&mut request,
&identity,
AuthSchemeEndpointConfig::empty(),
&config_bag,
)
.expect("success");
assert_eq!(
"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==",
@ -337,7 +358,12 @@ mod tests {
let identity = Identity::new(Token::new("some-token", None), None);
let mut request = http::Request::builder().body(SdkBody::empty()).unwrap();
signer
.sign_request(&mut request, &identity, &config_bag)
.sign_request(
&mut request,
&identity,
AuthSchemeEndpointConfig::empty(),
&config_bag,
)
.expect("success");
assert_eq!(
"Bearer some-token",

View File

@ -3,11 +3,62 @@
* SPDX-License-Identifier: Apache-2.0
*/
use aws_smithy_runtime_api::client::auth::{AuthSchemeEndpointConfig, AuthSchemeId};
use aws_smithy_runtime_api::client::interceptors::InterceptorContext;
use aws_smithy_runtime_api::client::orchestrator::{BoxError, ConfigBagAccessors};
use aws_smithy_runtime_api::config_bag::ConfigBag;
use aws_smithy_types::endpoint::Endpoint;
use aws_smithy_types::Document;
use std::borrow::Cow;
use std::error::Error as StdError;
use std::fmt;
#[derive(Debug)]
enum AuthOrchestrationError {
NoMatchingAuthScheme,
BadAuthSchemeEndpointConfig(Cow<'static, str>),
AuthSchemeEndpointConfigMismatch(String),
}
impl AuthOrchestrationError {
fn auth_scheme_endpoint_config_mismatch<'a>(
auth_schemes: impl Iterator<Item = &'a Document>,
) -> Self {
Self::AuthSchemeEndpointConfigMismatch(
auth_schemes
.flat_map(|s| match s {
Document::Object(map) => match map.get("name") {
Some(Document::String(name)) => Some(name.as_str()),
_ => None,
},
_ => None,
})
.collect::<Vec<_>>()
.join(", "),
)
}
}
impl fmt::Display for AuthOrchestrationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NoMatchingAuthScheme => f.write_str(
"no auth scheme matched auth options. This is a bug. Please file an issue.",
),
Self::BadAuthSchemeEndpointConfig(message) => f.write_str(message),
Self::AuthSchemeEndpointConfigMismatch(supported_schemes) => {
write!(f,
"selected auth scheme / endpoint config mismatch. Couldn't find `sigv4` endpoint config for this endpoint. \
The authentication schemes supported by this endpoint are: {:?}",
supported_schemes
)
}
}
}
}
impl StdError for AuthOrchestrationError {}
pub(super) async fn orchestrate_auth(
ctx: &mut InterceptorContext,
cfg: &ConfigBag,
@ -22,36 +73,62 @@ pub(super) async fn orchestrate_auth(
identity_resolvers = ?identity_resolvers,
"orchestrating auth",
);
for &scheme_id in auth_options.as_ref() {
if let Some(auth_scheme) = cfg.http_auth_schemes().scheme(scheme_id) {
if let Some(identity_resolver) = auth_scheme.identity_resolver(identity_resolvers) {
let request_signer = auth_scheme.request_signer();
let endpoint = cfg
.get::<Endpoint>()
.expect("endpoint added to config bag by endpoint orchestrator");
let auth_scheme_endpoint_config =
extract_endpoint_auth_scheme_config(endpoint, scheme_id)?;
let identity = identity_resolver.resolve_identity(cfg).await?;
let request = ctx.request_mut();
request_signer.sign_request(request, &identity, cfg)?;
request_signer.sign_request(
request,
&identity,
auth_scheme_endpoint_config,
cfg,
)?;
return Ok(());
}
}
}
Err(NoMatchingAuthScheme.into())
Err(AuthOrchestrationError::NoMatchingAuthScheme.into())
}
#[derive(Debug)]
struct NoMatchingAuthScheme;
impl fmt::Display for NoMatchingAuthScheme {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"no auth scheme matched auth options. This is a bug. Please file an issue."
)
}
fn extract_endpoint_auth_scheme_config(
endpoint: &Endpoint,
scheme_id: AuthSchemeId,
) -> Result<AuthSchemeEndpointConfig<'_>, AuthOrchestrationError> {
let auth_schemes = match endpoint.properties().get("authSchemes") {
Some(Document::Array(schemes)) => schemes,
// no auth schemes:
None => return Ok(AuthSchemeEndpointConfig::new(None)),
_other => {
return Err(AuthOrchestrationError::BadAuthSchemeEndpointConfig(
"expected an array for `authSchemes` in endpoint config".into(),
))
}
};
let auth_scheme_config = auth_schemes
.iter()
.find(|doc| {
let config_scheme_id = doc
.as_object()
.and_then(|object| object.get("name"))
.and_then(Document::as_string);
config_scheme_id == Some(scheme_id.as_str())
})
.ok_or_else(|| {
AuthOrchestrationError::auth_scheme_endpoint_config_mismatch(auth_schemes.iter())
})?;
Ok(AuthSchemeEndpointConfig::new(Some(auth_scheme_config)))
}
impl std::error::Error for NoMatchingAuthScheme {}
#[cfg(test)]
mod tests {
use super::*;
@ -64,6 +141,7 @@ mod tests {
use aws_smithy_runtime_api::client::interceptors::InterceptorContext;
use aws_smithy_runtime_api::client::orchestrator::{Future, HttpRequest};
use aws_smithy_runtime_api::type_erasure::TypedBox;
use std::collections::HashMap;
#[tokio::test]
async fn basic_case() {
@ -83,6 +161,7 @@ mod tests {
&self,
request: &mut HttpRequest,
_identity: &Identity,
_auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
_config_bag: &ConfigBag,
) -> Result<(), BoxError> {
request
@ -134,6 +213,7 @@ mod tests {
.auth_scheme(TEST_SCHEME_ID, TestAuthScheme { signer: TestSigner })
.build(),
);
cfg.put(Endpoint::builder().url("dontcare").build());
orchestrate_auth(&mut ctx, &cfg).await.expect("success");
@ -170,6 +250,7 @@ mod tests {
.auth_scheme(HTTP_BEARER_AUTH_SCHEME_ID, BearerAuthScheme::new())
.build(),
);
cfg.put(Endpoint::builder().url("dontcare").build());
// First, test the presence of a basic auth login and absence of a bearer token
cfg.set_identity_resolvers(
@ -203,4 +284,92 @@ mod tests {
ctx.request().headers().get("Authorization").unwrap()
);
}
#[test]
fn extract_endpoint_auth_scheme_config_no_config() {
let endpoint = Endpoint::builder()
.url("dontcare")
.property("something-unrelated", Document::Null)
.build();
let config = extract_endpoint_auth_scheme_config(&endpoint, "test-scheme-id".into())
.expect("success");
assert!(config.config().is_none());
}
#[test]
fn extract_endpoint_auth_scheme_config_wrong_type() {
let endpoint = Endpoint::builder()
.url("dontcare")
.property("authSchemes", Document::String("bad".into()))
.build();
extract_endpoint_auth_scheme_config(&endpoint, "test-scheme-id".into())
.expect_err("should fail because authSchemes is the wrong type");
}
#[test]
fn extract_endpoint_auth_scheme_config_no_matching_scheme() {
let endpoint = Endpoint::builder()
.url("dontcare")
.property(
"authSchemes",
vec![
Document::Object({
let mut out = HashMap::new();
out.insert("name".to_string(), "wrong-scheme-id".to_string().into());
out
}),
Document::Object({
let mut out = HashMap::new();
out.insert(
"name".to_string(),
"another-wrong-scheme-id".to_string().into(),
);
out
}),
],
)
.build();
extract_endpoint_auth_scheme_config(&endpoint, "test-scheme-id".into())
.expect_err("should fail because authSchemes doesn't include the desired scheme");
}
#[test]
fn extract_endpoint_auth_scheme_config_successfully() {
let endpoint = Endpoint::builder()
.url("dontcare")
.property(
"authSchemes",
vec![
Document::Object({
let mut out = HashMap::new();
out.insert("name".to_string(), "wrong-scheme-id".to_string().into());
out
}),
Document::Object({
let mut out = HashMap::new();
out.insert("name".to_string(), "test-scheme-id".to_string().into());
out.insert(
"magicString".to_string(),
"magic string value".to_string().into(),
);
out
}),
],
)
.build();
let config = extract_endpoint_auth_scheme_config(&endpoint, "test-scheme-id".into())
.expect("should find test-scheme-id");
assert_eq!(
"magic string value",
config
.config()
.expect("config is set")
.as_object()
.expect("it's an object")
.get("magicString")
.expect("magicString is set")
.as_string()
.expect("gimme the string, dammit!")
);
}
}

View File

@ -5,13 +5,15 @@
use aws_smithy_http::endpoint::error::ResolveEndpointError;
use aws_smithy_http::endpoint::{
apply_endpoint, EndpointPrefix, ResolveEndpoint, SharedEndpointResolver,
apply_endpoint as apply_endpoint_to_request_uri, EndpointPrefix, ResolveEndpoint,
SharedEndpointResolver,
};
use aws_smithy_runtime_api::client::interceptors::InterceptorContext;
use aws_smithy_runtime_api::client::orchestrator::{
BoxError, ConfigBagAccessors, EndpointResolver, EndpointResolverParams, HttpRequest,
};
use aws_smithy_runtime_api::config_bag::ConfigBag;
use aws_smithy_types::endpoint::Endpoint;
use http::header::HeaderName;
use http::{HeaderValue, Uri};
use std::fmt::Debug;
@ -36,14 +38,8 @@ impl StaticUriEndpointResolver {
}
impl EndpointResolver for StaticUriEndpointResolver {
fn resolve_and_apply_endpoint(
&self,
_params: &EndpointResolverParams,
_endpoint_prefix: Option<&EndpointPrefix>,
request: &mut HttpRequest,
) -> Result<(), BoxError> {
apply_endpoint(request.uri_mut(), &self.endpoint, None)?;
Ok(())
fn resolve_endpoint(&self, _params: &EndpointResolverParams) -> Result<Endpoint, BoxError> {
Ok(Endpoint::builder().url(self.endpoint.to_string()).build())
}
}
@ -81,63 +77,64 @@ impl<Params> EndpointResolver for DefaultEndpointResolver<Params>
where
Params: Debug + Send + Sync + 'static,
{
fn resolve_and_apply_endpoint(
&self,
params: &EndpointResolverParams,
endpoint_prefix: Option<&EndpointPrefix>,
request: &mut HttpRequest,
) -> Result<(), BoxError> {
let endpoint = match params.get::<Params>() {
Some(params) => self.inner.resolve_endpoint(params)?,
None => {
return Err(Box::new(ResolveEndpointError::message(
"params of expected type was not present",
)));
}
};
let uri: Uri = endpoint.url().parse().map_err(|err| {
ResolveEndpointError::from_source("endpoint did not have a valid uri", err)
})?;
apply_endpoint(request.uri_mut(), &uri, endpoint_prefix).map_err(|err| {
ResolveEndpointError::message(format!(
"failed to apply endpoint `{:?}` to request `{:?}`",
uri, request,
))
.with_source(Some(err.into()))
})?;
for (header_name, header_values) in endpoint.headers() {
request.headers_mut().remove(header_name);
for value in header_values {
request.headers_mut().insert(
HeaderName::from_str(header_name).map_err(|err| {
ResolveEndpointError::message("invalid header name")
.with_source(Some(err.into()))
})?,
HeaderValue::from_str(value).map_err(|err| {
ResolveEndpointError::message("invalid header value")
.with_source(Some(err.into()))
})?,
);
}
fn resolve_endpoint(&self, params: &EndpointResolverParams) -> Result<Endpoint, BoxError> {
match params.get::<Params>() {
Some(params) => Ok(self.inner.resolve_endpoint(params)?),
None => Err(Box::new(ResolveEndpointError::message(
"params of expected type was not present",
))),
}
Ok(())
}
}
pub(super) fn orchestrate_endpoint(
ctx: &mut InterceptorContext,
cfg: &ConfigBag,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let params = cfg.endpoint_resolver_params();
let endpoint_prefix = cfg.get::<EndpointPrefix>();
let request = ctx.request_mut();
let endpoint_resolver = cfg.endpoint_resolver();
endpoint_resolver.resolve_and_apply_endpoint(params, endpoint_prefix, request)?;
let endpoint = endpoint_resolver.resolve_endpoint(params)?;
apply_endpoint(request, &endpoint, endpoint_prefix)?;
// Make the endpoint config available to interceptors
cfg.put(endpoint);
Ok(())
}
fn apply_endpoint(
request: &mut HttpRequest,
endpoint: &Endpoint,
endpoint_prefix: Option<&EndpointPrefix>,
) -> Result<(), BoxError> {
let uri: Uri = endpoint.url().parse().map_err(|err| {
ResolveEndpointError::from_source("endpoint did not have a valid uri", err)
})?;
apply_endpoint_to_request_uri(request.uri_mut(), &uri, endpoint_prefix).map_err(|err| {
ResolveEndpointError::message(format!(
"failed to apply endpoint `{:?}` to request `{:?}`",
uri, request,
))
.with_source(Some(err.into()))
})?;
for (header_name, header_values) in endpoint.headers() {
request.headers_mut().remove(header_name);
for value in header_values {
request.headers_mut().insert(
HeaderName::from_str(header_name).map_err(|err| {
ResolveEndpointError::message("invalid header name")
.with_source(Some(err.into()))
})?,
HeaderValue::from_str(value).map_err(|err| {
ResolveEndpointError::message("invalid header value")
.with_source(Some(err.into()))
})?,
);
}
}
Ok(())
}

View File

@ -10,7 +10,7 @@ use aws_smithy_runtime_api::client::auth::option_resolver::{
StaticAuthOptionResolver, StaticAuthOptionResolverParams,
};
use aws_smithy_runtime_api::client::auth::{
AuthSchemeId, HttpAuthScheme, HttpAuthSchemes, HttpRequestSigner,
AuthSchemeEndpointConfig, AuthSchemeId, HttpAuthScheme, HttpAuthSchemes, HttpRequestSigner,
};
use aws_smithy_runtime_api::client::identity::{Identity, IdentityResolver, IdentityResolvers};
use aws_smithy_runtime_api::client::interceptors::InterceptorRegistrar;
@ -83,6 +83,7 @@ impl HttpRequestSigner for AnonymousSigner {
&self,
_request: &mut HttpRequest,
_identity: &Identity,
_auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
_config_bag: &ConfigBag,
) -> Result<(), BoxError> {
Ok(())

View File

@ -4,6 +4,7 @@
*/
use crate::Number;
use std::borrow::Cow;
use std::collections::HashMap;
/* ANCHOR: document */
@ -14,7 +15,7 @@ use std::collections::HashMap;
/// Open content is useful for modeling unstructured data that has no schema, data that can't be
/// modeled using rigid types, or data that has a schema that evolves outside of the purview of a model.
/// The serialization format of a document is an implementation detail of a protocol.
#[derive(Debug, Clone, PartialEq)]
#[derive(Clone, Debug, PartialEq)]
pub enum Document {
/// JSON object
Object(HashMap<String, Document>),
@ -30,12 +31,135 @@ pub enum Document {
Null,
}
impl Document {
/// Returns the inner map value if this `Document` is an object.
pub fn as_object(&self) -> Option<&HashMap<String, Document>> {
if let Self::Object(object) = self {
Some(object)
} else {
None
}
}
/// Returns the mutable inner map value if this `Document` is an object.
pub fn as_object_mut(&mut self) -> Option<&mut HashMap<String, Document>> {
if let Self::Object(object) = self {
Some(object)
} else {
None
}
}
/// Returns the inner array value if this `Document` is an array.
pub fn as_array(&self) -> Option<&Vec<Document>> {
if let Self::Array(array) = self {
Some(array)
} else {
None
}
}
/// Returns the mutable inner array value if this `Document` is an array.
pub fn as_array_mut(&mut self) -> Option<&mut Vec<Document>> {
if let Self::Array(array) = self {
Some(array)
} else {
None
}
}
/// Returns the inner number value if this `Document` is a number.
pub fn as_number(&self) -> Option<&Number> {
if let Self::Number(number) = self {
Some(number)
} else {
None
}
}
/// Returns the inner string value if this `Document` is a string.
pub fn as_string(&self) -> Option<&str> {
if let Self::String(string) = self {
Some(string)
} else {
None
}
}
/// Returns the inner boolean value if this `Document` is a boolean.
pub fn as_bool(&self) -> Option<bool> {
if let Self::Bool(boolean) = self {
Some(*boolean)
} else {
None
}
}
/// Returns `Some(())` if this `Document` is a null.
pub fn as_null(&self) -> Option<()> {
if let Self::Null = self {
Some(())
} else {
None
}
}
/// Returns `true` if this `Document` is an object.
pub fn is_object(&self) -> bool {
matches!(self, Self::Object(_))
}
/// Returns `true` if this `Document` is an array.
pub fn is_array(&self) -> bool {
matches!(self, Self::Array(_))
}
/// Returns `true` if this `Document` is a number.
pub fn is_number(&self) -> bool {
matches!(self, Self::Number(_))
}
/// Returns `true` if this `Document` is a string.
pub fn is_string(&self) -> bool {
matches!(self, Self::String(_))
}
/// Returns `true` if this `Document` is a bool.
pub fn is_bool(&self) -> bool {
matches!(self, Self::Bool(_))
}
/// Returns `true` if this `Document` is a boolean.
pub fn is_null(&self) -> bool {
matches!(self, Self::Null)
}
}
/// The default value is `Document::Null`.
impl Default for Document {
fn default() -> Self {
Self::Null
}
}
impl From<bool> for Document {
fn from(value: bool) -> Self {
Document::Bool(value)
}
}
impl<'a> From<&'a str> for Document {
fn from(value: &'a str) -> Self {
Document::String(value.to_string())
}
}
impl<'a> From<Cow<'a, str>> for Document {
fn from(value: Cow<'a, str>) -> Self {
Document::String(value.into_owned())
}
}
impl From<String> for Document {
fn from(value: String) -> Self {
Document::String(value)
@ -71,3 +195,15 @@ impl From<i32> for Document {
Document::Number(Number::NegInt(value as i64))
}
}
impl From<f64> for Document {
fn from(value: f64) -> Self {
Document::Number(Number::Float(value))
}
}
impl From<Number> for Document {
fn from(value: Number) -> Self {
Document::Number(value)
}
}

View File

@ -4,6 +4,17 @@
# SPDX-License-Identifier: Apache-2.0
#
set -eux
C_YELLOW='\033[1;33m'
C_RESET='\033[0m'
set -eu
cd smithy-rs
./gradlew aws:sdk-adhoc-test:test
# TODO(enableNewSmithyRuntime): Remove the middleware test run when cleaning up middleware
echo -e "## ${C_YELLOW}Running SDK adhoc tests against the middleware implementation...${C_RESET}"
./gradlew aws:sdk-adhoc-test:clean
./gradlew aws:sdk-adhoc-test:check -Psmithy.runtime.mode=middleware
echo -e "## ${C_YELLOW}Running SDK adhoc tests against the orchestrator implementation...${C_RESET}"
./gradlew aws:sdk-adhoc-test:clean
./gradlew aws:sdk-adhoc-test:check -Psmithy.runtime.mode=orchestrator