mirror of https://github.com/smithy-lang/smithy-rs
Simplify event stream message signer configuration (#2671)
## Motivation and Context This PR creates a `DeferredSigner` implementation that allows for the event stream message signer to be wired up by the signing implementation later in the request lifecycle rather than by adding an event stream signer method to the config. Refactoring this brings the middleware client implementation closer to how the orchestrator implementation will work, which unblocks the work required to make event streams work in the orchestrator. ---- _By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice._
This commit is contained in:
parent
d083c6f271
commit
8773a70428
|
@ -34,3 +34,19 @@ message = "`SsoCredentialsProvider`, `AssumeRoleProvider`, and `WebIdentityToken
|
|||
references = ["smithy-rs#2720"]
|
||||
meta = { "breaking" = false, "tada" = false, "bug" = true }
|
||||
author = "ysaito1001"
|
||||
|
||||
[[smithy-rs]]
|
||||
message = """
|
||||
<details>
|
||||
<summary>Breaking change in how event stream signing works (click to expand more details)</summary>
|
||||
|
||||
This change will only impact you if you are wiring up their own event stream signing/authentication scheme. If you're using `aws-sig-auth` to use AWS SigV4 event stream signing, then this change will **not** impact you.
|
||||
|
||||
Previously, event stream signing was configured at codegen time by placing a `new_event_stream_signer` method on the `Config`. This function was called at serialization time to connect the signer to the streaming body. Now, instead, a special `DeferredSigner` is wired up at serialization time that relies on a signing implementation to be sent on a channel by the HTTP request signer. To do this, a `DeferredSignerSender` must be pulled out of the property bag, and its `send()` method called with the desired event stream signing implementation.
|
||||
|
||||
See the changes in https://github.com/awslabs/smithy-rs/pull/2671 for an example of how this was done for SigV4.
|
||||
</details>
|
||||
"""
|
||||
references = ["smithy-rs#2671"]
|
||||
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
|
||||
author = "jdisanti"
|
||||
|
|
|
@ -3,6 +3,9 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
// TODO(enableNewSmithyRuntime): Remove this blanket allow once the old implementations are deleted
|
||||
#![allow(deprecated)]
|
||||
|
||||
use crate::middleware::Signature;
|
||||
use aws_credential_types::Credentials;
|
||||
use aws_sigv4::event_stream::{sign_empty_message, sign_message};
|
||||
|
@ -15,6 +18,115 @@ use std::time::SystemTime;
|
|||
|
||||
/// Event Stream SigV4 signing implementation.
|
||||
#[derive(Debug)]
|
||||
pub struct SigV4MessageSigner {
|
||||
last_signature: String,
|
||||
credentials: Credentials,
|
||||
signing_region: SigningRegion,
|
||||
signing_service: SigningService,
|
||||
time: Option<SystemTime>,
|
||||
}
|
||||
|
||||
impl SigV4MessageSigner {
|
||||
pub fn new(
|
||||
last_signature: String,
|
||||
credentials: Credentials,
|
||||
signing_region: SigningRegion,
|
||||
signing_service: SigningService,
|
||||
time: Option<SystemTime>,
|
||||
) -> Self {
|
||||
Self {
|
||||
last_signature,
|
||||
credentials,
|
||||
signing_region,
|
||||
signing_service,
|
||||
time,
|
||||
}
|
||||
}
|
||||
|
||||
fn signing_params(&self) -> SigningParams<()> {
|
||||
let mut builder = SigningParams::builder()
|
||||
.access_key(self.credentials.access_key_id())
|
||||
.secret_key(self.credentials.secret_access_key())
|
||||
.region(self.signing_region.as_ref())
|
||||
.service_name(self.signing_service.as_ref())
|
||||
.time(self.time.unwrap_or_else(SystemTime::now))
|
||||
.settings(());
|
||||
builder.set_security_token(self.credentials.session_token());
|
||||
builder.build().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl SignMessage for SigV4MessageSigner {
|
||||
fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
|
||||
let (signed_message, signature) = {
|
||||
let params = self.signing_params();
|
||||
sign_message(&message, &self.last_signature, ¶ms).into_parts()
|
||||
};
|
||||
self.last_signature = signature;
|
||||
Ok(signed_message)
|
||||
}
|
||||
|
||||
fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
|
||||
let (signed_message, signature) = {
|
||||
let params = self.signing_params();
|
||||
sign_empty_message(&self.last_signature, ¶ms).into_parts()
|
||||
};
|
||||
self.last_signature = signature;
|
||||
Some(Ok(signed_message))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::event_stream::SigV4MessageSigner;
|
||||
use aws_credential_types::Credentials;
|
||||
use aws_smithy_eventstream::frame::{HeaderValue, Message, SignMessage};
|
||||
use aws_types::region::Region;
|
||||
use aws_types::region::SigningRegion;
|
||||
use aws_types::SigningService;
|
||||
use std::time::{Duration, UNIX_EPOCH};
|
||||
|
||||
fn check_send_sync<T: Send + Sync>(value: T) -> T {
|
||||
value
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sign_message() {
|
||||
let region = Region::new("us-east-1");
|
||||
let mut signer = check_send_sync(SigV4MessageSigner::new(
|
||||
"initial-signature".into(),
|
||||
Credentials::for_tests(),
|
||||
SigningRegion::from(region),
|
||||
SigningService::from_static("transcribe"),
|
||||
Some(UNIX_EPOCH + Duration::new(1611160427, 0)),
|
||||
));
|
||||
let mut signatures = Vec::new();
|
||||
for _ in 0..5 {
|
||||
let signed = signer
|
||||
.sign(Message::new(&b"identical message"[..]))
|
||||
.unwrap();
|
||||
if let HeaderValue::ByteArray(signature) = signed
|
||||
.headers()
|
||||
.iter()
|
||||
.find(|h| h.name().as_str() == ":chunk-signature")
|
||||
.unwrap()
|
||||
.value()
|
||||
{
|
||||
signatures.push(signature.clone());
|
||||
} else {
|
||||
panic!("failed to get the :chunk-signature")
|
||||
}
|
||||
}
|
||||
for i in 1..signatures.len() {
|
||||
assert_ne!(signatures[i - 1], signatures[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(enableNewSmithyRuntime): Delete this old implementation that was kept around to support patch releases.
|
||||
#[deprecated = "use aws_sig_auth::event_stream::SigV4MessageSigner instead (this may require upgrading the smithy-rs code generator)"]
|
||||
#[derive(Debug)]
|
||||
/// Event Stream SigV4 signing implementation.
|
||||
pub struct SigV4Signer {
|
||||
properties: SharedPropertyBag,
|
||||
last_signature: Option<String>,
|
||||
|
@ -87,8 +199,9 @@ impl SignMessage for SigV4Signer {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO(enableNewSmithyRuntime): Delete this old implementation that was kept around to support patch releases.
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
mod old_tests {
|
||||
use crate::event_stream::SigV4Signer;
|
||||
use crate::middleware::Signature;
|
||||
use aws_credential_types::Credentials;
|
||||
|
|
|
@ -20,8 +20,15 @@ use crate::signer::{
|
|||
OperationSigningConfig, RequestConfig, SigV4Signer, SigningError, SigningRequirements,
|
||||
};
|
||||
|
||||
#[cfg(feature = "sign-eventstream")]
|
||||
use crate::event_stream::SigV4MessageSigner as EventStreamSigV4Signer;
|
||||
#[cfg(feature = "sign-eventstream")]
|
||||
use aws_smithy_eventstream::frame::DeferredSignerSender;
|
||||
|
||||
// TODO(enableNewSmithyRuntime): Delete `Signature` when switching to the orchestrator
|
||||
/// Container for the request signature for use in the property bag.
|
||||
#[non_exhaustive]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Signature(String);
|
||||
|
||||
impl Signature {
|
||||
|
@ -181,6 +188,22 @@ impl MapRequest for SigV4SigningStage {
|
|||
.signer
|
||||
.sign(operation_config, &request_config, &creds, &mut req)
|
||||
.map_err(SigningStageErrorKind::SigningFailure)?;
|
||||
|
||||
// If this is an event stream operation, set up the event stream signer
|
||||
#[cfg(feature = "sign-eventstream")]
|
||||
if let Some(signer_sender) = config.get::<DeferredSignerSender>() {
|
||||
let time_override = config.get::<SystemTime>().copied();
|
||||
signer_sender
|
||||
.send(Box::new(EventStreamSigV4Signer::new(
|
||||
signature.as_ref().into(),
|
||||
creds,
|
||||
request_config.region.clone(),
|
||||
request_config.service.clone(),
|
||||
time_override,
|
||||
)) as _)
|
||||
.expect("failed to send deferred signer");
|
||||
}
|
||||
|
||||
config.insert(signature);
|
||||
Ok(req)
|
||||
})
|
||||
|
@ -234,6 +257,49 @@ mod test {
|
|||
assert!(signature.is_some());
|
||||
}
|
||||
|
||||
#[cfg(feature = "sign-eventstream")]
|
||||
#[test]
|
||||
fn sends_event_stream_signer_for_event_stream_operations() {
|
||||
use crate::event_stream::SigV4MessageSigner as EventStreamSigV4Signer;
|
||||
use aws_smithy_eventstream::frame::{DeferredSigner, SignMessage};
|
||||
use std::time::SystemTime;
|
||||
|
||||
let (mut deferred_signer, deferred_signer_sender) = DeferredSigner::new();
|
||||
let req = http::Request::builder()
|
||||
.uri("https://test-service.test-region.amazonaws.com/")
|
||||
.body(SdkBody::from(""))
|
||||
.unwrap();
|
||||
let region = Region::new("us-east-1");
|
||||
let req = operation::Request::new(req)
|
||||
.augment(|req, properties| {
|
||||
properties.insert(region.clone());
|
||||
properties.insert::<SystemTime>(UNIX_EPOCH + Duration::new(1611160427, 0));
|
||||
properties.insert(SigningService::from_static("kinesis"));
|
||||
properties.insert(OperationSigningConfig::default_config());
|
||||
properties.insert(Credentials::for_tests());
|
||||
properties.insert(SigningRegion::from(region.clone()));
|
||||
properties.insert(deferred_signer_sender);
|
||||
Result::<_, Infallible>::Ok(req)
|
||||
})
|
||||
.expect("succeeds");
|
||||
|
||||
let signer = SigV4SigningStage::new(SigV4Signer::new());
|
||||
let _ = signer.apply(req).unwrap();
|
||||
|
||||
let mut signer_for_comparison = EventStreamSigV4Signer::new(
|
||||
// This is the expected SigV4 signature for the HTTP request above
|
||||
"abac477b4afabf5651079e7b9a0aa6a1a3e356a7418a81d974cdae9d4c8e5441".into(),
|
||||
Credentials::for_tests(),
|
||||
SigningRegion::from(region),
|
||||
SigningService::from_static("kinesis"),
|
||||
Some(UNIX_EPOCH + Duration::new(1611160427, 0)),
|
||||
);
|
||||
|
||||
let expected_signed_empty = signer_for_comparison.sign_empty().unwrap().unwrap();
|
||||
let actual_signed_empty = deferred_signer.sign_empty().unwrap().unwrap();
|
||||
assert_eq!(expected_signed_empty, actual_signed_empty);
|
||||
}
|
||||
|
||||
// check that the endpoint middleware followed by signing middleware produce the expected result
|
||||
#[test]
|
||||
fn endpoint_plus_signer() {
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
use crate::middleware::Signature;
|
||||
use aws_credential_types::Credentials;
|
||||
use aws_sigv4::http_request::{
|
||||
sign, PayloadChecksumKind, PercentEncodingMode, SessionTokenMode, SignableRequest,
|
||||
|
@ -15,6 +14,7 @@ use aws_types::SigningService;
|
|||
use std::fmt;
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use crate::middleware::Signature;
|
||||
pub use aws_sigv4::http_request::SignableBody;
|
||||
pub type SigningError = aws_sigv4::http_request::SigningError;
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegen
|
|||
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientCustomization
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientSection
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.MakeOperationGenerator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientHttpBoundProtocolPayloadGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.docs
|
||||
|
@ -34,7 +35,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
|
|||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.util.cloneOperation
|
||||
import software.amazon.smithy.rust.codegen.core.util.expectTrait
|
||||
import software.amazon.smithy.rust.codegen.core.util.hasTrait
|
||||
|
@ -179,7 +179,7 @@ class AwsInputPresignedMethod(
|
|||
MakeOperationGenerator(
|
||||
codegenContext,
|
||||
protocol,
|
||||
HttpBoundProtocolPayloadGenerator(codegenContext, protocol),
|
||||
ClientHttpBoundProtocolPayloadGenerator(codegenContext, protocol),
|
||||
// Prefixed with underscore to avoid colliding with modeled functions
|
||||
functionName = makeOperationFn,
|
||||
public = false,
|
||||
|
|
|
@ -16,7 +16,7 @@ import software.amazon.smithy.model.traits.OptionalAuthTrait
|
|||
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.EventStreamSigningConfig
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rust
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
|
||||
|
@ -77,17 +77,17 @@ class SigV4SigningDecorator : ClientCodegenDecorator {
|
|||
}
|
||||
|
||||
class SigV4SigningConfig(
|
||||
runtimeConfig: RuntimeConfig,
|
||||
private val runtimeConfig: RuntimeConfig,
|
||||
private val serviceHasEventStream: Boolean,
|
||||
private val sigV4Trait: SigV4Trait,
|
||||
) : EventStreamSigningConfig(runtimeConfig) {
|
||||
private val codegenScope = arrayOf(
|
||||
"SigV4Signer" to AwsRuntimeType.awsSigAuthEventStream(runtimeConfig).resolve("event_stream::SigV4Signer"),
|
||||
)
|
||||
|
||||
override fun configImplSection(): Writable {
|
||||
return writable {
|
||||
rustTemplate(
|
||||
) : ConfigCustomization() {
|
||||
override fun section(section: ServiceConfig): Writable = writable {
|
||||
if (section is ServiceConfig.ConfigImpl) {
|
||||
if (serviceHasEventStream) {
|
||||
// enable the aws-sig-auth `sign-eventstream` feature
|
||||
addDependency(AwsRuntimeType.awsSigAuthEventStream(runtimeConfig).toSymbol())
|
||||
}
|
||||
rust(
|
||||
"""
|
||||
/// The signature version 4 service signing name to use in the credential scope when signing requests.
|
||||
///
|
||||
|
@ -97,24 +97,7 @@ class SigV4SigningConfig(
|
|||
${sigV4Trait.name.dq()}
|
||||
}
|
||||
""",
|
||||
*codegenScope,
|
||||
)
|
||||
if (serviceHasEventStream) {
|
||||
rustTemplate(
|
||||
"#{signerFn:W}",
|
||||
"signerFn" to
|
||||
renderEventStreamSignerFn { propertiesName ->
|
||||
writable {
|
||||
rustTemplate(
|
||||
"""
|
||||
#{SigV4Signer}::new($propertiesName)
|
||||
""",
|
||||
*codegenScope,
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -209,5 +192,3 @@ class SigV4SigningFeature(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun RuntimeConfig.sigAuth() = awsRuntimeCrate("aws-sig-auth")
|
||||
|
|
|
@ -12,7 +12,7 @@ import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace
|
|||
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
|
||||
|
||||
internal class SigV4SigningCustomizationTest {
|
||||
internal class SigV4SigningDecoratorTest {
|
||||
@Test
|
||||
fun `generates a valid config`() {
|
||||
val project = stubConfigProject(
|
|
@ -15,7 +15,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.customizations.HttpAuth
|
|||
import software.amazon.smithy.rust.codegen.client.smithy.customizations.HttpConnectorConfigDecorator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.customize.CombinedClientCodegenDecorator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.customize.NoOpEventStreamSigningDecorator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.customize.RequiredCustomizations
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointParamsDecorator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointsDecorator
|
||||
|
@ -62,7 +61,6 @@ class RustClientCodegenPlugin : ClientDecoratableBuildPlugin() {
|
|||
FluentClientDecorator(),
|
||||
EndpointsDecorator(),
|
||||
EndpointParamsDecorator(),
|
||||
NoOpEventStreamSigningDecorator(),
|
||||
ApiKeyAuthDecorator(),
|
||||
HttpAuthDecorator(),
|
||||
HttpConnectorConfigDecorator(),
|
||||
|
|
|
@ -1,66 +0,0 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package software.amazon.smithy.rust.codegen.client.smithy.customize
|
||||
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.EventStreamSigningConfig
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.writable
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
|
||||
import software.amazon.smithy.rust.codegen.core.util.hasEventStreamOperations
|
||||
|
||||
/**
|
||||
* The NoOpEventStreamSigningDecorator:
|
||||
* - adds a `new_event_stream_signer()` method to `config` to create an Event Stream NoOp signer
|
||||
*/
|
||||
open class NoOpEventStreamSigningDecorator : ClientCodegenDecorator {
|
||||
override val name: String = "NoOpEventStreamSigning"
|
||||
override val order: Byte = Byte.MIN_VALUE
|
||||
|
||||
private fun applies(codegenContext: CodegenContext, baseCustomizations: List<ConfigCustomization>): Boolean =
|
||||
codegenContext.serviceShape.hasEventStreamOperations(codegenContext.model) &&
|
||||
// and if there is no other `EventStreamSigningConfig`, apply this one
|
||||
!baseCustomizations.any { it is EventStreamSigningConfig }
|
||||
|
||||
override fun configCustomizations(
|
||||
codegenContext: ClientCodegenContext,
|
||||
baseCustomizations: List<ConfigCustomization>,
|
||||
): List<ConfigCustomization> {
|
||||
if (!applies(codegenContext, baseCustomizations)) {
|
||||
return baseCustomizations
|
||||
}
|
||||
return baseCustomizations + NoOpEventStreamSigningConfig(
|
||||
codegenContext.serviceShape.hasEventStreamOperations(codegenContext.model),
|
||||
codegenContext.runtimeConfig,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
class NoOpEventStreamSigningConfig(
|
||||
private val serviceHasEventStream: Boolean,
|
||||
runtimeConfig: RuntimeConfig,
|
||||
) : EventStreamSigningConfig(runtimeConfig) {
|
||||
|
||||
private val codegenScope = arrayOf(
|
||||
"NoOpSigner" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::NoOpSigner"),
|
||||
)
|
||||
|
||||
override fun configImplSection() = renderEventStreamSignerFn {
|
||||
writable {
|
||||
if (serviceHasEventStream) {
|
||||
rustTemplate(
|
||||
"""
|
||||
#{NoOpSigner}{}
|
||||
""",
|
||||
*codegenScope,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,46 +0,0 @@
|
|||
/*
|
||||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package software.amazon.smithy.rust.codegen.client.smithy.generators.config
|
||||
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.writable
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
|
||||
|
||||
open class EventStreamSigningConfig(
|
||||
runtimeConfig: RuntimeConfig,
|
||||
) : ConfigCustomization() {
|
||||
private val codegenScope = arrayOf(
|
||||
"SharedPropertyBag" to RuntimeType.smithyHttp(runtimeConfig).resolve("property_bag::SharedPropertyBag"),
|
||||
"SignMessage" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::SignMessage"),
|
||||
)
|
||||
|
||||
override fun section(section: ServiceConfig): Writable {
|
||||
return when (section) {
|
||||
is ServiceConfig.ConfigImpl -> configImplSection()
|
||||
else -> emptySection
|
||||
}
|
||||
}
|
||||
|
||||
open fun configImplSection(): Writable = emptySection
|
||||
|
||||
fun renderEventStreamSignerFn(signerInstantiator: (String) -> Writable): Writable = writable {
|
||||
rustTemplate(
|
||||
"""
|
||||
/// Creates a new Event Stream `SignMessage` implementor.
|
||||
pub fn new_event_stream_signer(
|
||||
&self,
|
||||
_properties: #{SharedPropertyBag}
|
||||
) -> impl #{SignMessage} {
|
||||
#{signer:W}
|
||||
}
|
||||
""",
|
||||
*codegenScope,
|
||||
"signer" to signerInstantiator("_properties"),
|
||||
)
|
||||
}
|
||||
}
|
|
@ -13,6 +13,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.Cli
|
|||
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.MakeOperationGenerator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolParserGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rust
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
|
||||
|
@ -23,6 +24,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.pre
|
|||
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
|
||||
|
@ -30,11 +32,11 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctio
|
|||
import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember
|
||||
import software.amazon.smithy.rust.codegen.core.util.outputShape
|
||||
|
||||
// TODO(enableNewSmithyRuntime): Delete this class when cleaning up `enableNewSmithyRuntime`
|
||||
// TODO(enableNewSmithyRuntime): Delete this class when cleaning up `enableNewSmithyRuntime` (replace with ClientProtocolGenerator)
|
||||
class HttpBoundProtocolGenerator(
|
||||
codegenContext: ClientCodegenContext,
|
||||
protocol: Protocol,
|
||||
bodyGenerator: ProtocolPayloadGenerator = HttpBoundProtocolPayloadGenerator(codegenContext, protocol),
|
||||
bodyGenerator: ProtocolPayloadGenerator = ClientHttpBoundProtocolPayloadGenerator(codegenContext, protocol),
|
||||
) : ClientProtocolGenerator(
|
||||
codegenContext,
|
||||
protocol,
|
||||
|
@ -49,6 +51,35 @@ class HttpBoundProtocolGenerator(
|
|||
HttpBoundProtocolTraitImplGenerator(codegenContext, protocol),
|
||||
)
|
||||
|
||||
class ClientHttpBoundProtocolPayloadGenerator(
|
||||
codegenContext: ClientCodegenContext,
|
||||
protocol: Protocol,
|
||||
) : ProtocolPayloadGenerator by HttpBoundProtocolPayloadGenerator(
|
||||
codegenContext, protocol, HttpMessageType.REQUEST,
|
||||
renderEventStreamBody = { writer, params ->
|
||||
writer.rustTemplate(
|
||||
"""
|
||||
{
|
||||
let error_marshaller = #{errorMarshallerConstructorFn}();
|
||||
let marshaller = #{marshallerConstructorFn}();
|
||||
let (signer, signer_sender) = #{DeferredSigner}::new();
|
||||
properties.acquire_mut().insert(signer_sender);
|
||||
let adapter: #{aws_smithy_http}::event_stream::MessageStreamAdapter<_, _> =
|
||||
${params.outerName}.${params.memberName}.into_body_stream(marshaller, error_marshaller, signer);
|
||||
let body: #{SdkBody} = #{hyper}::Body::wrap_stream(adapter).into();
|
||||
body
|
||||
}
|
||||
""",
|
||||
"hyper" to CargoDependency.HyperWithStream.toType(),
|
||||
"SdkBody" to RuntimeType.sdkBody(codegenContext.runtimeConfig),
|
||||
"aws_smithy_http" to RuntimeType.smithyHttp(codegenContext.runtimeConfig),
|
||||
"DeferredSigner" to RuntimeType.smithyEventStream(codegenContext.runtimeConfig).resolve("frame::DeferredSigner"),
|
||||
"marshallerConstructorFn" to params.marshallerConstructorFn,
|
||||
"errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
// TODO(enableNewSmithyRuntime): Delete this class when cleaning up `enableNewSmithyRuntime`
|
||||
open class HttpBoundProtocolTraitImplGenerator(
|
||||
codegenContext: ClientCodegenContext,
|
||||
|
|
|
@ -10,7 +10,6 @@ import org.junit.jupiter.api.Test
|
|||
import software.amazon.smithy.model.shapes.ShapeId
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.customizations.ClientCustomizations
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.customize.CombinedClientCodegenDecorator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.customize.NoOpEventStreamSigningDecorator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.customize.RequiredCustomizations
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientDecorator
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
|
||||
|
@ -52,7 +51,6 @@ class ClientCodegenVisitorTest {
|
|||
ClientCustomizations(),
|
||||
RequiredCustomizations(),
|
||||
FluentClientDecorator(),
|
||||
NoOpEventStreamSigningDecorator(),
|
||||
)
|
||||
val visitor = ClientCodegenVisitor(ctx, codegenDecorator)
|
||||
val baselineModel = visitor.baselineTransform(model)
|
||||
|
|
|
@ -6,19 +6,9 @@
|
|||
package software.amazon.smithy.rust.codegen.client.smithy.customizations
|
||||
|
||||
import org.junit.jupiter.api.Test
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.EventStreamSigningConfig
|
||||
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig
|
||||
import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rust
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.writable
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
|
||||
import software.amazon.smithy.rust.codegen.core.testutil.integrationTest
|
||||
|
@ -172,7 +162,6 @@ internal class HttpVersionListGeneratorTest {
|
|||
clientIntegrationTest(
|
||||
model,
|
||||
IntegrationTestParams(addModuleToEventStreamAllowList = true),
|
||||
additionalDecorators = listOf(FakeSigningDecorator()),
|
||||
) { clientCodegenContext, rustCrate ->
|
||||
val moduleName = clientCodegenContext.moduleUseName()
|
||||
rustCrate.integrationTest("validate_eventstream_http") {
|
||||
|
@ -196,77 +185,3 @@ internal class HttpVersionListGeneratorTest {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
class FakeSigningDecorator : ClientCodegenDecorator {
|
||||
override val name: String = "fakesigning"
|
||||
override val order: Byte = 0
|
||||
override fun classpathDiscoverable(): Boolean = false
|
||||
override fun configCustomizations(
|
||||
codegenContext: ClientCodegenContext,
|
||||
baseCustomizations: List<ConfigCustomization>,
|
||||
): List<ConfigCustomization> {
|
||||
return baseCustomizations.filterNot {
|
||||
it is EventStreamSigningConfig
|
||||
} + FakeSigningConfig(codegenContext.runtimeConfig)
|
||||
}
|
||||
}
|
||||
|
||||
class FakeSigningConfig(
|
||||
runtimeConfig: RuntimeConfig,
|
||||
) : EventStreamSigningConfig(runtimeConfig) {
|
||||
private val codegenScope = arrayOf(
|
||||
"SharedPropertyBag" to RuntimeType.smithyHttp(runtimeConfig).resolve("property_bag::SharedPropertyBag"),
|
||||
"SignMessageError" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::SignMessageError"),
|
||||
"SignMessage" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::SignMessage"),
|
||||
"Message" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::Message"),
|
||||
)
|
||||
|
||||
override fun section(section: ServiceConfig): Writable {
|
||||
return when (section) {
|
||||
is ServiceConfig.ConfigImpl -> writable {
|
||||
rustTemplate(
|
||||
"""
|
||||
/// Creates a new Event Stream `SignMessage` implementor.
|
||||
pub fn new_event_stream_signer(
|
||||
&self,
|
||||
properties: #{SharedPropertyBag}
|
||||
) -> FakeSigner {
|
||||
FakeSigner::new(properties)
|
||||
}
|
||||
""",
|
||||
*codegenScope,
|
||||
)
|
||||
}
|
||||
|
||||
is ServiceConfig.Extras -> writable {
|
||||
rustTemplate(
|
||||
"""
|
||||
/// Fake signing implementation.
|
||||
##[derive(Debug)]
|
||||
pub struct FakeSigner;
|
||||
|
||||
impl FakeSigner {
|
||||
/// Create a real `FakeSigner`
|
||||
pub fn new(_properties: #{SharedPropertyBag}) -> Self {
|
||||
Self {}
|
||||
}
|
||||
}
|
||||
|
||||
impl #{SignMessage} for FakeSigner {
|
||||
fn sign(&mut self, message: #{Message}) -> Result<#{Message}, #{SignMessageError}> {
|
||||
Ok(message)
|
||||
}
|
||||
|
||||
fn sign_empty(&mut self) -> Option<Result<#{Message}, #{SignMessageError}>> {
|
||||
Some(Ok(#{Message}::new(Vec::new())))
|
||||
}
|
||||
}
|
||||
""",
|
||||
*codegenScope,
|
||||
)
|
||||
}
|
||||
|
||||
else -> emptySection
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ private class TestProtocolPayloadGenerator(private val body: String) : ProtocolP
|
|||
override fun payloadMetadata(operationShape: OperationShape) =
|
||||
ProtocolPayloadGenerator.PayloadMetadata(takesOwnership = false)
|
||||
|
||||
override fun generatePayload(writer: RustWriter, self: String, operationShape: OperationShape) {
|
||||
override fun generatePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) {
|
||||
writer.writeWithNoFormatting(body)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,13 +36,13 @@ interface ProtocolPayloadGenerator {
|
|||
/**
|
||||
* Write the payload into [writer].
|
||||
*
|
||||
* [self] is the name of the variable binding for the Rust struct that is to be serialized into the payload.
|
||||
* [shapeName] is the name of the variable binding for the Rust struct that is to be serialized into the payload.
|
||||
*
|
||||
* This should be an expression that returns bytes:
|
||||
* - a `Vec<u8>` for non-streaming operations; or
|
||||
* - a `ByteStream` for streaming operations.
|
||||
*/
|
||||
fun generatePayload(writer: RustWriter, self: String, operationShape: OperationShape)
|
||||
fun generatePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -18,7 +18,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
|
|||
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rust
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
|
||||
|
@ -42,10 +41,18 @@ import software.amazon.smithy.rust.codegen.core.util.isOutputEventStream
|
|||
import software.amazon.smithy.rust.codegen.core.util.isStreaming
|
||||
import software.amazon.smithy.rust.codegen.core.util.outputShape
|
||||
|
||||
data class EventStreamBodyParams(
|
||||
val outerName: String,
|
||||
val memberName: String,
|
||||
val marshallerConstructorFn: RuntimeType,
|
||||
val errorMarshallerConstructorFn: RuntimeType,
|
||||
)
|
||||
|
||||
class HttpBoundProtocolPayloadGenerator(
|
||||
codegenContext: CodegenContext,
|
||||
private val protocol: Protocol,
|
||||
private val httpMessageType: HttpMessageType = HttpMessageType.REQUEST,
|
||||
private val renderEventStreamBody: (RustWriter, EventStreamBodyParams) -> Unit,
|
||||
) : ProtocolPayloadGenerator {
|
||||
private val symbolProvider = codegenContext.symbolProvider
|
||||
private val model = codegenContext.model
|
||||
|
@ -91,38 +98,38 @@ class HttpBoundProtocolPayloadGenerator(
|
|||
}
|
||||
}
|
||||
|
||||
override fun generatePayload(writer: RustWriter, self: String, operationShape: OperationShape) {
|
||||
override fun generatePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) {
|
||||
when (httpMessageType) {
|
||||
HttpMessageType.RESPONSE -> generateResponsePayload(writer, self, operationShape)
|
||||
HttpMessageType.REQUEST -> generateRequestPayload(writer, self, operationShape)
|
||||
HttpMessageType.RESPONSE -> generateResponsePayload(writer, shapeName, operationShape)
|
||||
HttpMessageType.REQUEST -> generateRequestPayload(writer, shapeName, operationShape)
|
||||
}
|
||||
}
|
||||
|
||||
private fun generateRequestPayload(writer: RustWriter, self: String, operationShape: OperationShape) {
|
||||
private fun generateRequestPayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) {
|
||||
val payloadMemberName = httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName
|
||||
|
||||
if (payloadMemberName == null) {
|
||||
val serializerGenerator = protocol.structuredDataSerializer()
|
||||
generateStructureSerializer(writer, self, serializerGenerator.operationInputSerializer(operationShape))
|
||||
generateStructureSerializer(writer, shapeName, serializerGenerator.operationInputSerializer(operationShape))
|
||||
} else {
|
||||
generatePayloadMemberSerializer(writer, self, operationShape, payloadMemberName)
|
||||
generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName)
|
||||
}
|
||||
}
|
||||
|
||||
private fun generateResponsePayload(writer: RustWriter, self: String, operationShape: OperationShape) {
|
||||
private fun generateResponsePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) {
|
||||
val payloadMemberName = httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName
|
||||
|
||||
if (payloadMemberName == null) {
|
||||
val serializerGenerator = protocol.structuredDataSerializer()
|
||||
generateStructureSerializer(writer, self, serializerGenerator.operationOutputSerializer(operationShape))
|
||||
generateStructureSerializer(writer, shapeName, serializerGenerator.operationOutputSerializer(operationShape))
|
||||
} else {
|
||||
generatePayloadMemberSerializer(writer, self, operationShape, payloadMemberName)
|
||||
generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName)
|
||||
}
|
||||
}
|
||||
|
||||
private fun generatePayloadMemberSerializer(
|
||||
writer: RustWriter,
|
||||
self: String,
|
||||
shapeName: String,
|
||||
operationShape: OperationShape,
|
||||
payloadMemberName: String,
|
||||
) {
|
||||
|
@ -131,7 +138,7 @@ class HttpBoundProtocolPayloadGenerator(
|
|||
if (operationShape.isEventStream(model)) {
|
||||
if (operationShape.isInputEventStream(model) && target == CodegenTarget.CLIENT) {
|
||||
val payloadMember = operationShape.inputShape(model).expectMember(payloadMemberName)
|
||||
writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, "self")
|
||||
writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, shapeName)
|
||||
} else if (operationShape.isOutputEventStream(model) && target == CodegenTarget.SERVER) {
|
||||
val payloadMember = operationShape.outputShape(model).expectMember(payloadMemberName)
|
||||
writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, "output")
|
||||
|
@ -144,16 +151,16 @@ class HttpBoundProtocolPayloadGenerator(
|
|||
HttpMessageType.RESPONSE -> operationShape.outputShape(model).expectMember(payloadMemberName)
|
||||
HttpMessageType.REQUEST -> operationShape.inputShape(model).expectMember(payloadMemberName)
|
||||
}
|
||||
writer.serializeViaPayload(bodyMetadata, self, payloadMember, serializerGenerator)
|
||||
writer.serializeViaPayload(bodyMetadata, shapeName, payloadMember, serializerGenerator)
|
||||
}
|
||||
}
|
||||
|
||||
private fun generateStructureSerializer(writer: RustWriter, self: String, serializer: RuntimeType?) {
|
||||
private fun generateStructureSerializer(writer: RustWriter, shapeName: String, serializer: RuntimeType?) {
|
||||
if (serializer == null) {
|
||||
writer.rust("\"\"")
|
||||
} else {
|
||||
writer.rust(
|
||||
"#T(&$self)?",
|
||||
"#T(&$shapeName)?",
|
||||
serializer,
|
||||
)
|
||||
}
|
||||
|
@ -193,47 +200,20 @@ class HttpBoundProtocolPayloadGenerator(
|
|||
|
||||
// TODO(EventStream): [RPC] RPC protocols need to send an initial message with the
|
||||
// parameters that are not `@eventHeader` or `@eventPayload`.
|
||||
when (target) {
|
||||
CodegenTarget.CLIENT ->
|
||||
rustTemplate(
|
||||
"""
|
||||
{
|
||||
let error_marshaller = #{errorMarshallerConstructorFn}();
|
||||
let marshaller = #{marshallerConstructorFn}();
|
||||
let signer = _config.new_event_stream_signer(properties.clone());
|
||||
let adapter: #{SmithyHttp}::event_stream::MessageStreamAdapter<_, _> =
|
||||
$outerName.$memberName.into_body_stream(marshaller, error_marshaller, signer);
|
||||
let body: #{SdkBody} = #{hyper}::Body::wrap_stream(adapter).into();
|
||||
body
|
||||
}
|
||||
""",
|
||||
*codegenScope,
|
||||
"marshallerConstructorFn" to marshallerConstructorFn,
|
||||
"errorMarshallerConstructorFn" to errorMarshallerConstructorFn,
|
||||
)
|
||||
CodegenTarget.SERVER -> {
|
||||
rustTemplate(
|
||||
"""
|
||||
{
|
||||
let error_marshaller = #{errorMarshallerConstructorFn}();
|
||||
let marshaller = #{marshallerConstructorFn}();
|
||||
let signer = #{NoOpSigner}{};
|
||||
let adapter: #{SmithyHttp}::event_stream::MessageStreamAdapter<_, _> =
|
||||
$outerName.$memberName.into_body_stream(marshaller, error_marshaller, signer);
|
||||
adapter
|
||||
}
|
||||
""",
|
||||
*codegenScope,
|
||||
"marshallerConstructorFn" to marshallerConstructorFn,
|
||||
"errorMarshallerConstructorFn" to errorMarshallerConstructorFn,
|
||||
)
|
||||
}
|
||||
}
|
||||
renderEventStreamBody(
|
||||
this,
|
||||
EventStreamBodyParams(
|
||||
outerName,
|
||||
memberName,
|
||||
marshallerConstructorFn,
|
||||
errorMarshallerConstructorFn,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
private fun RustWriter.serializeViaPayload(
|
||||
payloadMetadata: ProtocolPayloadGenerator.PayloadMetadata,
|
||||
self: String,
|
||||
shapeName: String,
|
||||
member: MemberShape,
|
||||
serializerGenerator: StructuredDataSerializerGenerator,
|
||||
) {
|
||||
|
@ -281,7 +261,7 @@ class HttpBoundProtocolPayloadGenerator(
|
|||
}
|
||||
}
|
||||
}
|
||||
rust("#T($ref $self.${symbolProvider.toMemberName(member)})?", serializer)
|
||||
rust("#T($ref $shapeName.${symbolProvider.toMemberName(member)})?", serializer)
|
||||
}
|
||||
|
||||
private fun RustWriter.renderPayload(
|
||||
|
|
|
@ -41,6 +41,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter
|
|||
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
|
||||
import software.amazon.smithy.rust.codegen.core.rustlang.writable
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization
|
||||
|
@ -48,12 +49,14 @@ import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustom
|
|||
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.mapRustType
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson
|
||||
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator
|
||||
|
@ -114,6 +117,31 @@ class ServerHttpBoundProtocolGenerator(
|
|||
}
|
||||
}
|
||||
|
||||
class ServerHttpBoundProtocolPayloadGenerator(
|
||||
codegenContext: CodegenContext,
|
||||
protocol: Protocol,
|
||||
) : ProtocolPayloadGenerator by HttpBoundProtocolPayloadGenerator(
|
||||
codegenContext, protocol, HttpMessageType.RESPONSE,
|
||||
renderEventStreamBody = { writer, params ->
|
||||
writer.rustTemplate(
|
||||
"""
|
||||
{
|
||||
let error_marshaller = #{errorMarshallerConstructorFn}();
|
||||
let marshaller = #{marshallerConstructorFn}();
|
||||
let signer = #{NoOpSigner}{};
|
||||
let adapter: #{aws_smithy_http}::event_stream::MessageStreamAdapter<_, _> =
|
||||
${params.outerName}.${params.memberName}.into_body_stream(marshaller, error_marshaller, signer);
|
||||
adapter
|
||||
}
|
||||
""",
|
||||
"aws_smithy_http" to RuntimeType.smithyHttp(codegenContext.runtimeConfig),
|
||||
"NoOpSigner" to RuntimeType.smithyEventStream(codegenContext.runtimeConfig).resolve("frame::NoOpSigner"),
|
||||
"marshallerConstructorFn" to params.marshallerConstructorFn,
|
||||
"errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
/*
|
||||
* Generate all operation input parsers and output serializers for streaming and
|
||||
* non-streaming types.
|
||||
|
@ -504,12 +532,12 @@ class ServerHttpBoundProtocolTraitImplGenerator(
|
|||
?: serverRenderHttpResponseCode(httpTraitStatusCode)(this)
|
||||
|
||||
operationShape.outputShape(model).findStreamingMember(model)?.let {
|
||||
val payloadGenerator = HttpBoundProtocolPayloadGenerator(codegenContext, protocol, httpMessageType = HttpMessageType.RESPONSE)
|
||||
val payloadGenerator = ServerHttpBoundProtocolPayloadGenerator(codegenContext, protocol)
|
||||
withBlockTemplate("let body = #{SmithyHttpServer}::body::boxed(#{SmithyHttpServer}::body::Body::wrap_stream(", "));", *codegenScope) {
|
||||
payloadGenerator.generatePayload(this, "output", operationShape)
|
||||
}
|
||||
} ?: run {
|
||||
val payloadGenerator = HttpBoundProtocolPayloadGenerator(codegenContext, protocol, httpMessageType = HttpMessageType.RESPONSE)
|
||||
val payloadGenerator = ServerHttpBoundProtocolPayloadGenerator(codegenContext, protocol)
|
||||
withBlockTemplate("let payload = ", ";") {
|
||||
payloadGenerator.generatePayload(this, "output", operationShape)
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ use std::convert::{TryFrom, TryInto};
|
|||
use std::error::Error as StdError;
|
||||
use std::fmt;
|
||||
use std::mem::size_of;
|
||||
use std::sync::{mpsc, Mutex};
|
||||
|
||||
const PRELUDE_LENGTH_BYTES: u32 = 3 * size_of::<u32>() as u32;
|
||||
const PRELUDE_LENGTH_BYTES_USIZE: usize = PRELUDE_LENGTH_BYTES as usize;
|
||||
|
@ -34,6 +35,95 @@ pub trait SignMessage: fmt::Debug {
|
|||
fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>>;
|
||||
}
|
||||
|
||||
/// A sender that gets placed in the request config to wire up an event stream signer after signing.
|
||||
#[derive(Debug)]
|
||||
#[non_exhaustive]
|
||||
pub struct DeferredSignerSender(Mutex<mpsc::Sender<Box<dyn SignMessage + Send + Sync>>>);
|
||||
|
||||
impl DeferredSignerSender {
|
||||
/// Creates a new `DeferredSignerSender`
|
||||
fn new(tx: mpsc::Sender<Box<dyn SignMessage + Send + Sync>>) -> Self {
|
||||
Self(Mutex::new(tx))
|
||||
}
|
||||
|
||||
/// Sends a signer on the channel
|
||||
pub fn send(
|
||||
&self,
|
||||
signer: Box<dyn SignMessage + Send + Sync>,
|
||||
) -> Result<(), mpsc::SendError<Box<dyn SignMessage + Send + Sync>>> {
|
||||
self.0.lock().unwrap().send(signer)
|
||||
}
|
||||
}
|
||||
|
||||
/// Deferred event stream signer to allow a signer to be wired up later.
|
||||
///
|
||||
/// HTTP request signing takes place after serialization, and the event stream
|
||||
/// message stream body is established during serialization. Since event stream
|
||||
/// signing may need context from the initial HTTP signing operation, this
|
||||
/// [`DeferredSigner`] is needed to wire up the signer later in the request lifecycle.
|
||||
///
|
||||
/// This signer basically just establishes a MPSC channel so that the sender can
|
||||
/// be placed in the request's config. Then the HTTP signer implementation can
|
||||
/// retrieve the sender from that config and send an actual signing implementation
|
||||
/// with all the context needed.
|
||||
///
|
||||
/// When an event stream implementation needs to sign a message, the first call to
|
||||
/// sign will acquire a signing implementation off of the channel and cache it
|
||||
/// for the remainder of the operation.
|
||||
#[derive(Debug)]
|
||||
pub struct DeferredSigner {
|
||||
rx: Option<Mutex<mpsc::Receiver<Box<dyn SignMessage + Send + Sync>>>>,
|
||||
signer: Option<Box<dyn SignMessage + Send + Sync>>,
|
||||
}
|
||||
|
||||
impl DeferredSigner {
|
||||
pub fn new() -> (Self, DeferredSignerSender) {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
(
|
||||
Self {
|
||||
rx: Some(Mutex::new(rx)),
|
||||
signer: None,
|
||||
},
|
||||
DeferredSignerSender::new(tx),
|
||||
)
|
||||
}
|
||||
|
||||
fn acquire(&mut self) -> &mut (dyn SignMessage + Send + Sync) {
|
||||
// Can't use `if let Some(signer) = &mut self.signer` because the borrow checker isn't smart enough
|
||||
if self.signer.is_some() {
|
||||
return self.signer.as_mut().unwrap().as_mut();
|
||||
} else {
|
||||
self.signer = Some(
|
||||
self.rx
|
||||
.take()
|
||||
.expect("only taken once")
|
||||
.lock()
|
||||
.unwrap()
|
||||
.try_recv()
|
||||
.ok()
|
||||
// TODO(enableNewSmithyRuntime): When the middleware implementation is removed,
|
||||
// this should panic rather than default to the `NoOpSigner`. The reason it defaults
|
||||
// is because middleware-based generic clients don't have any default middleware,
|
||||
// so there is no way to send a `NoOpSigner` by default when there is no other
|
||||
// auth scheme. The orchestrator auth setup is a lot more robust and will make
|
||||
// this problem trivial.
|
||||
.unwrap_or_else(|| Box::new(NoOpSigner {}) as _),
|
||||
);
|
||||
self.acquire()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SignMessage for DeferredSigner {
|
||||
fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
|
||||
self.acquire().sign(message)
|
||||
}
|
||||
|
||||
fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
|
||||
self.acquire().sign_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct NoOpSigner {}
|
||||
impl SignMessage for NoOpSigner {
|
||||
|
@ -848,3 +938,60 @@ mod message_frame_decoder_tests {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod deferred_signer_tests {
|
||||
use crate::frame::{DeferredSigner, Header, HeaderValue, Message, SignMessage};
|
||||
use bytes::Bytes;
|
||||
|
||||
fn check_send_sync<T: Send + Sync>(value: T) -> T {
|
||||
value
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deferred_signer() {
|
||||
#[derive(Default, Debug)]
|
||||
struct TestSigner {
|
||||
call_num: i32,
|
||||
}
|
||||
impl SignMessage for TestSigner {
|
||||
fn sign(
|
||||
&mut self,
|
||||
message: crate::frame::Message,
|
||||
) -> Result<crate::frame::Message, crate::frame::SignMessageError> {
|
||||
self.call_num += 1;
|
||||
Ok(message.add_header(Header::new("call_num", HeaderValue::Int32(self.call_num))))
|
||||
}
|
||||
|
||||
fn sign_empty(
|
||||
&mut self,
|
||||
) -> Option<Result<crate::frame::Message, crate::frame::SignMessageError>> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
let (mut signer, sender) = check_send_sync(DeferredSigner::new());
|
||||
|
||||
sender
|
||||
.send(Box::new(TestSigner::default()))
|
||||
.expect("success");
|
||||
|
||||
let message = signer.sign(Message::new(Bytes::new())).expect("success");
|
||||
assert_eq!(1, message.headers()[0].value().as_int32().unwrap());
|
||||
|
||||
let message = signer.sign(Message::new(Bytes::new())).expect("success");
|
||||
assert_eq!(2, message.headers()[0].value().as_int32().unwrap());
|
||||
|
||||
assert!(signer.sign_empty().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deferred_signer_defaults_to_noop_signer() {
|
||||
let (mut signer, _sender) = DeferredSigner::new();
|
||||
assert_eq!(
|
||||
Message::new(Bytes::new()),
|
||||
signer.sign(Message::new(Bytes::new())).unwrap()
|
||||
);
|
||||
assert!(signer.sign_empty().is_none());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue