diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml
index f6dec54418..86fe2d7fe1 100644
--- a/CHANGELOG.next.toml
+++ b/CHANGELOG.next.toml
@@ -34,3 +34,19 @@ message = "`SsoCredentialsProvider`, `AssumeRoleProvider`, and `WebIdentityToken
references = ["smithy-rs#2720"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
author = "ysaito1001"
+
+[[smithy-rs]]
+message = """
+
+Breaking change in how event stream signing works (click to expand more details)
+
+This change will only impact you if you are wiring up their own event stream signing/authentication scheme. If you're using `aws-sig-auth` to use AWS SigV4 event stream signing, then this change will **not** impact you.
+
+Previously, event stream signing was configured at codegen time by placing a `new_event_stream_signer` method on the `Config`. This function was called at serialization time to connect the signer to the streaming body. Now, instead, a special `DeferredSigner` is wired up at serialization time that relies on a signing implementation to be sent on a channel by the HTTP request signer. To do this, a `DeferredSignerSender` must be pulled out of the property bag, and its `send()` method called with the desired event stream signing implementation.
+
+See the changes in https://github.com/awslabs/smithy-rs/pull/2671 for an example of how this was done for SigV4.
+
+"""
+references = ["smithy-rs#2671"]
+meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
+author = "jdisanti"
diff --git a/aws/rust-runtime/aws-sig-auth/src/event_stream.rs b/aws/rust-runtime/aws-sig-auth/src/event_stream.rs
index 74b8c65dd4..e6e06dac75 100644
--- a/aws/rust-runtime/aws-sig-auth/src/event_stream.rs
+++ b/aws/rust-runtime/aws-sig-auth/src/event_stream.rs
@@ -3,6 +3,9 @@
* SPDX-License-Identifier: Apache-2.0
*/
+// TODO(enableNewSmithyRuntime): Remove this blanket allow once the old implementations are deleted
+#![allow(deprecated)]
+
use crate::middleware::Signature;
use aws_credential_types::Credentials;
use aws_sigv4::event_stream::{sign_empty_message, sign_message};
@@ -15,6 +18,115 @@ use std::time::SystemTime;
/// Event Stream SigV4 signing implementation.
#[derive(Debug)]
+pub struct SigV4MessageSigner {
+ last_signature: String,
+ credentials: Credentials,
+ signing_region: SigningRegion,
+ signing_service: SigningService,
+ time: Option,
+}
+
+impl SigV4MessageSigner {
+ pub fn new(
+ last_signature: String,
+ credentials: Credentials,
+ signing_region: SigningRegion,
+ signing_service: SigningService,
+ time: Option,
+ ) -> Self {
+ Self {
+ last_signature,
+ credentials,
+ signing_region,
+ signing_service,
+ time,
+ }
+ }
+
+ fn signing_params(&self) -> SigningParams<()> {
+ let mut builder = SigningParams::builder()
+ .access_key(self.credentials.access_key_id())
+ .secret_key(self.credentials.secret_access_key())
+ .region(self.signing_region.as_ref())
+ .service_name(self.signing_service.as_ref())
+ .time(self.time.unwrap_or_else(SystemTime::now))
+ .settings(());
+ builder.set_security_token(self.credentials.session_token());
+ builder.build().unwrap()
+ }
+}
+
+impl SignMessage for SigV4MessageSigner {
+ fn sign(&mut self, message: Message) -> Result {
+ let (signed_message, signature) = {
+ let params = self.signing_params();
+ sign_message(&message, &self.last_signature, ¶ms).into_parts()
+ };
+ self.last_signature = signature;
+ Ok(signed_message)
+ }
+
+ fn sign_empty(&mut self) -> Option> {
+ let (signed_message, signature) = {
+ let params = self.signing_params();
+ sign_empty_message(&self.last_signature, ¶ms).into_parts()
+ };
+ self.last_signature = signature;
+ Some(Ok(signed_message))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::event_stream::SigV4MessageSigner;
+ use aws_credential_types::Credentials;
+ use aws_smithy_eventstream::frame::{HeaderValue, Message, SignMessage};
+ use aws_types::region::Region;
+ use aws_types::region::SigningRegion;
+ use aws_types::SigningService;
+ use std::time::{Duration, UNIX_EPOCH};
+
+ fn check_send_sync(value: T) -> T {
+ value
+ }
+
+ #[test]
+ fn sign_message() {
+ let region = Region::new("us-east-1");
+ let mut signer = check_send_sync(SigV4MessageSigner::new(
+ "initial-signature".into(),
+ Credentials::for_tests(),
+ SigningRegion::from(region),
+ SigningService::from_static("transcribe"),
+ Some(UNIX_EPOCH + Duration::new(1611160427, 0)),
+ ));
+ let mut signatures = Vec::new();
+ for _ in 0..5 {
+ let signed = signer
+ .sign(Message::new(&b"identical message"[..]))
+ .unwrap();
+ if let HeaderValue::ByteArray(signature) = signed
+ .headers()
+ .iter()
+ .find(|h| h.name().as_str() == ":chunk-signature")
+ .unwrap()
+ .value()
+ {
+ signatures.push(signature.clone());
+ } else {
+ panic!("failed to get the :chunk-signature")
+ }
+ }
+ for i in 1..signatures.len() {
+ assert_ne!(signatures[i - 1], signatures[i]);
+ }
+ }
+}
+
+// TODO(enableNewSmithyRuntime): Delete this old implementation that was kept around to support patch releases.
+#[deprecated = "use aws_sig_auth::event_stream::SigV4MessageSigner instead (this may require upgrading the smithy-rs code generator)"]
+#[derive(Debug)]
+/// Event Stream SigV4 signing implementation.
pub struct SigV4Signer {
properties: SharedPropertyBag,
last_signature: Option,
@@ -87,8 +199,9 @@ impl SignMessage for SigV4Signer {
}
}
+// TODO(enableNewSmithyRuntime): Delete this old implementation that was kept around to support patch releases.
#[cfg(test)]
-mod tests {
+mod old_tests {
use crate::event_stream::SigV4Signer;
use crate::middleware::Signature;
use aws_credential_types::Credentials;
diff --git a/aws/rust-runtime/aws-sig-auth/src/middleware.rs b/aws/rust-runtime/aws-sig-auth/src/middleware.rs
index d7ec53454c..c8c4794f0b 100644
--- a/aws/rust-runtime/aws-sig-auth/src/middleware.rs
+++ b/aws/rust-runtime/aws-sig-auth/src/middleware.rs
@@ -20,8 +20,15 @@ use crate::signer::{
OperationSigningConfig, RequestConfig, SigV4Signer, SigningError, SigningRequirements,
};
+#[cfg(feature = "sign-eventstream")]
+use crate::event_stream::SigV4MessageSigner as EventStreamSigV4Signer;
+#[cfg(feature = "sign-eventstream")]
+use aws_smithy_eventstream::frame::DeferredSignerSender;
+
+// TODO(enableNewSmithyRuntime): Delete `Signature` when switching to the orchestrator
/// Container for the request signature for use in the property bag.
#[non_exhaustive]
+#[derive(Debug, Clone)]
pub struct Signature(String);
impl Signature {
@@ -181,6 +188,22 @@ impl MapRequest for SigV4SigningStage {
.signer
.sign(operation_config, &request_config, &creds, &mut req)
.map_err(SigningStageErrorKind::SigningFailure)?;
+
+ // If this is an event stream operation, set up the event stream signer
+ #[cfg(feature = "sign-eventstream")]
+ if let Some(signer_sender) = config.get::() {
+ let time_override = config.get::().copied();
+ signer_sender
+ .send(Box::new(EventStreamSigV4Signer::new(
+ signature.as_ref().into(),
+ creds,
+ request_config.region.clone(),
+ request_config.service.clone(),
+ time_override,
+ )) as _)
+ .expect("failed to send deferred signer");
+ }
+
config.insert(signature);
Ok(req)
})
@@ -234,6 +257,49 @@ mod test {
assert!(signature.is_some());
}
+ #[cfg(feature = "sign-eventstream")]
+ #[test]
+ fn sends_event_stream_signer_for_event_stream_operations() {
+ use crate::event_stream::SigV4MessageSigner as EventStreamSigV4Signer;
+ use aws_smithy_eventstream::frame::{DeferredSigner, SignMessage};
+ use std::time::SystemTime;
+
+ let (mut deferred_signer, deferred_signer_sender) = DeferredSigner::new();
+ let req = http::Request::builder()
+ .uri("https://test-service.test-region.amazonaws.com/")
+ .body(SdkBody::from(""))
+ .unwrap();
+ let region = Region::new("us-east-1");
+ let req = operation::Request::new(req)
+ .augment(|req, properties| {
+ properties.insert(region.clone());
+ properties.insert::(UNIX_EPOCH + Duration::new(1611160427, 0));
+ properties.insert(SigningService::from_static("kinesis"));
+ properties.insert(OperationSigningConfig::default_config());
+ properties.insert(Credentials::for_tests());
+ properties.insert(SigningRegion::from(region.clone()));
+ properties.insert(deferred_signer_sender);
+ Result::<_, Infallible>::Ok(req)
+ })
+ .expect("succeeds");
+
+ let signer = SigV4SigningStage::new(SigV4Signer::new());
+ let _ = signer.apply(req).unwrap();
+
+ let mut signer_for_comparison = EventStreamSigV4Signer::new(
+ // This is the expected SigV4 signature for the HTTP request above
+ "abac477b4afabf5651079e7b9a0aa6a1a3e356a7418a81d974cdae9d4c8e5441".into(),
+ Credentials::for_tests(),
+ SigningRegion::from(region),
+ SigningService::from_static("kinesis"),
+ Some(UNIX_EPOCH + Duration::new(1611160427, 0)),
+ );
+
+ let expected_signed_empty = signer_for_comparison.sign_empty().unwrap().unwrap();
+ let actual_signed_empty = deferred_signer.sign_empty().unwrap().unwrap();
+ assert_eq!(expected_signed_empty, actual_signed_empty);
+ }
+
// check that the endpoint middleware followed by signing middleware produce the expected result
#[test]
fn endpoint_plus_signer() {
diff --git a/aws/rust-runtime/aws-sig-auth/src/signer.rs b/aws/rust-runtime/aws-sig-auth/src/signer.rs
index a1d36c97ca..d71c6ecf42 100644
--- a/aws/rust-runtime/aws-sig-auth/src/signer.rs
+++ b/aws/rust-runtime/aws-sig-auth/src/signer.rs
@@ -3,7 +3,6 @@
* SPDX-License-Identifier: Apache-2.0
*/
-use crate::middleware::Signature;
use aws_credential_types::Credentials;
use aws_sigv4::http_request::{
sign, PayloadChecksumKind, PercentEncodingMode, SessionTokenMode, SignableRequest,
@@ -15,6 +14,7 @@ use aws_types::SigningService;
use std::fmt;
use std::time::{Duration, SystemTime};
+use crate::middleware::Signature;
pub use aws_sigv4::http_request::SignableBody;
pub type SigningError = aws_sigv4::http_request::SigningError;
diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt
index ee5f82aff6..4c9f6feffc 100644
--- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt
+++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt
@@ -21,6 +21,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegen
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientSection
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.MakeOperationGenerator
+import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientHttpBoundProtocolPayloadGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.docs
@@ -34,7 +35,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection
-import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator
import software.amazon.smithy.rust.codegen.core.util.cloneOperation
import software.amazon.smithy.rust.codegen.core.util.expectTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
@@ -179,7 +179,7 @@ class AwsInputPresignedMethod(
MakeOperationGenerator(
codegenContext,
protocol,
- HttpBoundProtocolPayloadGenerator(codegenContext, protocol),
+ ClientHttpBoundProtocolPayloadGenerator(codegenContext, protocol),
// Prefixed with underscore to avoid colliding with modeled functions
functionName = makeOperationFn,
public = false,
diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt
index 81400d6597..3c40b518eb 100644
--- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt
+++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt
@@ -16,7 +16,7 @@ import software.amazon.smithy.model.traits.OptionalAuthTrait
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
-import software.amazon.smithy.rust.codegen.client.smithy.generators.config.EventStreamSigningConfig
+import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
@@ -77,17 +77,17 @@ class SigV4SigningDecorator : ClientCodegenDecorator {
}
class SigV4SigningConfig(
- runtimeConfig: RuntimeConfig,
+ private val runtimeConfig: RuntimeConfig,
private val serviceHasEventStream: Boolean,
private val sigV4Trait: SigV4Trait,
-) : EventStreamSigningConfig(runtimeConfig) {
- private val codegenScope = arrayOf(
- "SigV4Signer" to AwsRuntimeType.awsSigAuthEventStream(runtimeConfig).resolve("event_stream::SigV4Signer"),
- )
-
- override fun configImplSection(): Writable {
- return writable {
- rustTemplate(
+) : ConfigCustomization() {
+ override fun section(section: ServiceConfig): Writable = writable {
+ if (section is ServiceConfig.ConfigImpl) {
+ if (serviceHasEventStream) {
+ // enable the aws-sig-auth `sign-eventstream` feature
+ addDependency(AwsRuntimeType.awsSigAuthEventStream(runtimeConfig).toSymbol())
+ }
+ rust(
"""
/// The signature version 4 service signing name to use in the credential scope when signing requests.
///
@@ -97,24 +97,7 @@ class SigV4SigningConfig(
${sigV4Trait.name.dq()}
}
""",
- *codegenScope,
)
- if (serviceHasEventStream) {
- rustTemplate(
- "#{signerFn:W}",
- "signerFn" to
- renderEventStreamSignerFn { propertiesName ->
- writable {
- rustTemplate(
- """
- #{SigV4Signer}::new($propertiesName)
- """,
- *codegenScope,
- )
- }
- },
- )
- }
}
}
}
@@ -209,5 +192,3 @@ class SigV4SigningFeature(
}
}
}
-
-fun RuntimeConfig.sigAuth() = awsRuntimeCrate("aws-sig-auth")
diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4SigningCustomizationTest.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecoratorTest.kt
similarity index 96%
rename from aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4SigningCustomizationTest.kt
rename to aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecoratorTest.kt
index 9de7c91f65..71bd5eaf6c 100644
--- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4SigningCustomizationTest.kt
+++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecoratorTest.kt
@@ -12,7 +12,7 @@ import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
-internal class SigV4SigningCustomizationTest {
+internal class SigV4SigningDecoratorTest {
@Test
fun `generates a valid config`() {
val project = stubConfigProject(
diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt
index faa3a01c5d..08ba90d140 100644
--- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt
+++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt
@@ -15,7 +15,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.customizations.HttpAuth
import software.amazon.smithy.rust.codegen.client.smithy.customizations.HttpConnectorConfigDecorator
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.customize.CombinedClientCodegenDecorator
-import software.amazon.smithy.rust.codegen.client.smithy.customize.NoOpEventStreamSigningDecorator
import software.amazon.smithy.rust.codegen.client.smithy.customize.RequiredCustomizations
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointParamsDecorator
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointsDecorator
@@ -62,7 +61,6 @@ class RustClientCodegenPlugin : ClientDecoratableBuildPlugin() {
FluentClientDecorator(),
EndpointsDecorator(),
EndpointParamsDecorator(),
- NoOpEventStreamSigningDecorator(),
ApiKeyAuthDecorator(),
HttpAuthDecorator(),
HttpConnectorConfigDecorator(),
diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/NoOpEventStreamSigningDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/NoOpEventStreamSigningDecorator.kt
deleted file mode 100644
index 7d924ee75a..0000000000
--- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/NoOpEventStreamSigningDecorator.kt
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
- * SPDX-License-Identifier: Apache-2.0
- */
-
-package software.amazon.smithy.rust.codegen.client.smithy.customize
-
-import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
-import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
-import software.amazon.smithy.rust.codegen.client.smithy.generators.config.EventStreamSigningConfig
-import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
-import software.amazon.smithy.rust.codegen.core.rustlang.writable
-import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
-import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
-import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
-import software.amazon.smithy.rust.codegen.core.util.hasEventStreamOperations
-
-/**
- * The NoOpEventStreamSigningDecorator:
- * - adds a `new_event_stream_signer()` method to `config` to create an Event Stream NoOp signer
- */
-open class NoOpEventStreamSigningDecorator : ClientCodegenDecorator {
- override val name: String = "NoOpEventStreamSigning"
- override val order: Byte = Byte.MIN_VALUE
-
- private fun applies(codegenContext: CodegenContext, baseCustomizations: List): Boolean =
- codegenContext.serviceShape.hasEventStreamOperations(codegenContext.model) &&
- // and if there is no other `EventStreamSigningConfig`, apply this one
- !baseCustomizations.any { it is EventStreamSigningConfig }
-
- override fun configCustomizations(
- codegenContext: ClientCodegenContext,
- baseCustomizations: List,
- ): List {
- if (!applies(codegenContext, baseCustomizations)) {
- return baseCustomizations
- }
- return baseCustomizations + NoOpEventStreamSigningConfig(
- codegenContext.serviceShape.hasEventStreamOperations(codegenContext.model),
- codegenContext.runtimeConfig,
- )
- }
-}
-
-class NoOpEventStreamSigningConfig(
- private val serviceHasEventStream: Boolean,
- runtimeConfig: RuntimeConfig,
-) : EventStreamSigningConfig(runtimeConfig) {
-
- private val codegenScope = arrayOf(
- "NoOpSigner" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::NoOpSigner"),
- )
-
- override fun configImplSection() = renderEventStreamSignerFn {
- writable {
- if (serviceHasEventStream) {
- rustTemplate(
- """
- #{NoOpSigner}{}
- """,
- *codegenScope,
- )
- }
- }
- }
-}
diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/EventStreamSigningConfig.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/EventStreamSigningConfig.kt
deleted file mode 100644
index 35da7d63b0..0000000000
--- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/EventStreamSigningConfig.kt
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
- * SPDX-License-Identifier: Apache-2.0
- */
-
-package software.amazon.smithy.rust.codegen.client.smithy.generators.config
-
-import software.amazon.smithy.rust.codegen.core.rustlang.Writable
-import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
-import software.amazon.smithy.rust.codegen.core.rustlang.writable
-import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
-import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
-
-open class EventStreamSigningConfig(
- runtimeConfig: RuntimeConfig,
-) : ConfigCustomization() {
- private val codegenScope = arrayOf(
- "SharedPropertyBag" to RuntimeType.smithyHttp(runtimeConfig).resolve("property_bag::SharedPropertyBag"),
- "SignMessage" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::SignMessage"),
- )
-
- override fun section(section: ServiceConfig): Writable {
- return when (section) {
- is ServiceConfig.ConfigImpl -> configImplSection()
- else -> emptySection
- }
- }
-
- open fun configImplSection(): Writable = emptySection
-
- fun renderEventStreamSignerFn(signerInstantiator: (String) -> Writable): Writable = writable {
- rustTemplate(
- """
- /// Creates a new Event Stream `SignMessage` implementor.
- pub fn new_event_stream_signer(
- &self,
- _properties: #{SharedPropertyBag}
- ) -> impl #{SignMessage} {
- #{signer:W}
- }
- """,
- *codegenScope,
- "signer" to signerInstantiator("_properties"),
- )
- }
-}
diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt
index ae83e78cfe..69e0bdd480 100644
--- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt
+++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt
@@ -13,6 +13,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.Cli
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.MakeOperationGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolParserGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
+import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
@@ -23,6 +24,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.pre
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection
import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations
+import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
@@ -30,11 +32,11 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctio
import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.core.util.outputShape
-// TODO(enableNewSmithyRuntime): Delete this class when cleaning up `enableNewSmithyRuntime`
+// TODO(enableNewSmithyRuntime): Delete this class when cleaning up `enableNewSmithyRuntime` (replace with ClientProtocolGenerator)
class HttpBoundProtocolGenerator(
codegenContext: ClientCodegenContext,
protocol: Protocol,
- bodyGenerator: ProtocolPayloadGenerator = HttpBoundProtocolPayloadGenerator(codegenContext, protocol),
+ bodyGenerator: ProtocolPayloadGenerator = ClientHttpBoundProtocolPayloadGenerator(codegenContext, protocol),
) : ClientProtocolGenerator(
codegenContext,
protocol,
@@ -49,6 +51,35 @@ class HttpBoundProtocolGenerator(
HttpBoundProtocolTraitImplGenerator(codegenContext, protocol),
)
+class ClientHttpBoundProtocolPayloadGenerator(
+ codegenContext: ClientCodegenContext,
+ protocol: Protocol,
+) : ProtocolPayloadGenerator by HttpBoundProtocolPayloadGenerator(
+ codegenContext, protocol, HttpMessageType.REQUEST,
+ renderEventStreamBody = { writer, params ->
+ writer.rustTemplate(
+ """
+ {
+ let error_marshaller = #{errorMarshallerConstructorFn}();
+ let marshaller = #{marshallerConstructorFn}();
+ let (signer, signer_sender) = #{DeferredSigner}::new();
+ properties.acquire_mut().insert(signer_sender);
+ let adapter: #{aws_smithy_http}::event_stream::MessageStreamAdapter<_, _> =
+ ${params.outerName}.${params.memberName}.into_body_stream(marshaller, error_marshaller, signer);
+ let body: #{SdkBody} = #{hyper}::Body::wrap_stream(adapter).into();
+ body
+ }
+ """,
+ "hyper" to CargoDependency.HyperWithStream.toType(),
+ "SdkBody" to RuntimeType.sdkBody(codegenContext.runtimeConfig),
+ "aws_smithy_http" to RuntimeType.smithyHttp(codegenContext.runtimeConfig),
+ "DeferredSigner" to RuntimeType.smithyEventStream(codegenContext.runtimeConfig).resolve("frame::DeferredSigner"),
+ "marshallerConstructorFn" to params.marshallerConstructorFn,
+ "errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn,
+ )
+ },
+)
+
// TODO(enableNewSmithyRuntime): Delete this class when cleaning up `enableNewSmithyRuntime`
open class HttpBoundProtocolTraitImplGenerator(
codegenContext: ClientCodegenContext,
diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitorTest.kt
index 2cb21423fd..f9cf52375b 100644
--- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitorTest.kt
+++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitorTest.kt
@@ -10,7 +10,6 @@ import org.junit.jupiter.api.Test
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rust.codegen.client.smithy.customizations.ClientCustomizations
import software.amazon.smithy.rust.codegen.client.smithy.customize.CombinedClientCodegenDecorator
-import software.amazon.smithy.rust.codegen.client.smithy.customize.NoOpEventStreamSigningDecorator
import software.amazon.smithy.rust.codegen.client.smithy.customize.RequiredCustomizations
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientDecorator
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
@@ -52,7 +51,6 @@ class ClientCodegenVisitorTest {
ClientCustomizations(),
RequiredCustomizations(),
FluentClientDecorator(),
- NoOpEventStreamSigningDecorator(),
)
val visitor = ClientCodegenVisitor(ctx, codegenDecorator)
val baselineModel = visitor.baselineTransform(model)
diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpVersionListGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpVersionListGeneratorTest.kt
index 3b5fc70a6f..73633a603e 100644
--- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpVersionListGeneratorTest.kt
+++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpVersionListGeneratorTest.kt
@@ -6,19 +6,9 @@
package software.amazon.smithy.rust.codegen.client.smithy.customizations
import org.junit.jupiter.api.Test
-import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
-import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
-import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
-import software.amazon.smithy.rust.codegen.client.smithy.generators.config.EventStreamSigningConfig
-import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig
import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
-import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
-import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
-import software.amazon.smithy.rust.codegen.core.rustlang.writable
-import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
-import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.integrationTest
@@ -172,7 +162,6 @@ internal class HttpVersionListGeneratorTest {
clientIntegrationTest(
model,
IntegrationTestParams(addModuleToEventStreamAllowList = true),
- additionalDecorators = listOf(FakeSigningDecorator()),
) { clientCodegenContext, rustCrate ->
val moduleName = clientCodegenContext.moduleUseName()
rustCrate.integrationTest("validate_eventstream_http") {
@@ -196,77 +185,3 @@ internal class HttpVersionListGeneratorTest {
}
}
}
-
-class FakeSigningDecorator : ClientCodegenDecorator {
- override val name: String = "fakesigning"
- override val order: Byte = 0
- override fun classpathDiscoverable(): Boolean = false
- override fun configCustomizations(
- codegenContext: ClientCodegenContext,
- baseCustomizations: List,
- ): List {
- return baseCustomizations.filterNot {
- it is EventStreamSigningConfig
- } + FakeSigningConfig(codegenContext.runtimeConfig)
- }
-}
-
-class FakeSigningConfig(
- runtimeConfig: RuntimeConfig,
-) : EventStreamSigningConfig(runtimeConfig) {
- private val codegenScope = arrayOf(
- "SharedPropertyBag" to RuntimeType.smithyHttp(runtimeConfig).resolve("property_bag::SharedPropertyBag"),
- "SignMessageError" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::SignMessageError"),
- "SignMessage" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::SignMessage"),
- "Message" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::Message"),
- )
-
- override fun section(section: ServiceConfig): Writable {
- return when (section) {
- is ServiceConfig.ConfigImpl -> writable {
- rustTemplate(
- """
- /// Creates a new Event Stream `SignMessage` implementor.
- pub fn new_event_stream_signer(
- &self,
- properties: #{SharedPropertyBag}
- ) -> FakeSigner {
- FakeSigner::new(properties)
- }
- """,
- *codegenScope,
- )
- }
-
- is ServiceConfig.Extras -> writable {
- rustTemplate(
- """
- /// Fake signing implementation.
- ##[derive(Debug)]
- pub struct FakeSigner;
-
- impl FakeSigner {
- /// Create a real `FakeSigner`
- pub fn new(_properties: #{SharedPropertyBag}) -> Self {
- Self {}
- }
- }
-
- impl #{SignMessage} for FakeSigner {
- fn sign(&mut self, message: #{Message}) -> Result<#{Message}, #{SignMessageError}> {
- Ok(message)
- }
-
- fn sign_empty(&mut self) -> Option> {
- Some(Ok(#{Message}::new(Vec::new())))
- }
- }
- """,
- *codegenScope,
- )
- }
-
- else -> emptySection
- }
- }
-}
diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt
index d0d271231a..b03f374403 100644
--- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt
+++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt
@@ -38,7 +38,7 @@ private class TestProtocolPayloadGenerator(private val body: String) : ProtocolP
override fun payloadMetadata(operationShape: OperationShape) =
ProtocolPayloadGenerator.PayloadMetadata(takesOwnership = false)
- override fun generatePayload(writer: RustWriter, self: String, operationShape: OperationShape) {
+ override fun generatePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) {
writer.writeWithNoFormatting(body)
}
}
diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolGenerator.kt
index 4f410bf03a..ceb385c6d1 100644
--- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolGenerator.kt
+++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolGenerator.kt
@@ -36,13 +36,13 @@ interface ProtocolPayloadGenerator {
/**
* Write the payload into [writer].
*
- * [self] is the name of the variable binding for the Rust struct that is to be serialized into the payload.
+ * [shapeName] is the name of the variable binding for the Rust struct that is to be serialized into the payload.
*
* This should be an expression that returns bytes:
* - a `Vec` for non-streaming operations; or
* - a `ByteStream` for streaming operations.
*/
- fun generatePayload(writer: RustWriter, self: String, operationShape: OperationShape)
+ fun generatePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape)
}
/**
diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt
index 6d4e7bf850..1e17beca4e 100644
--- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt
+++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt
@@ -18,7 +18,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
-import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
@@ -42,10 +41,18 @@ import software.amazon.smithy.rust.codegen.core.util.isOutputEventStream
import software.amazon.smithy.rust.codegen.core.util.isStreaming
import software.amazon.smithy.rust.codegen.core.util.outputShape
+data class EventStreamBodyParams(
+ val outerName: String,
+ val memberName: String,
+ val marshallerConstructorFn: RuntimeType,
+ val errorMarshallerConstructorFn: RuntimeType,
+)
+
class HttpBoundProtocolPayloadGenerator(
codegenContext: CodegenContext,
private val protocol: Protocol,
private val httpMessageType: HttpMessageType = HttpMessageType.REQUEST,
+ private val renderEventStreamBody: (RustWriter, EventStreamBodyParams) -> Unit,
) : ProtocolPayloadGenerator {
private val symbolProvider = codegenContext.symbolProvider
private val model = codegenContext.model
@@ -91,38 +98,38 @@ class HttpBoundProtocolPayloadGenerator(
}
}
- override fun generatePayload(writer: RustWriter, self: String, operationShape: OperationShape) {
+ override fun generatePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) {
when (httpMessageType) {
- HttpMessageType.RESPONSE -> generateResponsePayload(writer, self, operationShape)
- HttpMessageType.REQUEST -> generateRequestPayload(writer, self, operationShape)
+ HttpMessageType.RESPONSE -> generateResponsePayload(writer, shapeName, operationShape)
+ HttpMessageType.REQUEST -> generateRequestPayload(writer, shapeName, operationShape)
}
}
- private fun generateRequestPayload(writer: RustWriter, self: String, operationShape: OperationShape) {
+ private fun generateRequestPayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) {
val payloadMemberName = httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName
if (payloadMemberName == null) {
val serializerGenerator = protocol.structuredDataSerializer()
- generateStructureSerializer(writer, self, serializerGenerator.operationInputSerializer(operationShape))
+ generateStructureSerializer(writer, shapeName, serializerGenerator.operationInputSerializer(operationShape))
} else {
- generatePayloadMemberSerializer(writer, self, operationShape, payloadMemberName)
+ generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName)
}
}
- private fun generateResponsePayload(writer: RustWriter, self: String, operationShape: OperationShape) {
+ private fun generateResponsePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) {
val payloadMemberName = httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName
if (payloadMemberName == null) {
val serializerGenerator = protocol.structuredDataSerializer()
- generateStructureSerializer(writer, self, serializerGenerator.operationOutputSerializer(operationShape))
+ generateStructureSerializer(writer, shapeName, serializerGenerator.operationOutputSerializer(operationShape))
} else {
- generatePayloadMemberSerializer(writer, self, operationShape, payloadMemberName)
+ generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName)
}
}
private fun generatePayloadMemberSerializer(
writer: RustWriter,
- self: String,
+ shapeName: String,
operationShape: OperationShape,
payloadMemberName: String,
) {
@@ -131,7 +138,7 @@ class HttpBoundProtocolPayloadGenerator(
if (operationShape.isEventStream(model)) {
if (operationShape.isInputEventStream(model) && target == CodegenTarget.CLIENT) {
val payloadMember = operationShape.inputShape(model).expectMember(payloadMemberName)
- writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, "self")
+ writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, shapeName)
} else if (operationShape.isOutputEventStream(model) && target == CodegenTarget.SERVER) {
val payloadMember = operationShape.outputShape(model).expectMember(payloadMemberName)
writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, "output")
@@ -144,16 +151,16 @@ class HttpBoundProtocolPayloadGenerator(
HttpMessageType.RESPONSE -> operationShape.outputShape(model).expectMember(payloadMemberName)
HttpMessageType.REQUEST -> operationShape.inputShape(model).expectMember(payloadMemberName)
}
- writer.serializeViaPayload(bodyMetadata, self, payloadMember, serializerGenerator)
+ writer.serializeViaPayload(bodyMetadata, shapeName, payloadMember, serializerGenerator)
}
}
- private fun generateStructureSerializer(writer: RustWriter, self: String, serializer: RuntimeType?) {
+ private fun generateStructureSerializer(writer: RustWriter, shapeName: String, serializer: RuntimeType?) {
if (serializer == null) {
writer.rust("\"\"")
} else {
writer.rust(
- "#T(&$self)?",
+ "#T(&$shapeName)?",
serializer,
)
}
@@ -193,47 +200,20 @@ class HttpBoundProtocolPayloadGenerator(
// TODO(EventStream): [RPC] RPC protocols need to send an initial message with the
// parameters that are not `@eventHeader` or `@eventPayload`.
- when (target) {
- CodegenTarget.CLIENT ->
- rustTemplate(
- """
- {
- let error_marshaller = #{errorMarshallerConstructorFn}();
- let marshaller = #{marshallerConstructorFn}();
- let signer = _config.new_event_stream_signer(properties.clone());
- let adapter: #{SmithyHttp}::event_stream::MessageStreamAdapter<_, _> =
- $outerName.$memberName.into_body_stream(marshaller, error_marshaller, signer);
- let body: #{SdkBody} = #{hyper}::Body::wrap_stream(adapter).into();
- body
- }
- """,
- *codegenScope,
- "marshallerConstructorFn" to marshallerConstructorFn,
- "errorMarshallerConstructorFn" to errorMarshallerConstructorFn,
- )
- CodegenTarget.SERVER -> {
- rustTemplate(
- """
- {
- let error_marshaller = #{errorMarshallerConstructorFn}();
- let marshaller = #{marshallerConstructorFn}();
- let signer = #{NoOpSigner}{};
- let adapter: #{SmithyHttp}::event_stream::MessageStreamAdapter<_, _> =
- $outerName.$memberName.into_body_stream(marshaller, error_marshaller, signer);
- adapter
- }
- """,
- *codegenScope,
- "marshallerConstructorFn" to marshallerConstructorFn,
- "errorMarshallerConstructorFn" to errorMarshallerConstructorFn,
- )
- }
- }
+ renderEventStreamBody(
+ this,
+ EventStreamBodyParams(
+ outerName,
+ memberName,
+ marshallerConstructorFn,
+ errorMarshallerConstructorFn,
+ ),
+ )
}
private fun RustWriter.serializeViaPayload(
payloadMetadata: ProtocolPayloadGenerator.PayloadMetadata,
- self: String,
+ shapeName: String,
member: MemberShape,
serializerGenerator: StructuredDataSerializerGenerator,
) {
@@ -281,7 +261,7 @@ class HttpBoundProtocolPayloadGenerator(
}
}
}
- rust("#T($ref $self.${symbolProvider.toMemberName(member)})?", serializer)
+ rust("#T($ref $shapeName.${symbolProvider.toMemberName(member)})?", serializer)
}
private fun RustWriter.renderPayload(
diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt
index 2530f04404..89accafd74 100644
--- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt
+++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt
@@ -41,6 +41,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
+import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization
@@ -48,12 +49,14 @@ import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustom
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType
+import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.mapRustType
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation
+import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator
@@ -114,6 +117,31 @@ class ServerHttpBoundProtocolGenerator(
}
}
+class ServerHttpBoundProtocolPayloadGenerator(
+ codegenContext: CodegenContext,
+ protocol: Protocol,
+) : ProtocolPayloadGenerator by HttpBoundProtocolPayloadGenerator(
+ codegenContext, protocol, HttpMessageType.RESPONSE,
+ renderEventStreamBody = { writer, params ->
+ writer.rustTemplate(
+ """
+ {
+ let error_marshaller = #{errorMarshallerConstructorFn}();
+ let marshaller = #{marshallerConstructorFn}();
+ let signer = #{NoOpSigner}{};
+ let adapter: #{aws_smithy_http}::event_stream::MessageStreamAdapter<_, _> =
+ ${params.outerName}.${params.memberName}.into_body_stream(marshaller, error_marshaller, signer);
+ adapter
+ }
+ """,
+ "aws_smithy_http" to RuntimeType.smithyHttp(codegenContext.runtimeConfig),
+ "NoOpSigner" to RuntimeType.smithyEventStream(codegenContext.runtimeConfig).resolve("frame::NoOpSigner"),
+ "marshallerConstructorFn" to params.marshallerConstructorFn,
+ "errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn,
+ )
+ },
+)
+
/*
* Generate all operation input parsers and output serializers for streaming and
* non-streaming types.
@@ -504,12 +532,12 @@ class ServerHttpBoundProtocolTraitImplGenerator(
?: serverRenderHttpResponseCode(httpTraitStatusCode)(this)
operationShape.outputShape(model).findStreamingMember(model)?.let {
- val payloadGenerator = HttpBoundProtocolPayloadGenerator(codegenContext, protocol, httpMessageType = HttpMessageType.RESPONSE)
+ val payloadGenerator = ServerHttpBoundProtocolPayloadGenerator(codegenContext, protocol)
withBlockTemplate("let body = #{SmithyHttpServer}::body::boxed(#{SmithyHttpServer}::body::Body::wrap_stream(", "));", *codegenScope) {
payloadGenerator.generatePayload(this, "output", operationShape)
}
} ?: run {
- val payloadGenerator = HttpBoundProtocolPayloadGenerator(codegenContext, protocol, httpMessageType = HttpMessageType.RESPONSE)
+ val payloadGenerator = ServerHttpBoundProtocolPayloadGenerator(codegenContext, protocol)
withBlockTemplate("let payload = ", ";") {
payloadGenerator.generatePayload(this, "output", operationShape)
}
diff --git a/rust-runtime/aws-smithy-eventstream/src/frame.rs b/rust-runtime/aws-smithy-eventstream/src/frame.rs
index a3d9c29a00..edf48e60de 100644
--- a/rust-runtime/aws-smithy-eventstream/src/frame.rs
+++ b/rust-runtime/aws-smithy-eventstream/src/frame.rs
@@ -14,6 +14,7 @@ use std::convert::{TryFrom, TryInto};
use std::error::Error as StdError;
use std::fmt;
use std::mem::size_of;
+use std::sync::{mpsc, Mutex};
const PRELUDE_LENGTH_BYTES: u32 = 3 * size_of::() as u32;
const PRELUDE_LENGTH_BYTES_USIZE: usize = PRELUDE_LENGTH_BYTES as usize;
@@ -34,6 +35,95 @@ pub trait SignMessage: fmt::Debug {
fn sign_empty(&mut self) -> Option>;
}
+/// A sender that gets placed in the request config to wire up an event stream signer after signing.
+#[derive(Debug)]
+#[non_exhaustive]
+pub struct DeferredSignerSender(Mutex>>);
+
+impl DeferredSignerSender {
+ /// Creates a new `DeferredSignerSender`
+ fn new(tx: mpsc::Sender>) -> Self {
+ Self(Mutex::new(tx))
+ }
+
+ /// Sends a signer on the channel
+ pub fn send(
+ &self,
+ signer: Box,
+ ) -> Result<(), mpsc::SendError>> {
+ self.0.lock().unwrap().send(signer)
+ }
+}
+
+/// Deferred event stream signer to allow a signer to be wired up later.
+///
+/// HTTP request signing takes place after serialization, and the event stream
+/// message stream body is established during serialization. Since event stream
+/// signing may need context from the initial HTTP signing operation, this
+/// [`DeferredSigner`] is needed to wire up the signer later in the request lifecycle.
+///
+/// This signer basically just establishes a MPSC channel so that the sender can
+/// be placed in the request's config. Then the HTTP signer implementation can
+/// retrieve the sender from that config and send an actual signing implementation
+/// with all the context needed.
+///
+/// When an event stream implementation needs to sign a message, the first call to
+/// sign will acquire a signing implementation off of the channel and cache it
+/// for the remainder of the operation.
+#[derive(Debug)]
+pub struct DeferredSigner {
+ rx: Option>>>,
+ signer: Option>,
+}
+
+impl DeferredSigner {
+ pub fn new() -> (Self, DeferredSignerSender) {
+ let (tx, rx) = mpsc::channel();
+ (
+ Self {
+ rx: Some(Mutex::new(rx)),
+ signer: None,
+ },
+ DeferredSignerSender::new(tx),
+ )
+ }
+
+ fn acquire(&mut self) -> &mut (dyn SignMessage + Send + Sync) {
+ // Can't use `if let Some(signer) = &mut self.signer` because the borrow checker isn't smart enough
+ if self.signer.is_some() {
+ return self.signer.as_mut().unwrap().as_mut();
+ } else {
+ self.signer = Some(
+ self.rx
+ .take()
+ .expect("only taken once")
+ .lock()
+ .unwrap()
+ .try_recv()
+ .ok()
+ // TODO(enableNewSmithyRuntime): When the middleware implementation is removed,
+ // this should panic rather than default to the `NoOpSigner`. The reason it defaults
+ // is because middleware-based generic clients don't have any default middleware,
+ // so there is no way to send a `NoOpSigner` by default when there is no other
+ // auth scheme. The orchestrator auth setup is a lot more robust and will make
+ // this problem trivial.
+ .unwrap_or_else(|| Box::new(NoOpSigner {}) as _),
+ );
+ self.acquire()
+ }
+ }
+}
+
+impl SignMessage for DeferredSigner {
+ fn sign(&mut self, message: Message) -> Result {
+ self.acquire().sign(message)
+ }
+
+ fn sign_empty(&mut self) -> Option> {
+ self.acquire().sign_empty()
+ }
+}
+
#[derive(Debug)]
pub struct NoOpSigner {}
impl SignMessage for NoOpSigner {
@@ -848,3 +938,60 @@ mod message_frame_decoder_tests {
}
}
}
+
+#[cfg(test)]
+mod deferred_signer_tests {
+ use crate::frame::{DeferredSigner, Header, HeaderValue, Message, SignMessage};
+ use bytes::Bytes;
+
+ fn check_send_sync(value: T) -> T {
+ value
+ }
+
+ #[test]
+ fn deferred_signer() {
+ #[derive(Default, Debug)]
+ struct TestSigner {
+ call_num: i32,
+ }
+ impl SignMessage for TestSigner {
+ fn sign(
+ &mut self,
+ message: crate::frame::Message,
+ ) -> Result {
+ self.call_num += 1;
+ Ok(message.add_header(Header::new("call_num", HeaderValue::Int32(self.call_num))))
+ }
+
+ fn sign_empty(
+ &mut self,
+ ) -> Option> {
+ None
+ }
+ }
+
+ let (mut signer, sender) = check_send_sync(DeferredSigner::new());
+
+ sender
+ .send(Box::new(TestSigner::default()))
+ .expect("success");
+
+ let message = signer.sign(Message::new(Bytes::new())).expect("success");
+ assert_eq!(1, message.headers()[0].value().as_int32().unwrap());
+
+ let message = signer.sign(Message::new(Bytes::new())).expect("success");
+ assert_eq!(2, message.headers()[0].value().as_int32().unwrap());
+
+ assert!(signer.sign_empty().is_none());
+ }
+
+ #[test]
+ fn deferred_signer_defaults_to_noop_signer() {
+ let (mut signer, _sender) = DeferredSigner::new();
+ assert_eq!(
+ Message::new(Bytes::new()),
+ signer.sign(Message::new(Bytes::new())).unwrap()
+ );
+ assert!(signer.sign_empty().is_none());
+ }
+}