Simplify event stream message signer configuration (#2671)

## Motivation and Context

This PR creates a `DeferredSigner` implementation that allows for the
event stream message signer to be wired up by the signing implementation
later in the request lifecycle rather than by adding an event stream
signer method to the config.

Refactoring this brings the middleware client implementation closer to
how the orchestrator implementation will work, which unblocks the work
required to make event streams work in the orchestrator.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
This commit is contained in:
John DiSanti 2023-05-26 08:49:56 -07:00 committed by GitHub
parent d083c6f271
commit 8773a70428
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 456 additions and 295 deletions

View File

@ -34,3 +34,19 @@ message = "`SsoCredentialsProvider`, `AssumeRoleProvider`, and `WebIdentityToken
references = ["smithy-rs#2720"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
author = "ysaito1001"
[[smithy-rs]]
message = """
<details>
<summary>Breaking change in how event stream signing works (click to expand more details)</summary>
This change will only impact you if you are wiring up their own event stream signing/authentication scheme. If you're using `aws-sig-auth` to use AWS SigV4 event stream signing, then this change will **not** impact you.
Previously, event stream signing was configured at codegen time by placing a `new_event_stream_signer` method on the `Config`. This function was called at serialization time to connect the signer to the streaming body. Now, instead, a special `DeferredSigner` is wired up at serialization time that relies on a signing implementation to be sent on a channel by the HTTP request signer. To do this, a `DeferredSignerSender` must be pulled out of the property bag, and its `send()` method called with the desired event stream signing implementation.
See the changes in https://github.com/awslabs/smithy-rs/pull/2671 for an example of how this was done for SigV4.
</details>
"""
references = ["smithy-rs#2671"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "jdisanti"

View File

@ -3,6 +3,9 @@
* SPDX-License-Identifier: Apache-2.0
*/
// TODO(enableNewSmithyRuntime): Remove this blanket allow once the old implementations are deleted
#![allow(deprecated)]
use crate::middleware::Signature;
use aws_credential_types::Credentials;
use aws_sigv4::event_stream::{sign_empty_message, sign_message};
@ -15,6 +18,115 @@ use std::time::SystemTime;
/// Event Stream SigV4 signing implementation.
#[derive(Debug)]
pub struct SigV4MessageSigner {
last_signature: String,
credentials: Credentials,
signing_region: SigningRegion,
signing_service: SigningService,
time: Option<SystemTime>,
}
impl SigV4MessageSigner {
pub fn new(
last_signature: String,
credentials: Credentials,
signing_region: SigningRegion,
signing_service: SigningService,
time: Option<SystemTime>,
) -> Self {
Self {
last_signature,
credentials,
signing_region,
signing_service,
time,
}
}
fn signing_params(&self) -> SigningParams<()> {
let mut builder = SigningParams::builder()
.access_key(self.credentials.access_key_id())
.secret_key(self.credentials.secret_access_key())
.region(self.signing_region.as_ref())
.service_name(self.signing_service.as_ref())
.time(self.time.unwrap_or_else(SystemTime::now))
.settings(());
builder.set_security_token(self.credentials.session_token());
builder.build().unwrap()
}
}
impl SignMessage for SigV4MessageSigner {
fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
let (signed_message, signature) = {
let params = self.signing_params();
sign_message(&message, &self.last_signature, &params).into_parts()
};
self.last_signature = signature;
Ok(signed_message)
}
fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
let (signed_message, signature) = {
let params = self.signing_params();
sign_empty_message(&self.last_signature, &params).into_parts()
};
self.last_signature = signature;
Some(Ok(signed_message))
}
}
#[cfg(test)]
mod tests {
use crate::event_stream::SigV4MessageSigner;
use aws_credential_types::Credentials;
use aws_smithy_eventstream::frame::{HeaderValue, Message, SignMessage};
use aws_types::region::Region;
use aws_types::region::SigningRegion;
use aws_types::SigningService;
use std::time::{Duration, UNIX_EPOCH};
fn check_send_sync<T: Send + Sync>(value: T) -> T {
value
}
#[test]
fn sign_message() {
let region = Region::new("us-east-1");
let mut signer = check_send_sync(SigV4MessageSigner::new(
"initial-signature".into(),
Credentials::for_tests(),
SigningRegion::from(region),
SigningService::from_static("transcribe"),
Some(UNIX_EPOCH + Duration::new(1611160427, 0)),
));
let mut signatures = Vec::new();
for _ in 0..5 {
let signed = signer
.sign(Message::new(&b"identical message"[..]))
.unwrap();
if let HeaderValue::ByteArray(signature) = signed
.headers()
.iter()
.find(|h| h.name().as_str() == ":chunk-signature")
.unwrap()
.value()
{
signatures.push(signature.clone());
} else {
panic!("failed to get the :chunk-signature")
}
}
for i in 1..signatures.len() {
assert_ne!(signatures[i - 1], signatures[i]);
}
}
}
// TODO(enableNewSmithyRuntime): Delete this old implementation that was kept around to support patch releases.
#[deprecated = "use aws_sig_auth::event_stream::SigV4MessageSigner instead (this may require upgrading the smithy-rs code generator)"]
#[derive(Debug)]
/// Event Stream SigV4 signing implementation.
pub struct SigV4Signer {
properties: SharedPropertyBag,
last_signature: Option<String>,
@ -87,8 +199,9 @@ impl SignMessage for SigV4Signer {
}
}
// TODO(enableNewSmithyRuntime): Delete this old implementation that was kept around to support patch releases.
#[cfg(test)]
mod tests {
mod old_tests {
use crate::event_stream::SigV4Signer;
use crate::middleware::Signature;
use aws_credential_types::Credentials;

View File

@ -20,8 +20,15 @@ use crate::signer::{
OperationSigningConfig, RequestConfig, SigV4Signer, SigningError, SigningRequirements,
};
#[cfg(feature = "sign-eventstream")]
use crate::event_stream::SigV4MessageSigner as EventStreamSigV4Signer;
#[cfg(feature = "sign-eventstream")]
use aws_smithy_eventstream::frame::DeferredSignerSender;
// TODO(enableNewSmithyRuntime): Delete `Signature` when switching to the orchestrator
/// Container for the request signature for use in the property bag.
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct Signature(String);
impl Signature {
@ -181,6 +188,22 @@ impl MapRequest for SigV4SigningStage {
.signer
.sign(operation_config, &request_config, &creds, &mut req)
.map_err(SigningStageErrorKind::SigningFailure)?;
// If this is an event stream operation, set up the event stream signer
#[cfg(feature = "sign-eventstream")]
if let Some(signer_sender) = config.get::<DeferredSignerSender>() {
let time_override = config.get::<SystemTime>().copied();
signer_sender
.send(Box::new(EventStreamSigV4Signer::new(
signature.as_ref().into(),
creds,
request_config.region.clone(),
request_config.service.clone(),
time_override,
)) as _)
.expect("failed to send deferred signer");
}
config.insert(signature);
Ok(req)
})
@ -234,6 +257,49 @@ mod test {
assert!(signature.is_some());
}
#[cfg(feature = "sign-eventstream")]
#[test]
fn sends_event_stream_signer_for_event_stream_operations() {
use crate::event_stream::SigV4MessageSigner as EventStreamSigV4Signer;
use aws_smithy_eventstream::frame::{DeferredSigner, SignMessage};
use std::time::SystemTime;
let (mut deferred_signer, deferred_signer_sender) = DeferredSigner::new();
let req = http::Request::builder()
.uri("https://test-service.test-region.amazonaws.com/")
.body(SdkBody::from(""))
.unwrap();
let region = Region::new("us-east-1");
let req = operation::Request::new(req)
.augment(|req, properties| {
properties.insert(region.clone());
properties.insert::<SystemTime>(UNIX_EPOCH + Duration::new(1611160427, 0));
properties.insert(SigningService::from_static("kinesis"));
properties.insert(OperationSigningConfig::default_config());
properties.insert(Credentials::for_tests());
properties.insert(SigningRegion::from(region.clone()));
properties.insert(deferred_signer_sender);
Result::<_, Infallible>::Ok(req)
})
.expect("succeeds");
let signer = SigV4SigningStage::new(SigV4Signer::new());
let _ = signer.apply(req).unwrap();
let mut signer_for_comparison = EventStreamSigV4Signer::new(
// This is the expected SigV4 signature for the HTTP request above
"abac477b4afabf5651079e7b9a0aa6a1a3e356a7418a81d974cdae9d4c8e5441".into(),
Credentials::for_tests(),
SigningRegion::from(region),
SigningService::from_static("kinesis"),
Some(UNIX_EPOCH + Duration::new(1611160427, 0)),
);
let expected_signed_empty = signer_for_comparison.sign_empty().unwrap().unwrap();
let actual_signed_empty = deferred_signer.sign_empty().unwrap().unwrap();
assert_eq!(expected_signed_empty, actual_signed_empty);
}
// check that the endpoint middleware followed by signing middleware produce the expected result
#[test]
fn endpoint_plus_signer() {

View File

@ -3,7 +3,6 @@
* SPDX-License-Identifier: Apache-2.0
*/
use crate::middleware::Signature;
use aws_credential_types::Credentials;
use aws_sigv4::http_request::{
sign, PayloadChecksumKind, PercentEncodingMode, SessionTokenMode, SignableRequest,
@ -15,6 +14,7 @@ use aws_types::SigningService;
use std::fmt;
use std::time::{Duration, SystemTime};
use crate::middleware::Signature;
pub use aws_sigv4::http_request::SignableBody;
pub type SigningError = aws_sigv4::http_request::SigningError;

View File

@ -21,6 +21,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegen
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientSection
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.MakeOperationGenerator
import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientHttpBoundProtocolPayloadGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.docs
@ -34,7 +35,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator
import software.amazon.smithy.rust.codegen.core.util.cloneOperation
import software.amazon.smithy.rust.codegen.core.util.expectTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
@ -179,7 +179,7 @@ class AwsInputPresignedMethod(
MakeOperationGenerator(
codegenContext,
protocol,
HttpBoundProtocolPayloadGenerator(codegenContext, protocol),
ClientHttpBoundProtocolPayloadGenerator(codegenContext, protocol),
// Prefixed with underscore to avoid colliding with modeled functions
functionName = makeOperationFn,
public = false,

View File

@ -16,7 +16,7 @@ import software.amazon.smithy.model.traits.OptionalAuthTrait
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.EventStreamSigningConfig
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
@ -77,17 +77,17 @@ class SigV4SigningDecorator : ClientCodegenDecorator {
}
class SigV4SigningConfig(
runtimeConfig: RuntimeConfig,
private val runtimeConfig: RuntimeConfig,
private val serviceHasEventStream: Boolean,
private val sigV4Trait: SigV4Trait,
) : EventStreamSigningConfig(runtimeConfig) {
private val codegenScope = arrayOf(
"SigV4Signer" to AwsRuntimeType.awsSigAuthEventStream(runtimeConfig).resolve("event_stream::SigV4Signer"),
)
override fun configImplSection(): Writable {
return writable {
rustTemplate(
) : ConfigCustomization() {
override fun section(section: ServiceConfig): Writable = writable {
if (section is ServiceConfig.ConfigImpl) {
if (serviceHasEventStream) {
// enable the aws-sig-auth `sign-eventstream` feature
addDependency(AwsRuntimeType.awsSigAuthEventStream(runtimeConfig).toSymbol())
}
rust(
"""
/// The signature version 4 service signing name to use in the credential scope when signing requests.
///
@ -97,24 +97,7 @@ class SigV4SigningConfig(
${sigV4Trait.name.dq()}
}
""",
*codegenScope,
)
if (serviceHasEventStream) {
rustTemplate(
"#{signerFn:W}",
"signerFn" to
renderEventStreamSignerFn { propertiesName ->
writable {
rustTemplate(
"""
#{SigV4Signer}::new($propertiesName)
""",
*codegenScope,
)
}
},
)
}
}
}
}
@ -209,5 +192,3 @@ class SigV4SigningFeature(
}
}
}
fun RuntimeConfig.sigAuth() = awsRuntimeCrate("aws-sig-auth")

View File

@ -12,7 +12,7 @@ import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
internal class SigV4SigningCustomizationTest {
internal class SigV4SigningDecoratorTest {
@Test
fun `generates a valid config`() {
val project = stubConfigProject(

View File

@ -15,7 +15,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.customizations.HttpAuth
import software.amazon.smithy.rust.codegen.client.smithy.customizations.HttpConnectorConfigDecorator
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.customize.CombinedClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.customize.NoOpEventStreamSigningDecorator
import software.amazon.smithy.rust.codegen.client.smithy.customize.RequiredCustomizations
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointParamsDecorator
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointsDecorator
@ -62,7 +61,6 @@ class RustClientCodegenPlugin : ClientDecoratableBuildPlugin() {
FluentClientDecorator(),
EndpointsDecorator(),
EndpointParamsDecorator(),
NoOpEventStreamSigningDecorator(),
ApiKeyAuthDecorator(),
HttpAuthDecorator(),
HttpConnectorConfigDecorator(),

View File

@ -1,66 +0,0 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.client.smithy.customize
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.EventStreamSigningConfig
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.util.hasEventStreamOperations
/**
* The NoOpEventStreamSigningDecorator:
* - adds a `new_event_stream_signer()` method to `config` to create an Event Stream NoOp signer
*/
open class NoOpEventStreamSigningDecorator : ClientCodegenDecorator {
override val name: String = "NoOpEventStreamSigning"
override val order: Byte = Byte.MIN_VALUE
private fun applies(codegenContext: CodegenContext, baseCustomizations: List<ConfigCustomization>): Boolean =
codegenContext.serviceShape.hasEventStreamOperations(codegenContext.model) &&
// and if there is no other `EventStreamSigningConfig`, apply this one
!baseCustomizations.any { it is EventStreamSigningConfig }
override fun configCustomizations(
codegenContext: ClientCodegenContext,
baseCustomizations: List<ConfigCustomization>,
): List<ConfigCustomization> {
if (!applies(codegenContext, baseCustomizations)) {
return baseCustomizations
}
return baseCustomizations + NoOpEventStreamSigningConfig(
codegenContext.serviceShape.hasEventStreamOperations(codegenContext.model),
codegenContext.runtimeConfig,
)
}
}
class NoOpEventStreamSigningConfig(
private val serviceHasEventStream: Boolean,
runtimeConfig: RuntimeConfig,
) : EventStreamSigningConfig(runtimeConfig) {
private val codegenScope = arrayOf(
"NoOpSigner" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::NoOpSigner"),
)
override fun configImplSection() = renderEventStreamSignerFn {
writable {
if (serviceHasEventStream) {
rustTemplate(
"""
#{NoOpSigner}{}
""",
*codegenScope,
)
}
}
}
}

View File

@ -1,46 +0,0 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rust.codegen.client.smithy.generators.config
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
open class EventStreamSigningConfig(
runtimeConfig: RuntimeConfig,
) : ConfigCustomization() {
private val codegenScope = arrayOf(
"SharedPropertyBag" to RuntimeType.smithyHttp(runtimeConfig).resolve("property_bag::SharedPropertyBag"),
"SignMessage" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::SignMessage"),
)
override fun section(section: ServiceConfig): Writable {
return when (section) {
is ServiceConfig.ConfigImpl -> configImplSection()
else -> emptySection
}
}
open fun configImplSection(): Writable = emptySection
fun renderEventStreamSignerFn(signerInstantiator: (String) -> Writable): Writable = writable {
rustTemplate(
"""
/// Creates a new Event Stream `SignMessage` implementor.
pub fn new_event_stream_signer(
&self,
_properties: #{SharedPropertyBag}
) -> impl #{SignMessage} {
#{signer:W}
}
""",
*codegenScope,
"signer" to signerInstantiator("_properties"),
)
}
}

View File

@ -13,6 +13,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.Cli
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.MakeOperationGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolParserGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
@ -23,6 +24,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.pre
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection
import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations
import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
@ -30,11 +32,11 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctio
import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.core.util.outputShape
// TODO(enableNewSmithyRuntime): Delete this class when cleaning up `enableNewSmithyRuntime`
// TODO(enableNewSmithyRuntime): Delete this class when cleaning up `enableNewSmithyRuntime` (replace with ClientProtocolGenerator)
class HttpBoundProtocolGenerator(
codegenContext: ClientCodegenContext,
protocol: Protocol,
bodyGenerator: ProtocolPayloadGenerator = HttpBoundProtocolPayloadGenerator(codegenContext, protocol),
bodyGenerator: ProtocolPayloadGenerator = ClientHttpBoundProtocolPayloadGenerator(codegenContext, protocol),
) : ClientProtocolGenerator(
codegenContext,
protocol,
@ -49,6 +51,35 @@ class HttpBoundProtocolGenerator(
HttpBoundProtocolTraitImplGenerator(codegenContext, protocol),
)
class ClientHttpBoundProtocolPayloadGenerator(
codegenContext: ClientCodegenContext,
protocol: Protocol,
) : ProtocolPayloadGenerator by HttpBoundProtocolPayloadGenerator(
codegenContext, protocol, HttpMessageType.REQUEST,
renderEventStreamBody = { writer, params ->
writer.rustTemplate(
"""
{
let error_marshaller = #{errorMarshallerConstructorFn}();
let marshaller = #{marshallerConstructorFn}();
let (signer, signer_sender) = #{DeferredSigner}::new();
properties.acquire_mut().insert(signer_sender);
let adapter: #{aws_smithy_http}::event_stream::MessageStreamAdapter<_, _> =
${params.outerName}.${params.memberName}.into_body_stream(marshaller, error_marshaller, signer);
let body: #{SdkBody} = #{hyper}::Body::wrap_stream(adapter).into();
body
}
""",
"hyper" to CargoDependency.HyperWithStream.toType(),
"SdkBody" to RuntimeType.sdkBody(codegenContext.runtimeConfig),
"aws_smithy_http" to RuntimeType.smithyHttp(codegenContext.runtimeConfig),
"DeferredSigner" to RuntimeType.smithyEventStream(codegenContext.runtimeConfig).resolve("frame::DeferredSigner"),
"marshallerConstructorFn" to params.marshallerConstructorFn,
"errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn,
)
},
)
// TODO(enableNewSmithyRuntime): Delete this class when cleaning up `enableNewSmithyRuntime`
open class HttpBoundProtocolTraitImplGenerator(
codegenContext: ClientCodegenContext,

View File

@ -10,7 +10,6 @@ import org.junit.jupiter.api.Test
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rust.codegen.client.smithy.customizations.ClientCustomizations
import software.amazon.smithy.rust.codegen.client.smithy.customize.CombinedClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.customize.NoOpEventStreamSigningDecorator
import software.amazon.smithy.rust.codegen.client.smithy.customize.RequiredCustomizations
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientDecorator
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
@ -52,7 +51,6 @@ class ClientCodegenVisitorTest {
ClientCustomizations(),
RequiredCustomizations(),
FluentClientDecorator(),
NoOpEventStreamSigningDecorator(),
)
val visitor = ClientCodegenVisitor(ctx, codegenDecorator)
val baselineModel = visitor.baselineTransform(model)

View File

@ -6,19 +6,9 @@
package software.amazon.smithy.rust.codegen.client.smithy.customizations
import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.EventStreamSigningConfig
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig
import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.integrationTest
@ -172,7 +162,6 @@ internal class HttpVersionListGeneratorTest {
clientIntegrationTest(
model,
IntegrationTestParams(addModuleToEventStreamAllowList = true),
additionalDecorators = listOf(FakeSigningDecorator()),
) { clientCodegenContext, rustCrate ->
val moduleName = clientCodegenContext.moduleUseName()
rustCrate.integrationTest("validate_eventstream_http") {
@ -196,77 +185,3 @@ internal class HttpVersionListGeneratorTest {
}
}
}
class FakeSigningDecorator : ClientCodegenDecorator {
override val name: String = "fakesigning"
override val order: Byte = 0
override fun classpathDiscoverable(): Boolean = false
override fun configCustomizations(
codegenContext: ClientCodegenContext,
baseCustomizations: List<ConfigCustomization>,
): List<ConfigCustomization> {
return baseCustomizations.filterNot {
it is EventStreamSigningConfig
} + FakeSigningConfig(codegenContext.runtimeConfig)
}
}
class FakeSigningConfig(
runtimeConfig: RuntimeConfig,
) : EventStreamSigningConfig(runtimeConfig) {
private val codegenScope = arrayOf(
"SharedPropertyBag" to RuntimeType.smithyHttp(runtimeConfig).resolve("property_bag::SharedPropertyBag"),
"SignMessageError" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::SignMessageError"),
"SignMessage" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::SignMessage"),
"Message" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::Message"),
)
override fun section(section: ServiceConfig): Writable {
return when (section) {
is ServiceConfig.ConfigImpl -> writable {
rustTemplate(
"""
/// Creates a new Event Stream `SignMessage` implementor.
pub fn new_event_stream_signer(
&self,
properties: #{SharedPropertyBag}
) -> FakeSigner {
FakeSigner::new(properties)
}
""",
*codegenScope,
)
}
is ServiceConfig.Extras -> writable {
rustTemplate(
"""
/// Fake signing implementation.
##[derive(Debug)]
pub struct FakeSigner;
impl FakeSigner {
/// Create a real `FakeSigner`
pub fn new(_properties: #{SharedPropertyBag}) -> Self {
Self {}
}
}
impl #{SignMessage} for FakeSigner {
fn sign(&mut self, message: #{Message}) -> Result<#{Message}, #{SignMessageError}> {
Ok(message)
}
fn sign_empty(&mut self) -> Option<Result<#{Message}, #{SignMessageError}>> {
Some(Ok(#{Message}::new(Vec::new())))
}
}
""",
*codegenScope,
)
}
else -> emptySection
}
}
}

View File

@ -38,7 +38,7 @@ private class TestProtocolPayloadGenerator(private val body: String) : ProtocolP
override fun payloadMetadata(operationShape: OperationShape) =
ProtocolPayloadGenerator.PayloadMetadata(takesOwnership = false)
override fun generatePayload(writer: RustWriter, self: String, operationShape: OperationShape) {
override fun generatePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) {
writer.writeWithNoFormatting(body)
}
}

View File

@ -36,13 +36,13 @@ interface ProtocolPayloadGenerator {
/**
* Write the payload into [writer].
*
* [self] is the name of the variable binding for the Rust struct that is to be serialized into the payload.
* [shapeName] is the name of the variable binding for the Rust struct that is to be serialized into the payload.
*
* This should be an expression that returns bytes:
* - a `Vec<u8>` for non-streaming operations; or
* - a `ByteStream` for streaming operations.
*/
fun generatePayload(writer: RustWriter, self: String, operationShape: OperationShape)
fun generatePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape)
}
/**

View File

@ -18,7 +18,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
@ -42,10 +41,18 @@ import software.amazon.smithy.rust.codegen.core.util.isOutputEventStream
import software.amazon.smithy.rust.codegen.core.util.isStreaming
import software.amazon.smithy.rust.codegen.core.util.outputShape
data class EventStreamBodyParams(
val outerName: String,
val memberName: String,
val marshallerConstructorFn: RuntimeType,
val errorMarshallerConstructorFn: RuntimeType,
)
class HttpBoundProtocolPayloadGenerator(
codegenContext: CodegenContext,
private val protocol: Protocol,
private val httpMessageType: HttpMessageType = HttpMessageType.REQUEST,
private val renderEventStreamBody: (RustWriter, EventStreamBodyParams) -> Unit,
) : ProtocolPayloadGenerator {
private val symbolProvider = codegenContext.symbolProvider
private val model = codegenContext.model
@ -91,38 +98,38 @@ class HttpBoundProtocolPayloadGenerator(
}
}
override fun generatePayload(writer: RustWriter, self: String, operationShape: OperationShape) {
override fun generatePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) {
when (httpMessageType) {
HttpMessageType.RESPONSE -> generateResponsePayload(writer, self, operationShape)
HttpMessageType.REQUEST -> generateRequestPayload(writer, self, operationShape)
HttpMessageType.RESPONSE -> generateResponsePayload(writer, shapeName, operationShape)
HttpMessageType.REQUEST -> generateRequestPayload(writer, shapeName, operationShape)
}
}
private fun generateRequestPayload(writer: RustWriter, self: String, operationShape: OperationShape) {
private fun generateRequestPayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) {
val payloadMemberName = httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName
if (payloadMemberName == null) {
val serializerGenerator = protocol.structuredDataSerializer()
generateStructureSerializer(writer, self, serializerGenerator.operationInputSerializer(operationShape))
generateStructureSerializer(writer, shapeName, serializerGenerator.operationInputSerializer(operationShape))
} else {
generatePayloadMemberSerializer(writer, self, operationShape, payloadMemberName)
generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName)
}
}
private fun generateResponsePayload(writer: RustWriter, self: String, operationShape: OperationShape) {
private fun generateResponsePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) {
val payloadMemberName = httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName
if (payloadMemberName == null) {
val serializerGenerator = protocol.structuredDataSerializer()
generateStructureSerializer(writer, self, serializerGenerator.operationOutputSerializer(operationShape))
generateStructureSerializer(writer, shapeName, serializerGenerator.operationOutputSerializer(operationShape))
} else {
generatePayloadMemberSerializer(writer, self, operationShape, payloadMemberName)
generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName)
}
}
private fun generatePayloadMemberSerializer(
writer: RustWriter,
self: String,
shapeName: String,
operationShape: OperationShape,
payloadMemberName: String,
) {
@ -131,7 +138,7 @@ class HttpBoundProtocolPayloadGenerator(
if (operationShape.isEventStream(model)) {
if (operationShape.isInputEventStream(model) && target == CodegenTarget.CLIENT) {
val payloadMember = operationShape.inputShape(model).expectMember(payloadMemberName)
writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, "self")
writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, shapeName)
} else if (operationShape.isOutputEventStream(model) && target == CodegenTarget.SERVER) {
val payloadMember = operationShape.outputShape(model).expectMember(payloadMemberName)
writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, "output")
@ -144,16 +151,16 @@ class HttpBoundProtocolPayloadGenerator(
HttpMessageType.RESPONSE -> operationShape.outputShape(model).expectMember(payloadMemberName)
HttpMessageType.REQUEST -> operationShape.inputShape(model).expectMember(payloadMemberName)
}
writer.serializeViaPayload(bodyMetadata, self, payloadMember, serializerGenerator)
writer.serializeViaPayload(bodyMetadata, shapeName, payloadMember, serializerGenerator)
}
}
private fun generateStructureSerializer(writer: RustWriter, self: String, serializer: RuntimeType?) {
private fun generateStructureSerializer(writer: RustWriter, shapeName: String, serializer: RuntimeType?) {
if (serializer == null) {
writer.rust("\"\"")
} else {
writer.rust(
"#T(&$self)?",
"#T(&$shapeName)?",
serializer,
)
}
@ -193,47 +200,20 @@ class HttpBoundProtocolPayloadGenerator(
// TODO(EventStream): [RPC] RPC protocols need to send an initial message with the
// parameters that are not `@eventHeader` or `@eventPayload`.
when (target) {
CodegenTarget.CLIENT ->
rustTemplate(
"""
{
let error_marshaller = #{errorMarshallerConstructorFn}();
let marshaller = #{marshallerConstructorFn}();
let signer = _config.new_event_stream_signer(properties.clone());
let adapter: #{SmithyHttp}::event_stream::MessageStreamAdapter<_, _> =
$outerName.$memberName.into_body_stream(marshaller, error_marshaller, signer);
let body: #{SdkBody} = #{hyper}::Body::wrap_stream(adapter).into();
body
}
""",
*codegenScope,
"marshallerConstructorFn" to marshallerConstructorFn,
"errorMarshallerConstructorFn" to errorMarshallerConstructorFn,
)
CodegenTarget.SERVER -> {
rustTemplate(
"""
{
let error_marshaller = #{errorMarshallerConstructorFn}();
let marshaller = #{marshallerConstructorFn}();
let signer = #{NoOpSigner}{};
let adapter: #{SmithyHttp}::event_stream::MessageStreamAdapter<_, _> =
$outerName.$memberName.into_body_stream(marshaller, error_marshaller, signer);
adapter
}
""",
*codegenScope,
"marshallerConstructorFn" to marshallerConstructorFn,
"errorMarshallerConstructorFn" to errorMarshallerConstructorFn,
)
}
}
renderEventStreamBody(
this,
EventStreamBodyParams(
outerName,
memberName,
marshallerConstructorFn,
errorMarshallerConstructorFn,
),
)
}
private fun RustWriter.serializeViaPayload(
payloadMetadata: ProtocolPayloadGenerator.PayloadMetadata,
self: String,
shapeName: String,
member: MemberShape,
serializerGenerator: StructuredDataSerializerGenerator,
) {
@ -281,7 +261,7 @@ class HttpBoundProtocolPayloadGenerator(
}
}
}
rust("#T($ref $self.${symbolProvider.toMemberName(member)})?", serializer)
rust("#T($ref $shapeName.${symbolProvider.toMemberName(member)})?", serializer)
}
private fun RustWriter.renderPayload(

View File

@ -41,6 +41,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization
@ -48,12 +49,14 @@ import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustom
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.mapRustType
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator
@ -114,6 +117,31 @@ class ServerHttpBoundProtocolGenerator(
}
}
class ServerHttpBoundProtocolPayloadGenerator(
codegenContext: CodegenContext,
protocol: Protocol,
) : ProtocolPayloadGenerator by HttpBoundProtocolPayloadGenerator(
codegenContext, protocol, HttpMessageType.RESPONSE,
renderEventStreamBody = { writer, params ->
writer.rustTemplate(
"""
{
let error_marshaller = #{errorMarshallerConstructorFn}();
let marshaller = #{marshallerConstructorFn}();
let signer = #{NoOpSigner}{};
let adapter: #{aws_smithy_http}::event_stream::MessageStreamAdapter<_, _> =
${params.outerName}.${params.memberName}.into_body_stream(marshaller, error_marshaller, signer);
adapter
}
""",
"aws_smithy_http" to RuntimeType.smithyHttp(codegenContext.runtimeConfig),
"NoOpSigner" to RuntimeType.smithyEventStream(codegenContext.runtimeConfig).resolve("frame::NoOpSigner"),
"marshallerConstructorFn" to params.marshallerConstructorFn,
"errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn,
)
},
)
/*
* Generate all operation input parsers and output serializers for streaming and
* non-streaming types.
@ -504,12 +532,12 @@ class ServerHttpBoundProtocolTraitImplGenerator(
?: serverRenderHttpResponseCode(httpTraitStatusCode)(this)
operationShape.outputShape(model).findStreamingMember(model)?.let {
val payloadGenerator = HttpBoundProtocolPayloadGenerator(codegenContext, protocol, httpMessageType = HttpMessageType.RESPONSE)
val payloadGenerator = ServerHttpBoundProtocolPayloadGenerator(codegenContext, protocol)
withBlockTemplate("let body = #{SmithyHttpServer}::body::boxed(#{SmithyHttpServer}::body::Body::wrap_stream(", "));", *codegenScope) {
payloadGenerator.generatePayload(this, "output", operationShape)
}
} ?: run {
val payloadGenerator = HttpBoundProtocolPayloadGenerator(codegenContext, protocol, httpMessageType = HttpMessageType.RESPONSE)
val payloadGenerator = ServerHttpBoundProtocolPayloadGenerator(codegenContext, protocol)
withBlockTemplate("let payload = ", ";") {
payloadGenerator.generatePayload(this, "output", operationShape)
}

View File

@ -14,6 +14,7 @@ use std::convert::{TryFrom, TryInto};
use std::error::Error as StdError;
use std::fmt;
use std::mem::size_of;
use std::sync::{mpsc, Mutex};
const PRELUDE_LENGTH_BYTES: u32 = 3 * size_of::<u32>() as u32;
const PRELUDE_LENGTH_BYTES_USIZE: usize = PRELUDE_LENGTH_BYTES as usize;
@ -34,6 +35,95 @@ pub trait SignMessage: fmt::Debug {
fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>>;
}
/// A sender that gets placed in the request config to wire up an event stream signer after signing.
#[derive(Debug)]
#[non_exhaustive]
pub struct DeferredSignerSender(Mutex<mpsc::Sender<Box<dyn SignMessage + Send + Sync>>>);
impl DeferredSignerSender {
/// Creates a new `DeferredSignerSender`
fn new(tx: mpsc::Sender<Box<dyn SignMessage + Send + Sync>>) -> Self {
Self(Mutex::new(tx))
}
/// Sends a signer on the channel
pub fn send(
&self,
signer: Box<dyn SignMessage + Send + Sync>,
) -> Result<(), mpsc::SendError<Box<dyn SignMessage + Send + Sync>>> {
self.0.lock().unwrap().send(signer)
}
}
/// Deferred event stream signer to allow a signer to be wired up later.
///
/// HTTP request signing takes place after serialization, and the event stream
/// message stream body is established during serialization. Since event stream
/// signing may need context from the initial HTTP signing operation, this
/// [`DeferredSigner`] is needed to wire up the signer later in the request lifecycle.
///
/// This signer basically just establishes a MPSC channel so that the sender can
/// be placed in the request's config. Then the HTTP signer implementation can
/// retrieve the sender from that config and send an actual signing implementation
/// with all the context needed.
///
/// When an event stream implementation needs to sign a message, the first call to
/// sign will acquire a signing implementation off of the channel and cache it
/// for the remainder of the operation.
#[derive(Debug)]
pub struct DeferredSigner {
rx: Option<Mutex<mpsc::Receiver<Box<dyn SignMessage + Send + Sync>>>>,
signer: Option<Box<dyn SignMessage + Send + Sync>>,
}
impl DeferredSigner {
pub fn new() -> (Self, DeferredSignerSender) {
let (tx, rx) = mpsc::channel();
(
Self {
rx: Some(Mutex::new(rx)),
signer: None,
},
DeferredSignerSender::new(tx),
)
}
fn acquire(&mut self) -> &mut (dyn SignMessage + Send + Sync) {
// Can't use `if let Some(signer) = &mut self.signer` because the borrow checker isn't smart enough
if self.signer.is_some() {
return self.signer.as_mut().unwrap().as_mut();
} else {
self.signer = Some(
self.rx
.take()
.expect("only taken once")
.lock()
.unwrap()
.try_recv()
.ok()
// TODO(enableNewSmithyRuntime): When the middleware implementation is removed,
// this should panic rather than default to the `NoOpSigner`. The reason it defaults
// is because middleware-based generic clients don't have any default middleware,
// so there is no way to send a `NoOpSigner` by default when there is no other
// auth scheme. The orchestrator auth setup is a lot more robust and will make
// this problem trivial.
.unwrap_or_else(|| Box::new(NoOpSigner {}) as _),
);
self.acquire()
}
}
}
impl SignMessage for DeferredSigner {
fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
self.acquire().sign(message)
}
fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
self.acquire().sign_empty()
}
}
#[derive(Debug)]
pub struct NoOpSigner {}
impl SignMessage for NoOpSigner {
@ -848,3 +938,60 @@ mod message_frame_decoder_tests {
}
}
}
#[cfg(test)]
mod deferred_signer_tests {
use crate::frame::{DeferredSigner, Header, HeaderValue, Message, SignMessage};
use bytes::Bytes;
fn check_send_sync<T: Send + Sync>(value: T) -> T {
value
}
#[test]
fn deferred_signer() {
#[derive(Default, Debug)]
struct TestSigner {
call_num: i32,
}
impl SignMessage for TestSigner {
fn sign(
&mut self,
message: crate::frame::Message,
) -> Result<crate::frame::Message, crate::frame::SignMessageError> {
self.call_num += 1;
Ok(message.add_header(Header::new("call_num", HeaderValue::Int32(self.call_num))))
}
fn sign_empty(
&mut self,
) -> Option<Result<crate::frame::Message, crate::frame::SignMessageError>> {
None
}
}
let (mut signer, sender) = check_send_sync(DeferredSigner::new());
sender
.send(Box::new(TestSigner::default()))
.expect("success");
let message = signer.sign(Message::new(Bytes::new())).expect("success");
assert_eq!(1, message.headers()[0].value().as_int32().unwrap());
let message = signer.sign(Message::new(Bytes::new())).expect("success");
assert_eq!(2, message.headers()[0].value().as_int32().unwrap());
assert!(signer.sign_empty().is_none());
}
#[test]
fn deferred_signer_defaults_to_noop_signer() {
let (mut signer, _sender) = DeferredSigner::new();
assert_eq!(
Message::new(Bytes::new()),
signer.sign(Message::new(Bytes::new())).unwrap()
);
assert!(signer.sign_empty().is_none());
}
}