diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index f6dec54418..86fe2d7fe1 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -34,3 +34,19 @@ message = "`SsoCredentialsProvider`, `AssumeRoleProvider`, and `WebIdentityToken references = ["smithy-rs#2720"] meta = { "breaking" = false, "tada" = false, "bug" = true } author = "ysaito1001" + +[[smithy-rs]] +message = """ +
+Breaking change in how event stream signing works (click to expand more details) + +This change will only impact you if you are wiring up their own event stream signing/authentication scheme. If you're using `aws-sig-auth` to use AWS SigV4 event stream signing, then this change will **not** impact you. + +Previously, event stream signing was configured at codegen time by placing a `new_event_stream_signer` method on the `Config`. This function was called at serialization time to connect the signer to the streaming body. Now, instead, a special `DeferredSigner` is wired up at serialization time that relies on a signing implementation to be sent on a channel by the HTTP request signer. To do this, a `DeferredSignerSender` must be pulled out of the property bag, and its `send()` method called with the desired event stream signing implementation. + +See the changes in https://github.com/awslabs/smithy-rs/pull/2671 for an example of how this was done for SigV4. +
+""" +references = ["smithy-rs#2671"] +meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" } +author = "jdisanti" diff --git a/aws/rust-runtime/aws-sig-auth/src/event_stream.rs b/aws/rust-runtime/aws-sig-auth/src/event_stream.rs index 74b8c65dd4..e6e06dac75 100644 --- a/aws/rust-runtime/aws-sig-auth/src/event_stream.rs +++ b/aws/rust-runtime/aws-sig-auth/src/event_stream.rs @@ -3,6 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ +// TODO(enableNewSmithyRuntime): Remove this blanket allow once the old implementations are deleted +#![allow(deprecated)] + use crate::middleware::Signature; use aws_credential_types::Credentials; use aws_sigv4::event_stream::{sign_empty_message, sign_message}; @@ -15,6 +18,115 @@ use std::time::SystemTime; /// Event Stream SigV4 signing implementation. #[derive(Debug)] +pub struct SigV4MessageSigner { + last_signature: String, + credentials: Credentials, + signing_region: SigningRegion, + signing_service: SigningService, + time: Option, +} + +impl SigV4MessageSigner { + pub fn new( + last_signature: String, + credentials: Credentials, + signing_region: SigningRegion, + signing_service: SigningService, + time: Option, + ) -> Self { + Self { + last_signature, + credentials, + signing_region, + signing_service, + time, + } + } + + fn signing_params(&self) -> SigningParams<()> { + let mut builder = SigningParams::builder() + .access_key(self.credentials.access_key_id()) + .secret_key(self.credentials.secret_access_key()) + .region(self.signing_region.as_ref()) + .service_name(self.signing_service.as_ref()) + .time(self.time.unwrap_or_else(SystemTime::now)) + .settings(()); + builder.set_security_token(self.credentials.session_token()); + builder.build().unwrap() + } +} + +impl SignMessage for SigV4MessageSigner { + fn sign(&mut self, message: Message) -> Result { + let (signed_message, signature) = { + let params = self.signing_params(); + sign_message(&message, &self.last_signature, ¶ms).into_parts() + }; + self.last_signature = signature; + Ok(signed_message) + } + + fn sign_empty(&mut self) -> Option> { + let (signed_message, signature) = { + let params = self.signing_params(); + sign_empty_message(&self.last_signature, ¶ms).into_parts() + }; + self.last_signature = signature; + Some(Ok(signed_message)) + } +} + +#[cfg(test)] +mod tests { + use crate::event_stream::SigV4MessageSigner; + use aws_credential_types::Credentials; + use aws_smithy_eventstream::frame::{HeaderValue, Message, SignMessage}; + use aws_types::region::Region; + use aws_types::region::SigningRegion; + use aws_types::SigningService; + use std::time::{Duration, UNIX_EPOCH}; + + fn check_send_sync(value: T) -> T { + value + } + + #[test] + fn sign_message() { + let region = Region::new("us-east-1"); + let mut signer = check_send_sync(SigV4MessageSigner::new( + "initial-signature".into(), + Credentials::for_tests(), + SigningRegion::from(region), + SigningService::from_static("transcribe"), + Some(UNIX_EPOCH + Duration::new(1611160427, 0)), + )); + let mut signatures = Vec::new(); + for _ in 0..5 { + let signed = signer + .sign(Message::new(&b"identical message"[..])) + .unwrap(); + if let HeaderValue::ByteArray(signature) = signed + .headers() + .iter() + .find(|h| h.name().as_str() == ":chunk-signature") + .unwrap() + .value() + { + signatures.push(signature.clone()); + } else { + panic!("failed to get the :chunk-signature") + } + } + for i in 1..signatures.len() { + assert_ne!(signatures[i - 1], signatures[i]); + } + } +} + +// TODO(enableNewSmithyRuntime): Delete this old implementation that was kept around to support patch releases. +#[deprecated = "use aws_sig_auth::event_stream::SigV4MessageSigner instead (this may require upgrading the smithy-rs code generator)"] +#[derive(Debug)] +/// Event Stream SigV4 signing implementation. pub struct SigV4Signer { properties: SharedPropertyBag, last_signature: Option, @@ -87,8 +199,9 @@ impl SignMessage for SigV4Signer { } } +// TODO(enableNewSmithyRuntime): Delete this old implementation that was kept around to support patch releases. #[cfg(test)] -mod tests { +mod old_tests { use crate::event_stream::SigV4Signer; use crate::middleware::Signature; use aws_credential_types::Credentials; diff --git a/aws/rust-runtime/aws-sig-auth/src/middleware.rs b/aws/rust-runtime/aws-sig-auth/src/middleware.rs index d7ec53454c..c8c4794f0b 100644 --- a/aws/rust-runtime/aws-sig-auth/src/middleware.rs +++ b/aws/rust-runtime/aws-sig-auth/src/middleware.rs @@ -20,8 +20,15 @@ use crate::signer::{ OperationSigningConfig, RequestConfig, SigV4Signer, SigningError, SigningRequirements, }; +#[cfg(feature = "sign-eventstream")] +use crate::event_stream::SigV4MessageSigner as EventStreamSigV4Signer; +#[cfg(feature = "sign-eventstream")] +use aws_smithy_eventstream::frame::DeferredSignerSender; + +// TODO(enableNewSmithyRuntime): Delete `Signature` when switching to the orchestrator /// Container for the request signature for use in the property bag. #[non_exhaustive] +#[derive(Debug, Clone)] pub struct Signature(String); impl Signature { @@ -181,6 +188,22 @@ impl MapRequest for SigV4SigningStage { .signer .sign(operation_config, &request_config, &creds, &mut req) .map_err(SigningStageErrorKind::SigningFailure)?; + + // If this is an event stream operation, set up the event stream signer + #[cfg(feature = "sign-eventstream")] + if let Some(signer_sender) = config.get::() { + let time_override = config.get::().copied(); + signer_sender + .send(Box::new(EventStreamSigV4Signer::new( + signature.as_ref().into(), + creds, + request_config.region.clone(), + request_config.service.clone(), + time_override, + )) as _) + .expect("failed to send deferred signer"); + } + config.insert(signature); Ok(req) }) @@ -234,6 +257,49 @@ mod test { assert!(signature.is_some()); } + #[cfg(feature = "sign-eventstream")] + #[test] + fn sends_event_stream_signer_for_event_stream_operations() { + use crate::event_stream::SigV4MessageSigner as EventStreamSigV4Signer; + use aws_smithy_eventstream::frame::{DeferredSigner, SignMessage}; + use std::time::SystemTime; + + let (mut deferred_signer, deferred_signer_sender) = DeferredSigner::new(); + let req = http::Request::builder() + .uri("https://test-service.test-region.amazonaws.com/") + .body(SdkBody::from("")) + .unwrap(); + let region = Region::new("us-east-1"); + let req = operation::Request::new(req) + .augment(|req, properties| { + properties.insert(region.clone()); + properties.insert::(UNIX_EPOCH + Duration::new(1611160427, 0)); + properties.insert(SigningService::from_static("kinesis")); + properties.insert(OperationSigningConfig::default_config()); + properties.insert(Credentials::for_tests()); + properties.insert(SigningRegion::from(region.clone())); + properties.insert(deferred_signer_sender); + Result::<_, Infallible>::Ok(req) + }) + .expect("succeeds"); + + let signer = SigV4SigningStage::new(SigV4Signer::new()); + let _ = signer.apply(req).unwrap(); + + let mut signer_for_comparison = EventStreamSigV4Signer::new( + // This is the expected SigV4 signature for the HTTP request above + "abac477b4afabf5651079e7b9a0aa6a1a3e356a7418a81d974cdae9d4c8e5441".into(), + Credentials::for_tests(), + SigningRegion::from(region), + SigningService::from_static("kinesis"), + Some(UNIX_EPOCH + Duration::new(1611160427, 0)), + ); + + let expected_signed_empty = signer_for_comparison.sign_empty().unwrap().unwrap(); + let actual_signed_empty = deferred_signer.sign_empty().unwrap().unwrap(); + assert_eq!(expected_signed_empty, actual_signed_empty); + } + // check that the endpoint middleware followed by signing middleware produce the expected result #[test] fn endpoint_plus_signer() { diff --git a/aws/rust-runtime/aws-sig-auth/src/signer.rs b/aws/rust-runtime/aws-sig-auth/src/signer.rs index a1d36c97ca..d71c6ecf42 100644 --- a/aws/rust-runtime/aws-sig-auth/src/signer.rs +++ b/aws/rust-runtime/aws-sig-auth/src/signer.rs @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -use crate::middleware::Signature; use aws_credential_types::Credentials; use aws_sigv4::http_request::{ sign, PayloadChecksumKind, PercentEncodingMode, SessionTokenMode, SignableRequest, @@ -15,6 +14,7 @@ use aws_types::SigningService; use std::fmt; use std::time::{Duration, SystemTime}; +use crate::middleware::Signature; pub use aws_sigv4::http_request::SignableBody; pub type SigningError = aws_sigv4::http_request::SigningError; diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt index ee5f82aff6..4c9f6feffc 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt @@ -21,6 +21,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegen import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientSection import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.MakeOperationGenerator +import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientHttpBoundProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.docs @@ -34,7 +35,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection -import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.util.cloneOperation import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait @@ -179,7 +179,7 @@ class AwsInputPresignedMethod( MakeOperationGenerator( codegenContext, protocol, - HttpBoundProtocolPayloadGenerator(codegenContext, protocol), + ClientHttpBoundProtocolPayloadGenerator(codegenContext, protocol), // Prefixed with underscore to avoid colliding with modeled functions functionName = makeOperationFn, public = false, diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt index 81400d6597..3c40b518eb 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt @@ -16,7 +16,7 @@ import software.amazon.smithy.model.traits.OptionalAuthTrait import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization -import software.amazon.smithy.rust.codegen.client.smithy.generators.config.EventStreamSigningConfig +import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate @@ -77,17 +77,17 @@ class SigV4SigningDecorator : ClientCodegenDecorator { } class SigV4SigningConfig( - runtimeConfig: RuntimeConfig, + private val runtimeConfig: RuntimeConfig, private val serviceHasEventStream: Boolean, private val sigV4Trait: SigV4Trait, -) : EventStreamSigningConfig(runtimeConfig) { - private val codegenScope = arrayOf( - "SigV4Signer" to AwsRuntimeType.awsSigAuthEventStream(runtimeConfig).resolve("event_stream::SigV4Signer"), - ) - - override fun configImplSection(): Writable { - return writable { - rustTemplate( +) : ConfigCustomization() { + override fun section(section: ServiceConfig): Writable = writable { + if (section is ServiceConfig.ConfigImpl) { + if (serviceHasEventStream) { + // enable the aws-sig-auth `sign-eventstream` feature + addDependency(AwsRuntimeType.awsSigAuthEventStream(runtimeConfig).toSymbol()) + } + rust( """ /// The signature version 4 service signing name to use in the credential scope when signing requests. /// @@ -97,24 +97,7 @@ class SigV4SigningConfig( ${sigV4Trait.name.dq()} } """, - *codegenScope, ) - if (serviceHasEventStream) { - rustTemplate( - "#{signerFn:W}", - "signerFn" to - renderEventStreamSignerFn { propertiesName -> - writable { - rustTemplate( - """ - #{SigV4Signer}::new($propertiesName) - """, - *codegenScope, - ) - } - }, - ) - } } } } @@ -209,5 +192,3 @@ class SigV4SigningFeature( } } } - -fun RuntimeConfig.sigAuth() = awsRuntimeCrate("aws-sig-auth") diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4SigningCustomizationTest.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecoratorTest.kt similarity index 96% rename from aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4SigningCustomizationTest.kt rename to aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecoratorTest.kt index 9de7c91f65..71bd5eaf6c 100644 --- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4SigningCustomizationTest.kt +++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecoratorTest.kt @@ -12,7 +12,7 @@ import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.unitTest -internal class SigV4SigningCustomizationTest { +internal class SigV4SigningDecoratorTest { @Test fun `generates a valid config`() { val project = stubConfigProject( diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt index faa3a01c5d..08ba90d140 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt @@ -15,7 +15,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.customizations.HttpAuth import software.amazon.smithy.rust.codegen.client.smithy.customizations.HttpConnectorConfigDecorator import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator import software.amazon.smithy.rust.codegen.client.smithy.customize.CombinedClientCodegenDecorator -import software.amazon.smithy.rust.codegen.client.smithy.customize.NoOpEventStreamSigningDecorator import software.amazon.smithy.rust.codegen.client.smithy.customize.RequiredCustomizations import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointParamsDecorator import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointsDecorator @@ -62,7 +61,6 @@ class RustClientCodegenPlugin : ClientDecoratableBuildPlugin() { FluentClientDecorator(), EndpointsDecorator(), EndpointParamsDecorator(), - NoOpEventStreamSigningDecorator(), ApiKeyAuthDecorator(), HttpAuthDecorator(), HttpConnectorConfigDecorator(), diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/NoOpEventStreamSigningDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/NoOpEventStreamSigningDecorator.kt deleted file mode 100644 index 7d924ee75a..0000000000 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/NoOpEventStreamSigningDecorator.kt +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.client.smithy.customize - -import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext -import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization -import software.amazon.smithy.rust.codegen.client.smithy.generators.config.EventStreamSigningConfig -import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.util.hasEventStreamOperations - -/** - * The NoOpEventStreamSigningDecorator: - * - adds a `new_event_stream_signer()` method to `config` to create an Event Stream NoOp signer - */ -open class NoOpEventStreamSigningDecorator : ClientCodegenDecorator { - override val name: String = "NoOpEventStreamSigning" - override val order: Byte = Byte.MIN_VALUE - - private fun applies(codegenContext: CodegenContext, baseCustomizations: List): Boolean = - codegenContext.serviceShape.hasEventStreamOperations(codegenContext.model) && - // and if there is no other `EventStreamSigningConfig`, apply this one - !baseCustomizations.any { it is EventStreamSigningConfig } - - override fun configCustomizations( - codegenContext: ClientCodegenContext, - baseCustomizations: List, - ): List { - if (!applies(codegenContext, baseCustomizations)) { - return baseCustomizations - } - return baseCustomizations + NoOpEventStreamSigningConfig( - codegenContext.serviceShape.hasEventStreamOperations(codegenContext.model), - codegenContext.runtimeConfig, - ) - } -} - -class NoOpEventStreamSigningConfig( - private val serviceHasEventStream: Boolean, - runtimeConfig: RuntimeConfig, -) : EventStreamSigningConfig(runtimeConfig) { - - private val codegenScope = arrayOf( - "NoOpSigner" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::NoOpSigner"), - ) - - override fun configImplSection() = renderEventStreamSignerFn { - writable { - if (serviceHasEventStream) { - rustTemplate( - """ - #{NoOpSigner}{} - """, - *codegenScope, - ) - } - } - } -} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/EventStreamSigningConfig.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/EventStreamSigningConfig.kt deleted file mode 100644 index 35da7d63b0..0000000000 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/EventStreamSigningConfig.kt +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.client.smithy.generators.config - -import software.amazon.smithy.rust.codegen.core.rustlang.Writable -import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType - -open class EventStreamSigningConfig( - runtimeConfig: RuntimeConfig, -) : ConfigCustomization() { - private val codegenScope = arrayOf( - "SharedPropertyBag" to RuntimeType.smithyHttp(runtimeConfig).resolve("property_bag::SharedPropertyBag"), - "SignMessage" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::SignMessage"), - ) - - override fun section(section: ServiceConfig): Writable { - return when (section) { - is ServiceConfig.ConfigImpl -> configImplSection() - else -> emptySection - } - } - - open fun configImplSection(): Writable = emptySection - - fun renderEventStreamSignerFn(signerInstantiator: (String) -> Writable): Writable = writable { - rustTemplate( - """ - /// Creates a new Event Stream `SignMessage` implementor. - pub fn new_event_stream_signer( - &self, - _properties: #{SharedPropertyBag} - ) -> impl #{SignMessage} { - #{signer:W} - } - """, - *codegenScope, - "signer" to signerInstantiator("_properties"), - ) - } -} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt index ae83e78cfe..69e0bdd480 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt @@ -13,6 +13,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.Cli import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.MakeOperationGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolParserGenerator import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate @@ -23,6 +24,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.pre import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations +import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol @@ -30,11 +32,11 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctio import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember import software.amazon.smithy.rust.codegen.core.util.outputShape -// TODO(enableNewSmithyRuntime): Delete this class when cleaning up `enableNewSmithyRuntime` +// TODO(enableNewSmithyRuntime): Delete this class when cleaning up `enableNewSmithyRuntime` (replace with ClientProtocolGenerator) class HttpBoundProtocolGenerator( codegenContext: ClientCodegenContext, protocol: Protocol, - bodyGenerator: ProtocolPayloadGenerator = HttpBoundProtocolPayloadGenerator(codegenContext, protocol), + bodyGenerator: ProtocolPayloadGenerator = ClientHttpBoundProtocolPayloadGenerator(codegenContext, protocol), ) : ClientProtocolGenerator( codegenContext, protocol, @@ -49,6 +51,35 @@ class HttpBoundProtocolGenerator( HttpBoundProtocolTraitImplGenerator(codegenContext, protocol), ) +class ClientHttpBoundProtocolPayloadGenerator( + codegenContext: ClientCodegenContext, + protocol: Protocol, +) : ProtocolPayloadGenerator by HttpBoundProtocolPayloadGenerator( + codegenContext, protocol, HttpMessageType.REQUEST, + renderEventStreamBody = { writer, params -> + writer.rustTemplate( + """ + { + let error_marshaller = #{errorMarshallerConstructorFn}(); + let marshaller = #{marshallerConstructorFn}(); + let (signer, signer_sender) = #{DeferredSigner}::new(); + properties.acquire_mut().insert(signer_sender); + let adapter: #{aws_smithy_http}::event_stream::MessageStreamAdapter<_, _> = + ${params.outerName}.${params.memberName}.into_body_stream(marshaller, error_marshaller, signer); + let body: #{SdkBody} = #{hyper}::Body::wrap_stream(adapter).into(); + body + } + """, + "hyper" to CargoDependency.HyperWithStream.toType(), + "SdkBody" to RuntimeType.sdkBody(codegenContext.runtimeConfig), + "aws_smithy_http" to RuntimeType.smithyHttp(codegenContext.runtimeConfig), + "DeferredSigner" to RuntimeType.smithyEventStream(codegenContext.runtimeConfig).resolve("frame::DeferredSigner"), + "marshallerConstructorFn" to params.marshallerConstructorFn, + "errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn, + ) + }, +) + // TODO(enableNewSmithyRuntime): Delete this class when cleaning up `enableNewSmithyRuntime` open class HttpBoundProtocolTraitImplGenerator( codegenContext: ClientCodegenContext, diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitorTest.kt index 2cb21423fd..f9cf52375b 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitorTest.kt @@ -10,7 +10,6 @@ import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.client.smithy.customizations.ClientCustomizations import software.amazon.smithy.rust.codegen.client.smithy.customize.CombinedClientCodegenDecorator -import software.amazon.smithy.rust.codegen.client.smithy.customize.NoOpEventStreamSigningDecorator import software.amazon.smithy.rust.codegen.client.smithy.customize.RequiredCustomizations import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientDecorator import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel @@ -52,7 +51,6 @@ class ClientCodegenVisitorTest { ClientCustomizations(), RequiredCustomizations(), FluentClientDecorator(), - NoOpEventStreamSigningDecorator(), ) val visitor = ClientCodegenVisitor(ctx, codegenDecorator) val baselineModel = visitor.baselineTransform(model) diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpVersionListGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpVersionListGeneratorTest.kt index 3b5fc70a6f..73633a603e 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpVersionListGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpVersionListGeneratorTest.kt @@ -6,19 +6,9 @@ package software.amazon.smithy.rust.codegen.client.smithy.customizations import org.junit.jupiter.api.Test -import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext -import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator -import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization -import software.amazon.smithy.rust.codegen.client.smithy.generators.config.EventStreamSigningConfig -import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.integrationTest @@ -172,7 +162,6 @@ internal class HttpVersionListGeneratorTest { clientIntegrationTest( model, IntegrationTestParams(addModuleToEventStreamAllowList = true), - additionalDecorators = listOf(FakeSigningDecorator()), ) { clientCodegenContext, rustCrate -> val moduleName = clientCodegenContext.moduleUseName() rustCrate.integrationTest("validate_eventstream_http") { @@ -196,77 +185,3 @@ internal class HttpVersionListGeneratorTest { } } } - -class FakeSigningDecorator : ClientCodegenDecorator { - override val name: String = "fakesigning" - override val order: Byte = 0 - override fun classpathDiscoverable(): Boolean = false - override fun configCustomizations( - codegenContext: ClientCodegenContext, - baseCustomizations: List, - ): List { - return baseCustomizations.filterNot { - it is EventStreamSigningConfig - } + FakeSigningConfig(codegenContext.runtimeConfig) - } -} - -class FakeSigningConfig( - runtimeConfig: RuntimeConfig, -) : EventStreamSigningConfig(runtimeConfig) { - private val codegenScope = arrayOf( - "SharedPropertyBag" to RuntimeType.smithyHttp(runtimeConfig).resolve("property_bag::SharedPropertyBag"), - "SignMessageError" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::SignMessageError"), - "SignMessage" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::SignMessage"), - "Message" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::Message"), - ) - - override fun section(section: ServiceConfig): Writable { - return when (section) { - is ServiceConfig.ConfigImpl -> writable { - rustTemplate( - """ - /// Creates a new Event Stream `SignMessage` implementor. - pub fn new_event_stream_signer( - &self, - properties: #{SharedPropertyBag} - ) -> FakeSigner { - FakeSigner::new(properties) - } - """, - *codegenScope, - ) - } - - is ServiceConfig.Extras -> writable { - rustTemplate( - """ - /// Fake signing implementation. - ##[derive(Debug)] - pub struct FakeSigner; - - impl FakeSigner { - /// Create a real `FakeSigner` - pub fn new(_properties: #{SharedPropertyBag}) -> Self { - Self {} - } - } - - impl #{SignMessage} for FakeSigner { - fn sign(&mut self, message: #{Message}) -> Result<#{Message}, #{SignMessageError}> { - Ok(message) - } - - fn sign_empty(&mut self) -> Option> { - Some(Ok(#{Message}::new(Vec::new()))) - } - } - """, - *codegenScope, - ) - } - - else -> emptySection - } - } -} diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt index d0d271231a..b03f374403 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt @@ -38,7 +38,7 @@ private class TestProtocolPayloadGenerator(private val body: String) : ProtocolP override fun payloadMetadata(operationShape: OperationShape) = ProtocolPayloadGenerator.PayloadMetadata(takesOwnership = false) - override fun generatePayload(writer: RustWriter, self: String, operationShape: OperationShape) { + override fun generatePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) { writer.writeWithNoFormatting(body) } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolGenerator.kt index 4f410bf03a..ceb385c6d1 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolGenerator.kt @@ -36,13 +36,13 @@ interface ProtocolPayloadGenerator { /** * Write the payload into [writer]. * - * [self] is the name of the variable binding for the Rust struct that is to be serialized into the payload. + * [shapeName] is the name of the variable binding for the Rust struct that is to be serialized into the payload. * * This should be an expression that returns bytes: * - a `Vec` for non-streaming operations; or * - a `ByteStream` for streaming operations. */ - fun generatePayload(writer: RustWriter, self: String, operationShape: OperationShape) + fun generatePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) } /** diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt index 6d4e7bf850..1e17beca4e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt @@ -18,7 +18,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate -import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext @@ -42,10 +41,18 @@ import software.amazon.smithy.rust.codegen.core.util.isOutputEventStream import software.amazon.smithy.rust.codegen.core.util.isStreaming import software.amazon.smithy.rust.codegen.core.util.outputShape +data class EventStreamBodyParams( + val outerName: String, + val memberName: String, + val marshallerConstructorFn: RuntimeType, + val errorMarshallerConstructorFn: RuntimeType, +) + class HttpBoundProtocolPayloadGenerator( codegenContext: CodegenContext, private val protocol: Protocol, private val httpMessageType: HttpMessageType = HttpMessageType.REQUEST, + private val renderEventStreamBody: (RustWriter, EventStreamBodyParams) -> Unit, ) : ProtocolPayloadGenerator { private val symbolProvider = codegenContext.symbolProvider private val model = codegenContext.model @@ -91,38 +98,38 @@ class HttpBoundProtocolPayloadGenerator( } } - override fun generatePayload(writer: RustWriter, self: String, operationShape: OperationShape) { + override fun generatePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) { when (httpMessageType) { - HttpMessageType.RESPONSE -> generateResponsePayload(writer, self, operationShape) - HttpMessageType.REQUEST -> generateRequestPayload(writer, self, operationShape) + HttpMessageType.RESPONSE -> generateResponsePayload(writer, shapeName, operationShape) + HttpMessageType.REQUEST -> generateRequestPayload(writer, shapeName, operationShape) } } - private fun generateRequestPayload(writer: RustWriter, self: String, operationShape: OperationShape) { + private fun generateRequestPayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) { val payloadMemberName = httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName if (payloadMemberName == null) { val serializerGenerator = protocol.structuredDataSerializer() - generateStructureSerializer(writer, self, serializerGenerator.operationInputSerializer(operationShape)) + generateStructureSerializer(writer, shapeName, serializerGenerator.operationInputSerializer(operationShape)) } else { - generatePayloadMemberSerializer(writer, self, operationShape, payloadMemberName) + generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName) } } - private fun generateResponsePayload(writer: RustWriter, self: String, operationShape: OperationShape) { + private fun generateResponsePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) { val payloadMemberName = httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName if (payloadMemberName == null) { val serializerGenerator = protocol.structuredDataSerializer() - generateStructureSerializer(writer, self, serializerGenerator.operationOutputSerializer(operationShape)) + generateStructureSerializer(writer, shapeName, serializerGenerator.operationOutputSerializer(operationShape)) } else { - generatePayloadMemberSerializer(writer, self, operationShape, payloadMemberName) + generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName) } } private fun generatePayloadMemberSerializer( writer: RustWriter, - self: String, + shapeName: String, operationShape: OperationShape, payloadMemberName: String, ) { @@ -131,7 +138,7 @@ class HttpBoundProtocolPayloadGenerator( if (operationShape.isEventStream(model)) { if (operationShape.isInputEventStream(model) && target == CodegenTarget.CLIENT) { val payloadMember = operationShape.inputShape(model).expectMember(payloadMemberName) - writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, "self") + writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, shapeName) } else if (operationShape.isOutputEventStream(model) && target == CodegenTarget.SERVER) { val payloadMember = operationShape.outputShape(model).expectMember(payloadMemberName) writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, "output") @@ -144,16 +151,16 @@ class HttpBoundProtocolPayloadGenerator( HttpMessageType.RESPONSE -> operationShape.outputShape(model).expectMember(payloadMemberName) HttpMessageType.REQUEST -> operationShape.inputShape(model).expectMember(payloadMemberName) } - writer.serializeViaPayload(bodyMetadata, self, payloadMember, serializerGenerator) + writer.serializeViaPayload(bodyMetadata, shapeName, payloadMember, serializerGenerator) } } - private fun generateStructureSerializer(writer: RustWriter, self: String, serializer: RuntimeType?) { + private fun generateStructureSerializer(writer: RustWriter, shapeName: String, serializer: RuntimeType?) { if (serializer == null) { writer.rust("\"\"") } else { writer.rust( - "#T(&$self)?", + "#T(&$shapeName)?", serializer, ) } @@ -193,47 +200,20 @@ class HttpBoundProtocolPayloadGenerator( // TODO(EventStream): [RPC] RPC protocols need to send an initial message with the // parameters that are not `@eventHeader` or `@eventPayload`. - when (target) { - CodegenTarget.CLIENT -> - rustTemplate( - """ - { - let error_marshaller = #{errorMarshallerConstructorFn}(); - let marshaller = #{marshallerConstructorFn}(); - let signer = _config.new_event_stream_signer(properties.clone()); - let adapter: #{SmithyHttp}::event_stream::MessageStreamAdapter<_, _> = - $outerName.$memberName.into_body_stream(marshaller, error_marshaller, signer); - let body: #{SdkBody} = #{hyper}::Body::wrap_stream(adapter).into(); - body - } - """, - *codegenScope, - "marshallerConstructorFn" to marshallerConstructorFn, - "errorMarshallerConstructorFn" to errorMarshallerConstructorFn, - ) - CodegenTarget.SERVER -> { - rustTemplate( - """ - { - let error_marshaller = #{errorMarshallerConstructorFn}(); - let marshaller = #{marshallerConstructorFn}(); - let signer = #{NoOpSigner}{}; - let adapter: #{SmithyHttp}::event_stream::MessageStreamAdapter<_, _> = - $outerName.$memberName.into_body_stream(marshaller, error_marshaller, signer); - adapter - } - """, - *codegenScope, - "marshallerConstructorFn" to marshallerConstructorFn, - "errorMarshallerConstructorFn" to errorMarshallerConstructorFn, - ) - } - } + renderEventStreamBody( + this, + EventStreamBodyParams( + outerName, + memberName, + marshallerConstructorFn, + errorMarshallerConstructorFn, + ), + ) } private fun RustWriter.serializeViaPayload( payloadMetadata: ProtocolPayloadGenerator.PayloadMetadata, - self: String, + shapeName: String, member: MemberShape, serializerGenerator: StructuredDataSerializerGenerator, ) { @@ -281,7 +261,7 @@ class HttpBoundProtocolPayloadGenerator( } } } - rust("#T($ref $self.${symbolProvider.toMemberName(member)})?", serializer) + rust("#T($ref $shapeName.${symbolProvider.toMemberName(member)})?", serializer) } private fun RustWriter.renderPayload( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 2530f04404..89accafd74 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -41,6 +41,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization @@ -48,12 +49,14 @@ import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustom import software.amazon.smithy.rust.codegen.core.smithy.customize.Section import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.mapRustType import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation +import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator @@ -114,6 +117,31 @@ class ServerHttpBoundProtocolGenerator( } } +class ServerHttpBoundProtocolPayloadGenerator( + codegenContext: CodegenContext, + protocol: Protocol, +) : ProtocolPayloadGenerator by HttpBoundProtocolPayloadGenerator( + codegenContext, protocol, HttpMessageType.RESPONSE, + renderEventStreamBody = { writer, params -> + writer.rustTemplate( + """ + { + let error_marshaller = #{errorMarshallerConstructorFn}(); + let marshaller = #{marshallerConstructorFn}(); + let signer = #{NoOpSigner}{}; + let adapter: #{aws_smithy_http}::event_stream::MessageStreamAdapter<_, _> = + ${params.outerName}.${params.memberName}.into_body_stream(marshaller, error_marshaller, signer); + adapter + } + """, + "aws_smithy_http" to RuntimeType.smithyHttp(codegenContext.runtimeConfig), + "NoOpSigner" to RuntimeType.smithyEventStream(codegenContext.runtimeConfig).resolve("frame::NoOpSigner"), + "marshallerConstructorFn" to params.marshallerConstructorFn, + "errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn, + ) + }, +) + /* * Generate all operation input parsers and output serializers for streaming and * non-streaming types. @@ -504,12 +532,12 @@ class ServerHttpBoundProtocolTraitImplGenerator( ?: serverRenderHttpResponseCode(httpTraitStatusCode)(this) operationShape.outputShape(model).findStreamingMember(model)?.let { - val payloadGenerator = HttpBoundProtocolPayloadGenerator(codegenContext, protocol, httpMessageType = HttpMessageType.RESPONSE) + val payloadGenerator = ServerHttpBoundProtocolPayloadGenerator(codegenContext, protocol) withBlockTemplate("let body = #{SmithyHttpServer}::body::boxed(#{SmithyHttpServer}::body::Body::wrap_stream(", "));", *codegenScope) { payloadGenerator.generatePayload(this, "output", operationShape) } } ?: run { - val payloadGenerator = HttpBoundProtocolPayloadGenerator(codegenContext, protocol, httpMessageType = HttpMessageType.RESPONSE) + val payloadGenerator = ServerHttpBoundProtocolPayloadGenerator(codegenContext, protocol) withBlockTemplate("let payload = ", ";") { payloadGenerator.generatePayload(this, "output", operationShape) } diff --git a/rust-runtime/aws-smithy-eventstream/src/frame.rs b/rust-runtime/aws-smithy-eventstream/src/frame.rs index a3d9c29a00..edf48e60de 100644 --- a/rust-runtime/aws-smithy-eventstream/src/frame.rs +++ b/rust-runtime/aws-smithy-eventstream/src/frame.rs @@ -14,6 +14,7 @@ use std::convert::{TryFrom, TryInto}; use std::error::Error as StdError; use std::fmt; use std::mem::size_of; +use std::sync::{mpsc, Mutex}; const PRELUDE_LENGTH_BYTES: u32 = 3 * size_of::() as u32; const PRELUDE_LENGTH_BYTES_USIZE: usize = PRELUDE_LENGTH_BYTES as usize; @@ -34,6 +35,95 @@ pub trait SignMessage: fmt::Debug { fn sign_empty(&mut self) -> Option>; } +/// A sender that gets placed in the request config to wire up an event stream signer after signing. +#[derive(Debug)] +#[non_exhaustive] +pub struct DeferredSignerSender(Mutex>>); + +impl DeferredSignerSender { + /// Creates a new `DeferredSignerSender` + fn new(tx: mpsc::Sender>) -> Self { + Self(Mutex::new(tx)) + } + + /// Sends a signer on the channel + pub fn send( + &self, + signer: Box, + ) -> Result<(), mpsc::SendError>> { + self.0.lock().unwrap().send(signer) + } +} + +/// Deferred event stream signer to allow a signer to be wired up later. +/// +/// HTTP request signing takes place after serialization, and the event stream +/// message stream body is established during serialization. Since event stream +/// signing may need context from the initial HTTP signing operation, this +/// [`DeferredSigner`] is needed to wire up the signer later in the request lifecycle. +/// +/// This signer basically just establishes a MPSC channel so that the sender can +/// be placed in the request's config. Then the HTTP signer implementation can +/// retrieve the sender from that config and send an actual signing implementation +/// with all the context needed. +/// +/// When an event stream implementation needs to sign a message, the first call to +/// sign will acquire a signing implementation off of the channel and cache it +/// for the remainder of the operation. +#[derive(Debug)] +pub struct DeferredSigner { + rx: Option>>>, + signer: Option>, +} + +impl DeferredSigner { + pub fn new() -> (Self, DeferredSignerSender) { + let (tx, rx) = mpsc::channel(); + ( + Self { + rx: Some(Mutex::new(rx)), + signer: None, + }, + DeferredSignerSender::new(tx), + ) + } + + fn acquire(&mut self) -> &mut (dyn SignMessage + Send + Sync) { + // Can't use `if let Some(signer) = &mut self.signer` because the borrow checker isn't smart enough + if self.signer.is_some() { + return self.signer.as_mut().unwrap().as_mut(); + } else { + self.signer = Some( + self.rx + .take() + .expect("only taken once") + .lock() + .unwrap() + .try_recv() + .ok() + // TODO(enableNewSmithyRuntime): When the middleware implementation is removed, + // this should panic rather than default to the `NoOpSigner`. The reason it defaults + // is because middleware-based generic clients don't have any default middleware, + // so there is no way to send a `NoOpSigner` by default when there is no other + // auth scheme. The orchestrator auth setup is a lot more robust and will make + // this problem trivial. + .unwrap_or_else(|| Box::new(NoOpSigner {}) as _), + ); + self.acquire() + } + } +} + +impl SignMessage for DeferredSigner { + fn sign(&mut self, message: Message) -> Result { + self.acquire().sign(message) + } + + fn sign_empty(&mut self) -> Option> { + self.acquire().sign_empty() + } +} + #[derive(Debug)] pub struct NoOpSigner {} impl SignMessage for NoOpSigner { @@ -848,3 +938,60 @@ mod message_frame_decoder_tests { } } } + +#[cfg(test)] +mod deferred_signer_tests { + use crate::frame::{DeferredSigner, Header, HeaderValue, Message, SignMessage}; + use bytes::Bytes; + + fn check_send_sync(value: T) -> T { + value + } + + #[test] + fn deferred_signer() { + #[derive(Default, Debug)] + struct TestSigner { + call_num: i32, + } + impl SignMessage for TestSigner { + fn sign( + &mut self, + message: crate::frame::Message, + ) -> Result { + self.call_num += 1; + Ok(message.add_header(Header::new("call_num", HeaderValue::Int32(self.call_num)))) + } + + fn sign_empty( + &mut self, + ) -> Option> { + None + } + } + + let (mut signer, sender) = check_send_sync(DeferredSigner::new()); + + sender + .send(Box::new(TestSigner::default())) + .expect("success"); + + let message = signer.sign(Message::new(Bytes::new())).expect("success"); + assert_eq!(1, message.headers()[0].value().as_int32().unwrap()); + + let message = signer.sign(Message::new(Bytes::new())).expect("success"); + assert_eq!(2, message.headers()[0].value().as_int32().unwrap()); + + assert!(signer.sign_empty().is_none()); + } + + #[test] + fn deferred_signer_defaults_to_noop_signer() { + let (mut signer, _sender) = DeferredSigner::new(); + assert_eq!( + Message::new(Bytes::new()), + signer.sign(Message::new(Bytes::new())).unwrap() + ); + assert!(signer.sign_empty().is_none()); + } +}